Merge "Add a method to help mocking registerNetworkAgent" into main
diff --git a/framework/src/android/net/ConnectivityManager.java b/framework/src/android/net/ConnectivityManager.java
index 5d99b74..3d7ea69 100644
--- a/framework/src/android/net/ConnectivityManager.java
+++ b/framework/src/android/net/ConnectivityManager.java
@@ -1242,6 +1242,22 @@
     @ConnectivityManagerFeature
     private Long mEnabledConnectivityManagerFeatures = null;
 
+    /**
+     * A class to help with mocking ConnectivityManager.
+     * @hide
+     */
+    public static class MockHelpers {
+        /**
+         * Produce an instance of the class returned by
+         * {@link ConnectivityManager#registerNetworkAgent}
+         * @hide
+         */
+        public static Network registerNetworkAgentResult(
+                @Nullable final Network network, @Nullable final INetworkAgentRegistry registry) {
+            return network;
+        }
+    }
+
     private TetheringManager getTetheringManager() {
         synchronized (mTetheringEventCallbacks) {
             if (mTetheringManager == null) {
diff --git a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
index 8fcc703..5e035a2 100644
--- a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
+++ b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
@@ -172,6 +172,7 @@
 import org.junit.Test
 import org.junit.runner.RunWith
 import org.mockito.ArgumentMatchers.any
+import org.mockito.ArgumentMatchers.anyInt
 import org.mockito.ArgumentMatchers.argThat
 import org.mockito.ArgumentMatchers.eq
 import org.mockito.Mockito.doReturn
@@ -1066,7 +1067,20 @@
     fun testAgentStartsInConnecting() {
         val mockContext = mock(Context::class.java)
         val mockCm = mock(ConnectivityManager::class.java)
+        val mockedResult = ConnectivityManager.MockHelpers.registerNetworkAgentResult(
+            mock(Network::class.java),
+            mock(INetworkAgentRegistry::class.java)
+        )
         doReturn(mockCm).`when`(mockContext).getSystemService(Context.CONNECTIVITY_SERVICE)
+        doReturn(mockedResult).`when`(mockCm).registerNetworkAgent(
+            any(),
+            any(),
+            any(),
+            any(),
+            any(),
+            any(),
+            anyInt()
+        )
         val agent = createNetworkAgent(mockContext)
         agent.register()
         verify(mockCm).registerNetworkAgent(