Create TestRule: SetFeatureFlagsRule

This rule works by reading the annotation and dynamically
invoking a user-specified method before evaluation to handle the feature flag configuration.

Usage:
class MyTestClass {
  @get:Rule
  val setFeatureFlagsRule = SetFeatureFlagsRule(setFlagsMethod = (name, enabled) -> {
    // Custom handling code.
  })

  // ... test methods with @FeatureFlag annotations
  @FeatureFlag("FooBar1", true)
  @FeatureFlag("FooBar2", false)
  @Test
  fun testFooBar() {}
}

Test: atest ConnectivityCoverageTests:android.net.connectivity.com.android.server.net.NetworkStatsServiceTest
Bug: N/A
Change-Id: I439e8ee40b8e81b6eb3857925912a3b843e8dfa1
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/SetFeatureFlagsRule.kt b/staticlibs/testutils/devicetests/com/android/testutils/SetFeatureFlagsRule.kt
new file mode 100644
index 0000000..4185b05
--- /dev/null
+++ b/staticlibs/testutils/devicetests/com/android/testutils/SetFeatureFlagsRule.kt
@@ -0,0 +1,82 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils.com.android.testutils
+
+import org.junit.rules.TestRule
+import org.junit.runner.Description
+import org.junit.runners.model.Statement
+
+/**
+ * A JUnit Rule that sets feature flags based on `@FeatureFlag` annotations.
+ *
+ * This rule enables dynamic control of feature flag states during testing.
+ *
+ * **Usage:**
+ * ```kotlin
+ * class MyTestClass {
+ *   @get:Rule
+ *   val setFeatureFlagsRule = SetFeatureFlagsRule(setFlagsMethod = (name, enabled) -> {
+ *     // Custom handling code.
+ *   })
+ *
+ *   // ... test methods with @FeatureFlag annotations
+ *   @FeatureFlag("FooBar1", true)
+ *   @FeatureFlag("FooBar2", false)
+ *   @Test
+ *   fun testFooBar() {}
+ * }
+ * ```
+ */
+class SetFeatureFlagsRule(val setFlagsMethod: (name: String, enabled: Boolean) -> Unit) : TestRule {
+    /**
+     * This annotation marks a test method as requiring a specific feature flag to be configured.
+     *
+     * Use this on test methods to dynamically control feature flag states during testing.
+     *
+     * @param name The name of the feature flag.
+     * @param enabled The desired state (true for enabled, false for disabled) of the feature flag.
+     */
+    @Target(AnnotationTarget.FUNCTION)
+    @Retention(AnnotationRetention.RUNTIME)
+    annotation class FeatureFlag(val name: String, val enabled: Boolean = true)
+
+    /**
+     * This method is the core of the rule, executed by the JUnit framework before each test method.
+     *
+     * It retrieves the test method's metadata.
+     * If any `@FeatureFlag` annotation is found, it passes every feature flag's name
+     * and enabled state into the user-specified lambda to apply custom actions.
+     */
+    override fun apply(base: Statement, description: Description): Statement {
+        return object : Statement() {
+            override fun evaluate() {
+                val testMethod = description.testClass.getMethod(description.methodName)
+                val featureFlagAnnotations = testMethod.getAnnotationsByType(
+                    FeatureFlag::class.java
+                )
+
+                for (featureFlagAnnotation in featureFlagAnnotations) {
+                    setFlagsMethod(featureFlagAnnotation.name, featureFlagAnnotation.enabled)
+                }
+
+                // Execute the test method, which includes methods annotated with
+                // @Before, @Test and @After.
+                base.evaluate()
+            }
+        }
+    }
+}
diff --git a/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java b/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java
index 8efab46..91b7bf0 100644
--- a/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java
+++ b/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java
@@ -95,9 +95,6 @@
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 
-import static java.lang.annotation.ElementType.METHOD;
-import static java.lang.annotation.RetentionPolicy.RUNTIME;
-
 import android.annotation.NonNull;
 import android.app.AlarmManager;
 import android.content.Context;
@@ -165,6 +162,8 @@
 import com.android.testutils.HandlerUtils;
 import com.android.testutils.TestBpfMap;
 import com.android.testutils.TestableNetworkStatsProviderBinder;
