Merge "Remove synchronized lock in MdnsServiceTypeClient" into main
diff --git a/service-t/native/libs/libnetworkstats/NetworkTraceHandler.cpp b/service-t/native/libs/libnetworkstats/NetworkTraceHandler.cpp
index c5f9631..ec63e41 100644
--- a/service-t/native/libs/libnetworkstats/NetworkTraceHandler.cpp
+++ b/service-t/native/libs/libnetworkstats/NetworkTraceHandler.cpp
@@ -119,7 +119,14 @@
       // the session and delegates writing. The corresponding handler will write
       // with the setting specified in the trace config.
       NetworkTraceHandler::Trace([&](NetworkTraceHandler::TraceContext ctx) {
-        ctx.GetDataSourceLocked()->Write(packets, ctx);
+        perfetto::LockedHandle<NetworkTraceHandler> handle =
+            ctx.GetDataSourceLocked();
+        // The underlying handle can be invalidated between when Trace starts
+        // and GetDataSourceLocked is called, but not while the LockedHandle
+        // exists and holds the lock. Check validity prior to use.
+        if (handle.valid()) {
+          handle->Write(packets, ctx);
+        }
       });
     });
 
diff --git a/service-t/src/com/android/server/connectivity/mdns/AbstractSocketNetlink.java b/service-t/src/com/android/server/connectivity/mdns/AbstractSocketNetlinkMonitor.java
similarity index 95%
rename from service-t/src/com/android/server/connectivity/mdns/AbstractSocketNetlink.java
rename to service-t/src/com/android/server/connectivity/mdns/AbstractSocketNetlinkMonitor.java
index b792e46..bba3338 100644
--- a/service-t/src/com/android/server/connectivity/mdns/AbstractSocketNetlink.java
+++ b/service-t/src/com/android/server/connectivity/mdns/AbstractSocketNetlinkMonitor.java
@@ -19,7 +19,7 @@
 /**
  * The interface for netlink monitor.
  */