+import com.android.testutils.com.android.testutils.SetFeatureFlagsRule;
+import com.android.testutils.com.android.testutils.SetFeatureFlagsRule.FeatureFlag;
 
 import libcore.testing.io.TestIoUtils;
 
@@ -173,7 +172,6 @@
 import org.junit.Ignore;
 import org.junit.Rule;
 import org.junit.Test;
-import org.junit.rules.TestName;
 import org.junit.runner.RunWith;
 import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
@@ -183,9 +181,6 @@
 import java.io.FileDescriptor;
 import java.io.PrintWriter;
 import java.io.StringWriter;
-import java.lang.annotation.Retention;
-import java.lang.annotation.Target;
-import java.lang.reflect.Method;
 import java.nio.file.Files;
 import java.nio.file.Path;
 import java.time.Clock;
@@ -214,8 +209,6 @@
 // NetworkStatsService is not updatable before T, so tests do not need to be backwards compatible
 @DevSdkIgnoreRule.IgnoreUpTo(SC_V2)
 public class NetworkStatsServiceTest extends NetworkStatsBaseTest {
-    @Rule
-    public final TestName mTestNameRule = new TestName();
 
     private static final String TAG = "NetworkStatsServiceTest";
 
@@ -312,6 +305,15 @@
     final TestDependencies mDeps = new TestDependencies();
     final HashMap<String, Boolean> mFeatureFlags = new HashMap<>();
 
+    // This will set feature flags from @FeatureFlag annotations
+    // into the map before setUp() runs.
+    @Rule
+    public final SetFeatureFlagsRule mSetFeatureFlagsRule =
+            new SetFeatureFlagsRule((name, enabled) -> {
+                mFeatureFlags.put(name, enabled);
+                return null;
+            });
+
     private class MockContext extends BroadcastInterceptingContext {
         private final Context mBaseContext;
 
@@ -369,33 +371,6 @@
         return parcel;
     }
 
-    // Tests can use this annotation to set feature flags before constructing
-    // NetworkStatsService, e.g. @FeatureFlag(FeatureName, true/false)).
-    // TODO: Refactor into a Rule, and put in a common place.
-    @Retention(RUNTIME)
-    @Target(METHOD)
-    public @interface FeatureFlag {
-        String name();
-
-        boolean enabled() default true;
-    }
-
-    private void initFeatureFlagsFromAnnotations() {
-        // Setup default and overrides feature flags before creating the service.
-        mFeatureFlags.put(TRAFFICSTATS_RATE_LIMIT_CACHE_ENABLED_FLAG, true);
-
-        final String testMethodName = mTestNameRule.getMethodName();
-        try {
-            final Method method = this.getClass().getMethod(testMethodName);
-            final FeatureFlag[] flags = method.getAnnotationsByType(FeatureFlag.class);
-            for (final FeatureFlag flag : flags) {
-                mFeatureFlags.put(flag.name(), flag.enabled());
-            }
-        } catch (NoSuchMethodException ignored) {
-            // This is expected for parameterized tests
-        }
-    }
-
     @Before
     public void setUp() throws Exception {
         MockitoAnnotations.initMocks(this);
@@ -419,9 +394,6 @@
         PowerManager.WakeLock wakeLock =
                 powerManager.newWakeLock(PowerManager.PARTIAL_WAKE_LOCK, TAG);
 
-        // This has to be invoked before initialize the service instance.
-        initFeatureFlagsFromAnnotations();
-
         mHandlerThread = new HandlerThread("NetworkStatsServiceTest-HandlerThread");
         // Create a separate thread for observers to run on. This thread cannot be the same
         // as the handler thread, because the observer callback is fired on this thread, and
@@ -2433,7 +2405,7 @@
         doTestTrafficStatsRateLimitCache(false /* cacheEnabled */);
     }
 
-    @FeatureFlag(name = TRAFFICSTATS_RATE_LIMIT_CACHE_ENABLED_FLAG, enabled = true)
+    @FeatureFlag(name = TRAFFICSTATS_RATE_LIMIT_CACHE_ENABLED_FLAG)
     @Test
     public void testTrafficStatsRateLimitCache_enabled() throws Exception {
         doTestTrafficStatsRateLimitCache(true /* cacheEnabled */);