-public interface AbstractSocketNetlink {
+public interface AbstractSocketNetlinkMonitor {
 
     /**
      * Returns if the netlink monitor is supported or not. By default, it is not supported.
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
index e963ab7..6925b49 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
@@ -82,7 +82,7 @@
     @NonNull private final Dependencies mDependencies;
     @NonNull private final NetworkCallback mNetworkCallback;
     @NonNull private final TetheringEventCallback mTetheringEventCallback;
-    @NonNull private final AbstractSocketNetlink mSocketNetlinkMonitor;
+    @NonNull private final AbstractSocketNetlinkMonitor mSocketNetlinkMonitor;
     @NonNull private final SharedLog mSharedLog;
     private final ArrayMap<Network, SocketInfo> mNetworkSockets = new ArrayMap<>();
     private final ArrayMap<String, SocketInfo> mTetherInterfaceSockets = new ArrayMap<>();
@@ -253,7 +253,8 @@
             return iface.getIndex();
         }
         /*** Creates a SocketNetlinkMonitor */
-        public AbstractSocketNetlink createSocketNetlinkMonitor(@NonNull final Handler handler,
+        public AbstractSocketNetlinkMonitor createSocketNetlinkMonitor(
+                @NonNull final Handler handler,
                 @NonNull final SharedLog log,
                 @NonNull final NetLinkMonitorCallBack cb) {
             return SocketNetLinkMonitorFactory.createNetLinkMonitor(handler, log, cb);
diff --git a/service-t/src/com/android/server/connectivity/mdns/SocketNetLinkMonitorFactory.java b/service-t/src/com/android/server/connectivity/mdns/SocketNetLinkMonitorFactory.java
index 6bc7941..77c8f9c 100644
--- a/service-t/src/com/android/server/connectivity/mdns/SocketNetLinkMonitorFactory.java
+++ b/service-t/src/com/android/server/connectivity/mdns/SocketNetLinkMonitorFactory.java
@@ -30,7 +30,7 @@
     /**
      * Creates a new netlink monitor.
      */
-    public static AbstractSocketNetlink createNetLinkMonitor(@NonNull final Handler handler,
+    public static AbstractSocketNetlinkMonitor createNetLinkMonitor(@NonNull final Handler handler,
             @NonNull SharedLog log, @NonNull MdnsSocketProvider.NetLinkMonitorCallBack cb) {
         return new SocketNetlinkMonitor(handler, log, cb);
     }
diff --git a/service-t/src/com/android/server/connectivity/mdns/internal/SocketNetlinkMonitor.java b/service-t/src/com/android/server/connectivity/mdns/internal/SocketNetlinkMonitor.java
index 451909c..c21c903 100644
--- a/service-t/src/com/android/server/connectivity/mdns/internal/SocketNetlinkMonitor.java
+++ b/service-t/src/com/android/server/connectivity/mdns/internal/SocketNetlinkMonitor.java
@@ -28,13 +28,13 @@
 import com.android.net.module.util.netlink.NetlinkMessage;
 import com.android.net.module.util.netlink.RtNetlinkAddressMessage;
 import com.android.net.module.util.netlink.StructIfaddrMsg;
-import com.android.server.connectivity.mdns.AbstractSocketNetlink;
+import com.android.server.connectivity.mdns.AbstractSocketNetlinkMonitor;
 import com.android.server.connectivity.mdns.MdnsSocketProvider;
 
 /**
  * The netlink monitor for MdnsSocketProvider.
  */
-public class SocketNetlinkMonitor extends NetlinkMonitor implements AbstractSocketNetlink {
+public class SocketNetlinkMonitor extends NetlinkMonitor implements AbstractSocketNetlinkMonitor {
 
     public static final String TAG = SocketNetlinkMonitor.class.getSimpleName();
 
diff --git a/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java b/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
index 368860e..d03cac6 100644
--- a/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
+++ b/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
@@ -495,8 +495,11 @@
         final AutomaticOnOffKeepalive autoKi;
         try {
             autoKi = target.withKeepaliveInfo(res.second);
-            // Close the duplicated fd.
-            target.close();
+            // Only automatic keepalives duplicate the fd.
+            if (target.mAutomaticOnOffState != STATE_ALWAYS_ON) {
+                // Close the duplicated fd.
+                target.close();
+            }
         } catch (InvalidSocketException e) {
             Log.wtf(TAG, "Fail to create AutomaticOnOffKeepalive", e);
             return;
diff --git a/service/src/com/android/server/connectivity/KeepaliveTracker.java b/service/src/com/android/server/connectivity/KeepaliveTracker.java
index 125c269..2ce186f 100644
--- a/service/src/com/android/server/connectivity/KeepaliveTracker.java
+++ b/service/src/com/android/server/connectivity/KeepaliveTracker.java
@@ -553,6 +553,8 @@
 
     private KeepaliveInfo handleUpdateKeepaliveForClat(KeepaliveInfo ki)
             throws InvalidSocketException, InvalidPacketException {
+        // Translation applies to only NAT-T keepalive
+        if (ki.mType != KeepaliveInfo.TYPE_NATT) return ki;
         // Only try to translate address if the packet source address is the clat's source address.
         if (!ki.mPacket.getSrcAddress().equals(ki.getNai().getClatv4SrcAddress())) return ki;
 
diff --git a/tests/cts/hostside/src/com/android/cts/net/HostsideConnOnActivityStartTest.java b/tests/cts/hostside/src/com/android/cts/net/HostsideConnOnActivityStartTest.java
index a7d5590..d112425 100644
--- a/tests/cts/hostside/src/com/android/cts/net/HostsideConnOnActivityStartTest.java
+++ b/tests/cts/hostside/src/com/android/cts/net/HostsideConnOnActivityStartTest.java
@@ -26,16 +26,12 @@
     private static final String TEST_CLASS = TEST_PKG + ".ConnOnActivityStartTest";
     @Before
     public void setUp() throws Exception {
-        super.setUp();
-
         uninstallPackage(TEST_APP2_PKG, false);
         installPackage(TEST_APP2_APK);
     }
 
     @After
     public void tearDown() throws Exception {
-        super.tearDown();
-
         uninstallPackage(TEST_APP2_PKG, true);
     }
 
diff --git a/tests/cts/hostside/src/com/android/cts/net/HostsideNetworkCallbackTests.java b/tests/cts/hostside/src/com/android/cts/net/HostsideNetworkCallbackTests.java
index 5d7ad62..d8e7a2c 100644
--- a/tests/cts/hostside/src/com/android/cts/net/HostsideNetworkCallbackTests.java
+++ b/tests/cts/hostside/src/com/android/cts/net/HostsideNetworkCallbackTests.java
@@ -23,14 +23,12 @@
 
     @Before
     public void setUp() throws Exception {
-        super.setUp();
         uninstallPackage(TEST_APP2_PKG, false);
         installPackage(TEST_APP2_APK);
     }
 
     @After
     public void tearDown() throws Exception {
-        super.tearDown();
         uninstallPackage(TEST_APP2_PKG, true);
     }
 
diff --git a/tests/cts/hostside/src/com/android/cts/net/HostsideNetworkPolicyManagerTests.java b/tests/cts/hostside/src/com/android/cts/net/HostsideNetworkPolicyManagerTests.java
index 40f5f59..3ddb88b 100644
--- a/tests/cts/hostside/src/com/android/cts/net/HostsideNetworkPolicyManagerTests.java
+++ b/tests/cts/hostside/src/com/android/cts/net/HostsideNetworkPolicyManagerTests.java
@@ -23,14 +23,12 @@
 public class HostsideNetworkPolicyManagerTests extends HostsideNetworkTestCase {
     @Before
     public void setUp() throws Exception {
-        super.setUp();
         uninstallPackage(TEST_APP2_PKG, false);
         installPackage(TEST_APP2_APK);
     }
 
     @After
     public void tearDown() throws Exception {
-        super.tearDown();
         uninstallPackage(TEST_APP2_PKG, true);
     }
 
diff --git a/tests/cts/hostside/src/com/android/cts/net/HostsideNetworkTestCase.java b/tests/cts/hostside/src/com/android/cts/net/HostsideNetworkTestCase.java
index c896168..566d9da 100644
--- a/tests/cts/hostside/src/com/android/cts/net/HostsideNetworkTestCase.java
+++ b/tests/cts/hostside/src/com/android/cts/net/HostsideNetworkTestCase.java
@@ -16,16 +16,21 @@
 
 package com.android.cts.net;
 
+import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.fail;
 
 import com.android.ddmlib.Log;
 import com.android.modules.utils.build.testing.DeviceSdkLevel;
 import com.android.tradefed.device.DeviceNotAvailableException;
+import com.android.tradefed.invoker.TestInformation;
+import com.android.tradefed.targetprep.BuildError;
 import com.android.tradefed.targetprep.TargetSetupError;
+import com.android.tradefed.targetprep.suite.SuiteApkInstaller;
 import com.android.tradefed.testtype.DeviceJUnit4ClassRunner;
+import com.android.tradefed.testtype.junit4.AfterClassWithInfo;
 import com.android.tradefed.testtype.junit4.BaseHostJUnit4Test;
-import com.android.tradefed.testtype.junit4.DeviceTestRunOptions;
+import com.android.tradefed.testtype.junit4.BeforeClassWithInfo;
 import com.android.tradefed.util.RunUtil;
 
 import org.junit.runner.RunWith;
@@ -40,34 +45,61 @@
     protected static final String TEST_APP2_PKG = "com.android.cts.net.hostside.app2";
     protected static final String TEST_APP2_APK = "CtsHostsideNetworkTestsApp2.apk";
 
-    protected void setUp() throws Exception {
-        DeviceSdkLevel deviceSdkLevel = new DeviceSdkLevel(getDevice());
-        String testApk = deviceSdkLevel.isDeviceAtLeastT() ? TEST_APK_NEXT
-                : TEST_APK;
+    @BeforeClassWithInfo
+    public static void setUpOnce(TestInformation testInfo) throws Exception {
+        DeviceSdkLevel deviceSdkLevel = new DeviceSdkLevel(testInfo.getDevice());
+        String testApk = deviceSdkLevel.isDeviceAtLeastT() ? TEST_APK_NEXT : TEST_APK;
 
-        uninstallPackage(TEST_PKG, false);
-        installPackage(testApk);
+        uninstallPackage(testInfo, TEST_PKG, false);
+        installPackage(testInfo, testApk);
     }
 
-    protected void tearDown() throws Exception {
-        uninstallPackage(TEST_PKG, true);
+    @AfterClassWithInfo
+    public static void tearDownOnce(TestInformation testInfo) throws DeviceNotAvailableException {
+        uninstallPackage(testInfo, TEST_PKG, true);
+    }
+
+    // Custom static method to install the specified package, this is used to bypass auto-cleanup
+    // per test in BaseHostJUnit4.
+    protected static void installPackage(TestInformation testInfo, String apk)
+            throws DeviceNotAvailableException, TargetSetupError {
+        assertNotNull(testInfo);
+        final int userId = testInfo.getDevice().getCurrentUser();
+        final SuiteApkInstaller installer = new SuiteApkInstaller();
+        // Force the apk clean up
+        installer.setCleanApk(true);
+        installer.addTestFileName(apk);
+        installer.setUserId(userId);
+        installer.setShouldGrantPermission(true);
+        installer.addInstallArg("-t");
+        try {
+            installer.setUp(testInfo);
+        } catch (BuildError e) {
+            throw new TargetSetupError(
+                    e.getMessage(), e, testInfo.getDevice().getDeviceDescriptor(), e.getErrorId());
+        }
     }
 
     protected void installPackage(String apk) throws DeviceNotAvailableException, TargetSetupError {
-        final DeviceTestRunOptions installOptions = new DeviceTestRunOptions(
-                null /* packageName */);
-        final int userId = getDevice().getCurrentUser();
-        installPackageAsUser(apk, true /* grantPermission */, userId, "-t");
+        installPackage(getTestInformation(), apk);
     }
 
-    protected void uninstallPackage(String packageName, boolean shouldSucceed)
+    protected static void uninstallPackage(TestInformation testInfo, String packageName,
+            boolean shouldSucceed)
             throws DeviceNotAvailableException {
-        final String result = uninstallPackage(packageName);
+        assertNotNull(testInfo);
+        final String result = testInfo.getDevice().uninstallPackage(packageName);
         if (shouldSucceed) {
             assertNull("uninstallPackage(" + packageName + ") failed: " + result, result);
         }
     }
 
+    protected void uninstallPackage(String packageName,
+            boolean shouldSucceed)
+            throws DeviceNotAvailableException {
+        uninstallPackage(getTestInformation(), packageName, shouldSucceed);
+    }
+
     protected void assertPackageUninstalled(String packageName) throws DeviceNotAvailableException,
             InterruptedException {
         final String command = "cmd package list packages " + packageName;
diff --git a/tests/cts/hostside/src/com/android/cts/net/HostsideRestrictBackgroundNetworkTests.java b/tests/cts/hostside/src/com/android/cts/net/HostsideRestrictBackgroundNetworkTests.java
index 0977deb..57b26bd 100644
--- a/tests/cts/hostside/src/com/android/cts/net/HostsideRestrictBackgroundNetworkTests.java
+++ b/tests/cts/hostside/src/com/android/cts/net/HostsideRestrictBackgroundNetworkTests.java
@@ -32,16 +32,12 @@
 
     @Before
     public void setUp() throws Exception {
-        super.setUp();
-
         uninstallPackage(TEST_APP2_PKG, false);
         installPackage(TEST_APP2_APK);
     }
 
     @After
     public void tearDown() throws Exception {
-        super.tearDown();
-
         uninstallPackage(TEST_APP2_PKG, true);
     }
 
diff --git a/tests/cts/hostside/src/com/android/cts/net/HostsideVpnTests.java b/tests/cts/hostside/src/com/android/cts/net/HostsideVpnTests.java
index 242fd5d..691ac90 100644
--- a/tests/cts/hostside/src/com/android/cts/net/HostsideVpnTests.java
+++ b/tests/cts/hostside/src/com/android/cts/net/HostsideVpnTests.java
@@ -26,16 +26,12 @@
 
     @Before
     public void setUp() throws Exception {
-        super.setUp();
-
         uninstallPackage(TEST_APP2_PKG, false);
         installPackage(TEST_APP2_APK);
     }
 
     @After
     public void tearDown() throws Exception {
-        super.tearDown();
-
         uninstallPackage(TEST_APP2_PKG, true);
     }
 
diff --git a/tests/cts/net/src/android/net/cts/NetworkStatsManagerTest.java b/tests/cts/net/src/android/net/cts/NetworkStatsManagerTest.java
index 83b9b81..7bccbde 100644
--- a/tests/cts/net/src/android/net/cts/NetworkStatsManagerTest.java
+++ b/tests/cts/net/src/android/net/cts/NetworkStatsManagerTest.java
@@ -82,9 +82,9 @@
 import org.junit.Test;
 import org.junit.runner.RunWith;
 
+import java.io.BufferedInputStream;
 import java.io.IOException;
 import java.io.InputStream;
-import java.io.InputStreamReader;
 import java.net.HttpURLConnection;
 import java.net.URL;
 import java.net.UnknownHostException;
@@ -220,7 +220,7 @@
         } else {
             Log.w(LOG_TAG, "Network: " + networkInfo.toString());
         }
-        InputStreamReader in = null;
+        BufferedInputStream in = null;
         HttpURLConnection urlc = null;
         String originalKeepAlive = System.getProperty("http.keepAlive");
         System.setProperty("http.keepAlive", "false");
@@ -236,10 +236,10 @@
             urlc.connect();
             boolean ping = urlc.getResponseCode() == 200;
             if (ping) {
-                in = new InputStreamReader((InputStream) urlc.getContent());
-                // Since the test doesn't really care about the precise amount of data, instead
-                // of reading all contents, just read few bytes at the beginning.
-                in.read();
+                in = new BufferedInputStream((InputStream) urlc.getContent());
+                while (in.read() != -1) {
+                    // Comments to suppress lint error.
+                }
             }
         } catch (Exception e) {
             Log.i(LOG_TAG, "Badness during exercising remote server: " + e);
@@ -377,9 +377,14 @@
                 .addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
                 .build(), callback);
         synchronized (this) {
-            try {
-                wait((int) (TIMEOUT_MILLIS * 2.4));
-            } catch (InterruptedException e) {
+            long now = System.currentTimeMillis();
+            final long deadline = (long) (now + TIMEOUT_MILLIS * 2.4);
+            while (!callback.success && now < deadline) {
+                try {
+                    wait(deadline - now);
+                } catch (InterruptedException e) {
+                }
+                now = System.currentTimeMillis();
             }
         }
         if (callback.success) {
diff --git a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
index 6c411cf..49620b0 100644
--- a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
+++ b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
@@ -80,8 +80,6 @@
 import com.android.modules.utils.build.SdkLevel.isAtLeastU
 import com.android.net.module.util.ArrayTrackRecord
 import com.android.net.module.util.TrackRecord
-import com.android.networkstack.apishim.NsdShimImpl
-import com.android.networkstack.apishim.common.NsdShim
 import com.android.testutils.ConnectivityModuleTest
 import com.android.testutils.DevSdkIgnoreRule
 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
@@ -133,8 +131,6 @@
 private const val DBG = false
 private const val TEST_PORT = 12345
 
-private val nsdShim = NsdShimImpl.newInstance()
-
 @AppModeFull(reason = "Socket cannot bind in instant app mode")
 @RunWith(DevSdkIgnoreRunner::class)
 @SmallTest
@@ -293,7 +289,7 @@
             val serviceFound = expectCallbackEventually<ServiceFound> {
                 it.serviceInfo.serviceName == serviceName &&
                         (expectedNetwork == null ||
-                                expectedNetwork == nsdShim.getNetwork(it.serviceInfo))
+                                expectedNetwork == it.serviceInfo.network)
             }.serviceInfo
             // Discovered service types have a dot at the end
             assertEquals("$serviceType.", serviceFound.serviceType)
@@ -331,7 +327,7 @@
         }
     }
 
-    private class NsdServiceInfoCallbackRecord : NsdShim.ServiceInfoCallbackShim,
+    private class NsdServiceInfoCallbackRecord : NsdManager.ServiceInfoCallback,
             NsdRecord<NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent>() {
         sealed class ServiceInfoCallbackEvent : NsdEvent {
             data class RegisterCallbackFailed(val errorCode: Int) : ServiceInfoCallbackEvent()
@@ -361,11 +357,9 @@
     fun setUp() {
         handlerThread.start()
 
-        if (TestUtils.shouldTestTApis()) {
-            runAsShell(MANAGE_TEST_NETWORKS) {
-                testNetwork1 = createTestNetwork()
-                testNetwork2 = createTestNetwork()
-            }
+        runAsShell(MANAGE_TEST_NETWORKS) {
+            testNetwork1 = createTestNetwork()
+            testNetwork2 = createTestNetwork()
         }
     }
 
@@ -450,12 +444,10 @@
 
     @After
     fun tearDown() {
-        if (TestUtils.shouldTestTApis()) {
-            runAsShell(MANAGE_TEST_NETWORKS) {
-                // Avoid throwing here if initializing failed in setUp
-                if (this::testNetwork1.isInitialized) testNetwork1.close(cm)
-                if (this::testNetwork2.isInitialized) testNetwork2.close(cm)
-            }
+        runAsShell(MANAGE_TEST_NETWORKS) {
+            // Avoid throwing here if initializing failed in setUp
+            if (this::testNetwork1.isInitialized) testNetwork1.close(cm)
+            if (this::testNetwork2.isInitialized) testNetwork2.close(cm)
         }
         handlerThread.waitForIdle(TIMEOUT_MS)
         handlerThread.quitSafely()
@@ -601,9 +593,6 @@
 
     @Test
     fun testNsdManager_DiscoverOnNetwork() {
-        // This test requires shims supporting T+ APIs (discovering on specific network)
-        assumeTrue(TestUtils.shouldTestTApis())
-
         val si = NsdServiceInfo()
         si.serviceType = serviceType
         si.serviceName = this.serviceName
@@ -614,19 +603,19 @@
 
         tryTest {
             val discoveryRecord = NsdDiscoveryRecord()
-            nsdShim.discoverServices(nsdManager, serviceType, NsdManager.PROTOCOL_DNS_SD,
+            nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD,
                     testNetwork1.network, Executor { it.run() }, discoveryRecord)
 
             val foundInfo = discoveryRecord.waitForServiceDiscovered(
                     serviceName, serviceType, testNetwork1.network)
-            assertEquals(testNetwork1.network, nsdShim.getNetwork(foundInfo))
+            assertEquals(testNetwork1.network, foundInfo.network)
 
             // Rewind to ensure the service is not found on the other interface
             discoveryRecord.nextEvents.rewind(0)
             assertNull(discoveryRecord.nextEvents.poll(timeoutMs = 100L) {
                 it is ServiceFound &&
                         it.serviceInfo.serviceName == registeredInfo.serviceName &&
-                        nsdShim.getNetwork(it.serviceInfo) != testNetwork1.network
+                        it.serviceInfo.network != testNetwork1.network
             }, "The service should not be found on this network")
         } cleanup {
             nsdManager.unregisterService(registrationRecord)
@@ -635,9 +624,6 @@
 
     @Test
     fun testNsdManager_DiscoverWithNetworkRequest() {
-        // This test requires shims supporting T+ APIs (discovering on network request)
-        assumeTrue(TestUtils.shouldTestTApis())
-
         val si = NsdServiceInfo()
         si.serviceType = serviceType
         si.serviceName = this.serviceName
@@ -652,7 +638,7 @@
 
         tryTest {
             val specifier = TestNetworkSpecifier(testNetwork1.iface.interfaceName)
-            nsdShim.discoverServices(nsdManager, serviceType, NsdManager.PROTOCOL_DNS_SD,
+            nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD,
                     NetworkRequest.Builder()
                             .removeCapability(NET_CAPABILITY_TRUSTED)
                             .addTransportType(TRANSPORT_TEST)
@@ -667,27 +653,27 @@
             assertEquals(registeredInfo1.serviceName, serviceDiscovered.serviceInfo.serviceName)
             // Discovered service types have a dot at the end
             assertEquals("$serviceType.", serviceDiscovered.serviceInfo.serviceType)
-            assertEquals(testNetwork1.network, nsdShim.getNetwork(serviceDiscovered.serviceInfo))
+            assertEquals(testNetwork1.network, serviceDiscovered.serviceInfo.network)
 
             // Unregister, then register the service back: it should be lost and found again
             nsdManager.unregisterService(registrationRecord)
             val serviceLost1 = discoveryRecord.expectCallback<ServiceLost>()
             assertEquals(registeredInfo1.serviceName, serviceLost1.serviceInfo.serviceName)
-            assertEquals(testNetwork1.network, nsdShim.getNetwork(serviceLost1.serviceInfo))
+            assertEquals(testNetwork1.network, serviceLost1.serviceInfo.network)
 
             registrationRecord.expectCallback<ServiceUnregistered>()
             val registeredInfo2 = registerService(registrationRecord, si, executor)
             val serviceDiscovered2 = discoveryRecord.expectCallback<ServiceFound>()
             assertEquals(registeredInfo2.serviceName, serviceDiscovered2.serviceInfo.serviceName)
             assertEquals("$serviceType.", serviceDiscovered2.serviceInfo.serviceType)
-            assertEquals(testNetwork1.network, nsdShim.getNetwork(serviceDiscovered2.serviceInfo))
+            assertEquals(testNetwork1.network, serviceDiscovered2.serviceInfo.network)
 
             // Teardown, then bring back up a network on the test interface: the service should
             // go away, then come back
             testNetwork1.agent.unregister()
             val serviceLost = discoveryRecord.expectCallback<ServiceLost>()
             assertEquals(registeredInfo2.serviceName, serviceLost.serviceInfo.serviceName)
-            assertEquals(testNetwork1.network, nsdShim.getNetwork(serviceLost.serviceInfo))
+            assertEquals(testNetwork1.network, serviceLost.serviceInfo.network)
 
             val newAgent = runAsShell(MANAGE_TEST_NETWORKS) {
                 registerTestNetworkAgent(testNetwork1.iface.interfaceName)
@@ -696,7 +682,7 @@
             val serviceDiscovered3 = discoveryRecord.expectCallback<ServiceFound>()
             assertEquals(registeredInfo2.serviceName, serviceDiscovered3.serviceInfo.serviceName)
             assertEquals("$serviceType.", serviceDiscovered3.serviceInfo.serviceType)
-            assertEquals(newNetwork, nsdShim.getNetwork(serviceDiscovered3.serviceInfo))
+            assertEquals(newNetwork, serviceDiscovered3.serviceInfo.network)
         } cleanupStep {
             nsdManager.stopServiceDiscovery(discoveryRecord)
             discoveryRecord.expectCallback<DiscoveryStopped>()
@@ -707,9 +693,6 @@
 
     @Test
     fun testNsdManager_DiscoverWithNetworkRequest_NoMatchingNetwork() {
-        // This test requires shims supporting T+ APIs (discovering on network request)
-        assumeTrue(TestUtils.shouldTestTApis())
-
         val si = NsdServiceInfo()
         si.serviceType = serviceType
         si.serviceName = this.serviceName
@@ -722,7 +705,7 @@
         val specifier = TestNetworkSpecifier(testNetwork1.iface.interfaceName)
 
         tryTest {
-            nsdShim.discoverServices(nsdManager, serviceType, NsdManager.PROTOCOL_DNS_SD,
+            nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD,
                     NetworkRequest.Builder()
                             .removeCapability(NET_CAPABILITY_TRUSTED)
                             .addTransportType(TRANSPORT_TEST)
@@ -754,9 +737,6 @@
 
     @Test
     fun testNsdManager_ResolveOnNetwork() {
-        // This test requires shims supporting T+ APIs (NsdServiceInfo.network)
-        assumeTrue(TestUtils.shouldTestTApis())
-
         val si = NsdServiceInfo()
         si.serviceType = serviceType
         si.serviceName = this.serviceName
@@ -772,21 +752,21 @@
 
             val foundInfo1 = discoveryRecord.waitForServiceDiscovered(
                     serviceName, serviceType, testNetwork1.network)
-            assertEquals(testNetwork1.network, nsdShim.getNetwork(foundInfo1))
+            assertEquals(testNetwork1.network, foundInfo1.network)
             // Rewind as the service could be found on each interface in any order
             discoveryRecord.nextEvents.rewind(0)
             val foundInfo2 = discoveryRecord.waitForServiceDiscovered(
                     serviceName, serviceType, testNetwork2.network)
-            assertEquals(testNetwork2.network, nsdShim.getNetwork(foundInfo2))
+            assertEquals(testNetwork2.network, foundInfo2.network)
 
-            nsdShim.resolveService(nsdManager, foundInfo1, Executor { it.run() }, resolveRecord)
+            nsdManager.resolveService(foundInfo1, Executor { it.run() }, resolveRecord)
             val cb = resolveRecord.expectCallback<ServiceResolved>()
             cb.serviceInfo.let {
                 // Resolved service type has leading dot
                 assertEquals(".$serviceType", it.serviceType)
                 assertEquals(registeredInfo.serviceName, it.serviceName)
                 assertEquals(si.port, it.port)
-                assertEquals(testNetwork1.network, nsdShim.getNetwork(it))
+                assertEquals(testNetwork1.network, it.network)
                 checkAddressScopeId(testNetwork1.iface, it.hostAddresses)
             }
             // TODO: check that MDNS packets are sent only on testNetwork1.
@@ -799,9 +779,6 @@
 
     @Test
     fun testNsdManager_RegisterOnNetwork() {
-        // This test requires shims supporting T+ APIs (NsdServiceInfo.network)
-        assumeTrue(TestUtils.shouldTestTApis())
-
         val si = NsdServiceInfo()
         si.serviceType = serviceType
         si.serviceName = this.serviceName
@@ -817,27 +794,27 @@
 
         tryTest {
             // Discover service on testNetwork1.
-            nsdShim.discoverServices(nsdManager, serviceType, NsdManager.PROTOCOL_DNS_SD,
+            nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD,
                 testNetwork1.network, Executor { it.run() }, discoveryRecord)
             // Expect that service is found on testNetwork1
             val foundInfo = discoveryRecord.waitForServiceDiscovered(
                 serviceName, serviceType, testNetwork1.network)
-            assertEquals(testNetwork1.network, nsdShim.getNetwork(foundInfo))
+            assertEquals(testNetwork1.network, foundInfo.network)
 
             // Discover service on testNetwork2.
-            nsdShim.discoverServices(nsdManager, serviceType, NsdManager.PROTOCOL_DNS_SD,
+            nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD,
                 testNetwork2.network, Executor { it.run() }, discoveryRecord2)
             // Expect that discovery is started then no other callbacks.
             discoveryRecord2.expectCallback<DiscoveryStarted>()
             discoveryRecord2.assertNoCallback()
 
             // Discover service on all networks (not specify any network).
-            nsdShim.discoverServices(nsdManager, serviceType, NsdManager.PROTOCOL_DNS_SD,
+            nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD,
                 null as Network? /* network */, Executor { it.run() }, discoveryRecord3)
             // Expect that service is found on testNetwork1
             val foundInfo3 = discoveryRecord3.waitForServiceDiscovered(
                     serviceName, serviceType, testNetwork1.network)
-            assertEquals(testNetwork1.network, nsdShim.getNetwork(foundInfo3))
+            assertEquals(testNetwork1.network, foundInfo3.network)
         } cleanupStep {
             nsdManager.stopServiceDiscovery(discoveryRecord2)
             discoveryRecord2.expectCallback<DiscoveryStopped>()
@@ -970,9 +947,6 @@
 
     @Test
     fun testStopServiceResolution() {
-        // This test requires shims supporting U+ APIs (NsdManager.stopServiceResolution)
-        assumeTrue(TestUtils.shouldTestUApis())
-
         val si = NsdServiceInfo()
         si.serviceType = this@NsdManagerTest.serviceType
         si.serviceName = this@NsdManagerTest.serviceName
@@ -981,8 +955,8 @@
         val resolveRecord = NsdResolveRecord()
         // Try to resolve an unknown service then stop it immediately.
         // Expected ResolutionStopped callback.
-        nsdShim.resolveService(nsdManager, si, { it.run() }, resolveRecord)
-        nsdShim.stopServiceResolution(nsdManager, resolveRecord)
+        nsdManager.resolveService(si, { it.run() }, resolveRecord)
+        nsdManager.stopServiceResolution(resolveRecord)
         val stoppedCb = resolveRecord.expectCallback<ResolutionStopped>()
         assertEquals(si.serviceName, stoppedCb.serviceInfo.serviceName)
         assertEquals(si.serviceType, stoppedCb.serviceInfo.serviceType)
@@ -990,9 +964,6 @@
 
     @Test
     fun testRegisterServiceInfoCallback() {
-        // This test requires shims supporting U+ APIs (NsdManager.registerServiceInfoCallback)
-        assumeTrue(TestUtils.shouldTestUApis())
-
         val lp = cm.getLinkProperties(testNetwork1.network)
         assertNotNull(lp)
         val addresses = lp.addresses
@@ -1013,13 +984,13 @@
         val cbRecord = NsdServiceInfoCallbackRecord()
         tryTest {
             // Discover service on the network.
-            nsdShim.discoverServices(nsdManager, serviceType, NsdManager.PROTOCOL_DNS_SD,
+            nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD,
                     testNetwork1.network, Executor { it.run() }, discoveryRecord)
             val foundInfo = discoveryRecord.waitForServiceDiscovered(
                     serviceName, serviceType, testNetwork1.network)
 
             // Register service callback and check the addresses are the same as network addresses
-            nsdShim.registerServiceInfoCallback(nsdManager, foundInfo, { it.run() }, cbRecord)
+            nsdManager.registerServiceInfoCallback(foundInfo, { it.run() }, cbRecord)
             val serviceInfoCb = cbRecord.expectCallback<ServiceUpdated>()
             assertEquals(foundInfo.serviceName, serviceInfoCb.serviceInfo.serviceName)
             val hostAddresses = serviceInfoCb.serviceInfo.hostAddresses
@@ -1035,7 +1006,7 @@
             cbRecord.expectCallback<ServiceUpdatedLost>()
         } cleanupStep {
             // Cancel subscription and check stop callback received.
-            nsdShim.unregisterServiceInfoCallback(nsdManager, cbRecord)
+            nsdManager.unregisterServiceInfoCallback(cbRecord)
             cbRecord.expectCallback<UnregisterCallbackSucceeded>()
         } cleanup {
             nsdManager.stopServiceDiscovery(discoveryRecord)
@@ -1045,9 +1016,6 @@
 
     @Test
     fun testStopServiceResolutionFailedCallback() {
-        // This test requires shims supporting U+ APIs (NsdManager.stopServiceResolution)
-        assumeTrue(TestUtils.shouldTestUApis())
-
         // It's not possible to make ResolutionListener#onStopResolutionFailed callback sending
         // because it is only sent in very edge-case scenarios when the legacy implementation is
         // used, and the legacy implementation is never used in the current AOSP builds. Considering
@@ -1115,7 +1083,7 @@
         si: NsdServiceInfo,
         executor: Executor = Executor { it.run() }
     ): NsdServiceInfo {
-        nsdShim.registerService(nsdManager, si, NsdManager.PROTOCOL_DNS_SD, executor, record)
+        nsdManager.registerService(si, NsdManager.PROTOCOL_DNS_SD, executor, record)
         // We may not always get the name that we tried to register;
         // This events tells us the name that was registered.
         val cb = record.expectCallback<ServiceRegistered>(REGISTRATION_TIMEOUT_MS)
@@ -1124,7 +1092,7 @@
 
     private fun resolveService(discoveredInfo: NsdServiceInfo): NsdServiceInfo {
         val record = NsdResolveRecord()
-        nsdShim.resolveService(nsdManager, discoveredInfo, Executor { it.run() }, record)
+        nsdManager.resolveService(discoveredInfo, Executor { it.run() }, record)
         val resolvedCb = record.expectCallback<ServiceResolved>()
         assertEquals(discoveredInfo.serviceName, resolvedCb.serviceInfo.serviceName)
 
diff --git a/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java b/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
index eeffbe1..b30c9ce 100644
--- a/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
+++ b/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
@@ -612,17 +612,60 @@
         verifyNoMoreInteractions(ignoreStubs(testInfo.socketKeepaliveCallback));
     }
 
-    @Test
-    public void testStartNattKeepalive_addressTranslationOnClat() throws Exception {
-        final InetAddress v6AddrSrc = InetAddresses.parseNumericAddress("2001:db8::1");
-        final InetAddress v6AddrDst = InetAddresses.parseNumericAddress("2001:db8::2");
-        doReturn(v6AddrDst).when(mNai).translateV4toClatV6(any());
-        doReturn(v6AddrSrc).when(mNai).getClatv6SrcAddress();
+    private void setupTestNaiForClat(InetAddress v6Src, InetAddress v6Dst) throws Exception {
+        doReturn(v6Dst).when(mNai).translateV4toClatV6(any());
+        doReturn(v6Src).when(mNai).getClatv6SrcAddress();
         doReturn(InetAddress.getByAddress(V4_SRC_ADDR)).when(mNai).getClatv4SrcAddress();
         // Setup nai to add clat address
         final LinkProperties stacked = new LinkProperties();
         stacked.setInterfaceName(TEST_V4_IFACE);
+        final InetAddress srcAddress = InetAddress.getByAddress(
+                new byte[] { (byte) 192, 0, 0, (byte) 129 });
+        mNai.linkProperties.addLinkAddress(new LinkAddress(srcAddress, 24));
         mNai.linkProperties.addStackedLink(stacked);
+    }
+
+    private TestKeepaliveInfo doStartTcpKeepalive(InetAddress srcAddr) throws Exception {
+        final KeepalivePacketData kpd = new TcpKeepalivePacketData(
+                srcAddr,
+                12345 /* srcPort */,
+                InetAddress.getByAddress(new byte[] { 8, 8, 8, 8}) /* dstAddr */,
+                12345 /* dstPort */, new byte[] {1},  111 /* tcpSeq */,
+                222 /* tcpAck */, 800 /* tcpWindow */, 2 /* tcpWindowScale */,
+                4 /* ipTos */, 64 /* ipTtl */);
+        final TestKeepaliveInfo testInfo = new TestKeepaliveInfo(kpd);
+
+        final KeepaliveInfo ki = mKeepaliveTracker.new KeepaliveInfo(
+                testInfo.socketKeepaliveCallback, mNai, kpd,
+                TEST_KEEPALIVE_INTERVAL_SEC, KeepaliveInfo.TYPE_TCP, testInfo.fd);
+        mKeepaliveTracker.setReturnedKeepaliveInfo(ki);
+
+        // Setup TCP keepalive.
+        mAOOKeepaliveTracker.startTcpKeepalive(mNai, testInfo.fd, TEST_KEEPALIVE_INTERVAL_SEC,
+                testInfo.socketKeepaliveCallback);
+        HandlerUtils.waitForIdle(mTestHandler, TIMEOUT_MS);
+        return testInfo;
+    }
+    @Test
+    public void testStartTcpKeepalive_addressTranslationOnClat() throws Exception {
+        setupTestNaiForClat(InetAddresses.parseNumericAddress("2001:db8::1") /* v6Src */,
+                InetAddresses.parseNumericAddress("2001:db8::2") /* v6Dst */);
+        final InetAddress srcAddr = InetAddress.getByAddress(V4_SRC_ADDR);
+        doStartTcpKeepalive(srcAddr);
+        final ArgumentCaptor<TcpKeepalivePacketData> tpdCaptor =
+                ArgumentCaptor.forClass(TcpKeepalivePacketData.class);
+        verify(mNai).onStartTcpSocketKeepalive(
+                eq(TEST_SLOT), eq(TEST_KEEPALIVE_INTERVAL_SEC), tpdCaptor.capture());
+        final TcpKeepalivePacketData tpd = tpdCaptor.getValue();
+        // Verify the addresses still be the same address when clat is started.
+        assertEquals(srcAddr, tpd.getSrcAddress());
+    }
+
+    @Test
+    public void testStartNattKeepalive_addressTranslationOnClat() throws Exception {
+        final InetAddress v6AddrSrc = InetAddresses.parseNumericAddress("2001:db8::1");
+        final InetAddress v6AddrDst = InetAddresses.parseNumericAddress("2001:db8::2");
+        setupTestNaiForClat(v6AddrSrc, v6AddrDst);
 
         final TestKeepaliveInfo testInfo = doStartNattKeepalive();
         final ArgumentCaptor<NattKeepalivePacketData> kpdCaptor =
@@ -899,24 +942,8 @@
                 new byte[] { (byte) 192, 0, 0, (byte) 129 });
         mNai.linkProperties.addLinkAddress(new LinkAddress(srcAddress, 24));
 
-        final KeepalivePacketData kpd = new TcpKeepalivePacketData(
-                InetAddress.getByAddress(new byte[] { (byte) 192, 0, 0, (byte) 129 }) /* srcAddr */,
-                12345 /* srcPort */,
-                InetAddress.getByAddress(new byte[] { 8, 8, 8, 8}) /* dstAddr */,
-                12345 /* dstPort */, new byte[] {1},  111 /* tcpSeq */,
-                222 /* tcpAck */, 800 /* tcpWindow */, 2 /* tcpWindowScale */,
-                4 /* ipTos */, 64 /* ipTtl */);
-        final TestKeepaliveInfo testInfo = new TestKeepaliveInfo(kpd);
-
-        final KeepaliveInfo ki = mKeepaliveTracker.new KeepaliveInfo(
-                testInfo.socketKeepaliveCallback, mNai, kpd,
-                TEST_KEEPALIVE_INTERVAL_SEC, KeepaliveInfo.TYPE_TCP, testInfo.fd);
-        mKeepaliveTracker.setReturnedKeepaliveInfo(ki);
-
-        // Setup TCP keepalive.
-        mAOOKeepaliveTracker.startTcpKeepalive(mNai, testInfo.fd, TEST_KEEPALIVE_INTERVAL_SEC,
-                testInfo.socketKeepaliveCallback);
-        HandlerUtils.waitForIdle(mTestHandler, TIMEOUT_MS);
+        final TestKeepaliveInfo testInfo =
+                doStartTcpKeepalive(InetAddress.getByAddress(V4_SRC_ADDR));
 
         // A closed socket will result in EVENT_HANGUP and trigger error to
         // FileDescriptorEventListener.
@@ -924,6 +951,6 @@
         HandlerUtils.waitForIdle(mTestHandler, TIMEOUT_MS);
 
         // The keepalive should be removed in AutomaticOnOffKeepaliveTracker.
-        getAutoKiForBinder(testInfo.binder);
+        assertNull(getAutoKiForBinder(testInfo.binder));
     }
 }
diff --git a/tests/unit/java/com/android/server/connectivity/VpnTest.java b/tests/unit/java/com/android/server/connectivity/VpnTest.java
index 7829cb6..9ae727d 100644
--- a/tests/unit/java/com/android/server/connectivity/VpnTest.java
+++ b/tests/unit/java/com/android/server/connectivity/VpnTest.java
@@ -1965,7 +1965,16 @@
 
         vpn.startVpnProfile(TEST_VPN_PKG);
         final NetworkCallback nwCb = triggerOnAvailableAndGetCallback(underlyingNetworkCaps);
-        verify(mExecutor, atLeastOnce()).schedule(any(Runnable.class), anyLong(), any());
+        // There are 4 interactions with the executor.
+        // - Network available
+        // - LP change
+        // - NC change
+        // - schedule() calls in scheduleStartIkeSession()
+        // The first 3 calls are triggered from Executor.execute(). The execute() will also call to
+        // schedule() with 0 delay. Verify the exact interaction here so that it won't cause flakes
+        // in the follow-up flow.
+        verify(mExecutor, timeout(TEST_TIMEOUT_MS).times(4))
+                .schedule(any(Runnable.class), anyLong(), any());
         reset(mExecutor);
 
         // Mock the setup procedure by firing callbacks