Merge "Add CTS for DnsOptions API"
diff --git a/Cronet/tests/common/Android.bp b/Cronet/tests/common/Android.bp
index 1ed0881..f8bdb08 100644
--- a/Cronet/tests/common/Android.bp
+++ b/Cronet/tests/common/Android.bp
@@ -35,5 +35,5 @@
         "NetHttpTestsLibPreJarJar",
     ],
     jarjar_rules: ":framework-tethering-jarjar-rules",
-    compile_multilib: "both",
+    compile_multilib: "both", // Include both the 32 and 64 bit versions
 }
diff --git a/Cronet/tests/cts/Android.bp b/Cronet/tests/cts/Android.bp
index 9f92dba..d260694 100644
--- a/Cronet/tests/cts/Android.bp
+++ b/Cronet/tests/cts/Android.bp
@@ -75,7 +75,6 @@
         "CronetTestJavaDefaults",
     ],
     sdk_version: "test_current",
-    compile_multilib: "both", // Include both the 32 and 64 bit versions
     static_libs: ["CtsNetHttpTestsLib"],
     // Tag this as a cts test artifact
     test_suites: [
diff --git a/Cronet/tests/mts/Android.bp b/Cronet/tests/mts/Android.bp
index cfafc5b..1cabd63 100644
--- a/Cronet/tests/mts/Android.bp
+++ b/Cronet/tests/mts/Android.bp
@@ -37,7 +37,6 @@
         "CronetTestJavaDefaults",
         "mts-target-sdk-version-current",
      ],
-     compile_multilib: "both",
      sdk_version: "test_current",
      static_libs: ["NetHttpTestsLibPreJarJar"],
      jarjar_rules: ":framework-tethering-jarjar-rules",
diff --git a/Tethering/common/TetheringLib/cronet_enabled/api/current.txt b/Tethering/common/TetheringLib/cronet_enabled/api/current.txt
index c8bcb3b..777138d 100644
--- a/Tethering/common/TetheringLib/cronet_enabled/api/current.txt
+++ b/Tethering/common/TetheringLib/cronet_enabled/api/current.txt
@@ -1,11 +1,50 @@
 // Signature format: 2.0
 package android.net.http {
 
+  public abstract class BidirectionalStream {
+    ctor public BidirectionalStream();
+    method public abstract void cancel();
+    method public abstract void flush();
+    method public abstract boolean isDone();
+    method public abstract void read(java.nio.ByteBuffer);
+    method public abstract void start();
+    method public abstract void write(java.nio.ByteBuffer, boolean);
+  }
+
+  public abstract static class BidirectionalStream.Builder {
+    ctor public BidirectionalStream.Builder();
+    method public abstract android.net.http.BidirectionalStream.Builder addHeader(String, String);
+    method public abstract android.net.http.BidirectionalStream build();
+    method public abstract android.net.http.BidirectionalStream.Builder delayRequestHeadersUntilFirstFlush(boolean);
+    method public abstract android.net.http.BidirectionalStream.Builder setHttpMethod(String);
+    method public abstract android.net.http.BidirectionalStream.Builder setPriority(int);
+    method public abstract android.net.http.BidirectionalStream.Builder setTrafficStatsTag(int);
+    method public abstract android.net.http.BidirectionalStream.Builder setTrafficStatsUid(int);
+    field public static final int STREAM_PRIORITY_HIGHEST = 4; // 0x4
+    field public static final int STREAM_PRIORITY_IDLE = 0; // 0x0
+    field public static final int STREAM_PRIORITY_LOW = 2; // 0x2
+    field public static final int STREAM_PRIORITY_LOWEST = 1; // 0x1
+    field public static final int STREAM_PRIORITY_MEDIUM = 3; // 0x3
+  }
+
+  public abstract static class BidirectionalStream.Callback {
+    ctor public BidirectionalStream.Callback();
+    method public void onCanceled(android.net.http.BidirectionalStream, android.net.http.UrlResponseInfo);
+    method public abstract void onFailed(android.net.http.BidirectionalStream, android.net.http.UrlResponseInfo, android.net.http.HttpException);
+    method public abstract void onReadCompleted(android.net.http.BidirectionalStream, android.net.http.UrlResponseInfo, java.nio.ByteBuffer, boolean);
+    method public abstract void onResponseHeadersReceived(android.net.http.BidirectionalStream, android.net.http.UrlResponseInfo);
+    method public void onResponseTrailersReceived(android.net.http.BidirectionalStream, android.net.http.UrlResponseInfo, android.net.http.UrlResponseInfo.HeaderBlock);
+    method public abstract void onStreamReady(android.net.http.BidirectionalStream);
+    method public abstract void onSucceeded(android.net.http.BidirectionalStream, android.net.http.UrlResponseInfo);
+    method public abstract void onWriteCompleted(android.net.http.BidirectionalStream, android.net.http.UrlResponseInfo, java.nio.ByteBuffer, boolean);
+  }
+
   public abstract class CallbackException extends android.net.http.HttpException {
     ctor protected CallbackException(String, Throwable);
   }
 
   public class ConnectionMigrationOptions {
+    method @Nullable public Boolean getAllowNonDefaultNetworkUsage();
     method @Nullable public Boolean getEnableDefaultNetworkMigration();
     method @Nullable public Boolean getEnablePathDegradationMigration();
   }
@@ -13,26 +52,52 @@
   public static class ConnectionMigrationOptions.Builder {
     ctor public ConnectionMigrationOptions.Builder();
     method public android.net.http.ConnectionMigrationOptions build();
+    method public android.net.http.ConnectionMigrationOptions.Builder setAllowNonDefaultNetworkUsage(boolean);
     method public android.net.http.ConnectionMigrationOptions.Builder setEnableDefaultNetworkMigration(boolean);
     method public android.net.http.ConnectionMigrationOptions.Builder setEnablePathDegradationMigration(boolean);
   }
 
   public final class DnsOptions {
+    method @Nullable public Boolean getEnableStaleDns();
     method @Nullable public Boolean getPersistHostCache();
     method @Nullable public java.time.Duration getPersistHostCachePeriod();
+    method @Nullable public Boolean getPreestablishConnectionsToStaleDnsResults();
+    method @Nullable public android.net.http.DnsOptions.StaleDnsOptions getStaleDnsOptions();
+    method @Nullable public Boolean getUseHttpStackDnsResolver();
   }
 
   public static final class DnsOptions.Builder {
     ctor public DnsOptions.Builder();
     method public android.net.http.DnsOptions build();
+    method public android.net.http.DnsOptions.Builder setEnableStaleDns(boolean);
     method public android.net.http.DnsOptions.Builder setPersistHostCache(boolean);
     method public android.net.http.DnsOptions.Builder setPersistHostCachePeriod(java.time.Duration);
+    method public android.net.http.DnsOptions.Builder setPreestablishConnectionsToStaleDnsResults(boolean);
+    method public android.net.http.DnsOptions.Builder setStaleDnsOptions(android.net.http.DnsOptions.StaleDnsOptions);
+    method public android.net.http.DnsOptions.Builder setUseHttpStackDnsResolver(boolean);
+  }
+
+  public static class DnsOptions.StaleDnsOptions {
+    method @Nullable public Boolean getAllowCrossNetworkUsage();
+    method @Nullable public Long getFreshLookupTimeoutMillis();
+    method @Nullable public Long getMaxExpiredDelayMillis();
+    method @Nullable public Boolean getUseStaleOnNameNotResolved();
+  }
+
+  public static final class DnsOptions.StaleDnsOptions.Builder {
+    ctor public DnsOptions.StaleDnsOptions.Builder();
+    method public android.net.http.DnsOptions.StaleDnsOptions build();
+    method public android.net.http.DnsOptions.StaleDnsOptions.Builder setAllowCrossNetworkUsage(boolean);
+    method public android.net.http.DnsOptions.StaleDnsOptions.Builder setFreshLookupTimeout(java.time.Duration);
+    method public android.net.http.DnsOptions.StaleDnsOptions.Builder setMaxExpiredDelay(java.time.Duration);
+    method public android.net.http.DnsOptions.StaleDnsOptions.Builder setUseStaleOnNameNotResolved(boolean);
   }
 
   public abstract class HttpEngine {
     method public void bindToNetwork(@Nullable android.net.Network);
     method public abstract java.net.URLStreamHandlerFactory createURLStreamHandlerFactory();
     method public static String getVersionString();
+    method public abstract android.net.http.BidirectionalStream.Builder newBidirectionalStreamBuilder(String, android.net.http.BidirectionalStream.Callback, java.util.concurrent.Executor);
     method public abstract android.net.http.UrlRequest.Builder newUrlRequestBuilder(String, android.net.http.UrlRequest.Callback, java.util.concurrent.Executor);
     method public abstract java.net.URLConnection openConnection(java.net.URL) throws java.io.IOException;
     method public abstract void shutdown();
@@ -100,6 +165,7 @@
     method public android.net.http.QuicOptions.Builder addAllowedQuicHost(String);
     method public android.net.http.QuicOptions build();
     method public android.net.http.QuicOptions.Builder setHandshakeUserAgent(String);
+    method public android.net.http.QuicOptions.Builder setIdleConnectionTimeout(java.time.Duration);
     method public android.net.http.QuicOptions.Builder setInMemoryServerConfigsCacheSize(int);
   }
 
@@ -136,6 +202,8 @@
     method public abstract android.net.http.UrlRequest.Builder disableCache();
     method public abstract android.net.http.UrlRequest.Builder setHttpMethod(String);
     method public abstract android.net.http.UrlRequest.Builder setPriority(int);
+    method public abstract android.net.http.UrlRequest.Builder setTrafficStatsTag(int);
+    method public abstract android.net.http.UrlRequest.Builder setTrafficStatsUid(int);
     method public abstract android.net.http.UrlRequest.Builder setUploadDataProvider(android.net.http.UploadDataProvider, java.util.concurrent.Executor);
     field public static final int REQUEST_PRIORITY_HIGHEST = 4; // 0x4
     field public static final int REQUEST_PRIORITY_IDLE = 0; // 0x0
@@ -180,8 +248,7 @@
 
   public abstract class UrlResponseInfo {
     ctor public UrlResponseInfo();
-    method public abstract java.util.Map<java.lang.String,java.util.List<java.lang.String>> getAllHeaders();
-    method public abstract java.util.List<java.util.Map.Entry<java.lang.String,java.lang.String>> getAllHeadersAsList();
+    method public abstract android.net.http.UrlResponseInfo.HeaderBlock getHeaders();
     method public abstract int getHttpStatusCode();
     method public abstract String getHttpStatusText();
     method public abstract String getNegotiatedProtocol();
@@ -192,5 +259,11 @@
     method public abstract boolean wasCached();
   }
 
+  public abstract static class UrlResponseInfo.HeaderBlock {
+    ctor public UrlResponseInfo.HeaderBlock();
+    method public abstract java.util.List<java.util.Map.Entry<java.lang.String,java.lang.String>> getAsList();
+    method public abstract java.util.Map<java.lang.String,java.util.List<java.lang.String>> getAsMap();
+  }
+
 }
 
diff --git a/service-t/src/com/android/server/ethernet/EthernetNetworkFactory.java b/service-t/src/com/android/server/ethernet/EthernetNetworkFactory.java
index 60485f1..eed9aeb 100644
--- a/service-t/src/com/android/server/ethernet/EthernetNetworkFactory.java
+++ b/service-t/src/com/android/server/ethernet/EthernetNetworkFactory.java
@@ -469,9 +469,7 @@
             // TODO: Update this logic to only do a restart if required. Although a restart may
             //  be required due to the capabilities or ipConfiguration values, not all
             //  capabilities changes require a restart.
-            if (mIpClient != null) {
-                restart();
-            }
+            maybeRestart();
         }
 
         boolean isRestricted() {
@@ -549,7 +547,7 @@
                 // Send a callback in case a provisioning request was in progress.
                 return;
             }
-            restart();
+            maybeRestart();
         }
 
         private void ensureRunningOnEthernetHandlerThread() {
@@ -582,7 +580,7 @@
             // If there is a better network, that will become default and apps
             // will be able to use internet. If ethernet gets connected again,
             // and has backhaul connectivity, it will become default.
-            restart();
+            maybeRestart();
         }
 
         /** Returns true if state has been modified */
@@ -656,18 +654,16 @@
                         .build();
         }
 
-        void restart() {
-            if (DBG) Log.d(TAG, "restart IpClient");
-
+        void maybeRestart() {
             if (mIpClient == null) {
-                // If restart() is called from a provisioning failure, it is
+                // If maybeRestart() is called from a provisioning failure, it is
                 // possible that link disappeared in the meantime. In that
                 // case, stop() has already been called and IpClient should not
                 // get restarted to prevent a provisioning failure loop.
-                Log.i(TAG, String.format("restart() was called on stopped interface %s", name));
+                Log.i(TAG, String.format("maybeRestart() called on stopped interface %s", name));
                 return;
             }
-
+            if (DBG) Log.d(TAG, "restart IpClient");
             stop();
             start();
         }
diff --git a/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java b/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
index 760a5f3..cd7a842 100644
--- a/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
+++ b/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
@@ -19,7 +19,7 @@
 import static android.net.NetworkAgent.CMD_START_SOCKET_KEEPALIVE;
 import static android.net.SocketKeepalive.ERROR_INVALID_SOCKET;
 import static android.net.SocketKeepalive.SUCCESS_PAUSED;
-import static android.provider.DeviceConfig.NAMESPACE_CONNECTIVITY;
+import static android.provider.DeviceConfig.NAMESPACE_TETHERING;
 import static android.system.OsConstants.AF_INET;
 import static android.system.OsConstants.AF_INET6;
 import static android.system.OsConstants.SOL_SOCKET;
@@ -206,7 +206,8 @@
             this.mKi = Objects.requireNonNull(ki);
             mCallback = ki.mCallback;
             mUnderpinnedNetwork = underpinnedNetwork;
-            if (autoOnOff && mDependencies.isFeatureEnabled(AUTOMATIC_ON_OFF_KEEPALIVE_VERSION)) {
+            if (autoOnOff && mDependencies.isFeatureEnabled(AUTOMATIC_ON_OFF_KEEPALIVE_VERSION,
+                    true /* defaultEnabled */)) {
                 mAutomaticOnOffState = STATE_ENABLED;
                 if (null == ki.mFd) {
                     throw new IllegalArgumentException("fd can't be null with automatic "
@@ -568,7 +569,8 @@
         // Clear calling identity to align the calling uid and package so that it won't fail if cts
         // would like to do the dump()
         final boolean featureEnabled = BinderUtils.withCleanCallingIdentity(
-                () -> mDependencies.isFeatureEnabled(AUTOMATIC_ON_OFF_KEEPALIVE_VERSION));
+                () -> mDependencies.isFeatureEnabled(AUTOMATIC_ON_OFF_KEEPALIVE_VERSION,
+                        true /* defaultEnabled */));
         pw.println("AutomaticOnOff enabled: " + featureEnabled);
         pw.increaseIndent();
         for (AutomaticOnOffKeepalive autoKi : mAutomaticOnOffKeepalives) {
@@ -798,10 +800,13 @@
          * Find out if a feature is enabled from DeviceConfig.
          *
          * @param name The name of the property to look up.
+         * @param defaultEnabled whether to consider the feature enabled in the absence of
+         *                       the flag. This MUST be a statically-known constant.
          * @return whether the feature is enabled
          */
-        public boolean isFeatureEnabled(@NonNull final String name) {
-            return DeviceConfigUtils.isFeatureEnabled(mContext, NAMESPACE_CONNECTIVITY, name);
+        public boolean isFeatureEnabled(@NonNull final String name, final boolean defaultEnabled) {
+            return DeviceConfigUtils.isFeatureEnabled(mContext, NAMESPACE_TETHERING, name,
+                    defaultEnabled);
         }
     }
 }
diff --git a/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java b/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
index f1ab62d..a5bf000 100755
--- a/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
+++ b/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
@@ -95,7 +95,6 @@
 import android.system.Os;
 import android.system.OsConstants;
 import android.system.StructPollfd;
-import android.telephony.TelephonyManager;
 import android.test.MoreAsserts;
 import android.text.TextUtils;
 import android.util.ArraySet;
@@ -199,8 +198,8 @@
     private RemoteSocketFactoryClient mRemoteSocketFactoryClient;
     private CtsNetUtils mCtsNetUtils;
     private PackageManager mPackageManager;
-    private TelephonyManager mTelephonyManager;
-
+    private Context mTestContext;
+    private Context mTargetContext;
     Network mNetwork;
     NetworkCallback mCallback;
     final Object mLock = new Object();
@@ -230,21 +229,19 @@
     public void setUp() throws Exception {
         mNetwork = null;
         mCallback = null;
+        mTestContext = getInstrumentation().getContext();
+        mTargetContext = getInstrumentation().getTargetContext();
         storePrivateDnsSetting();
-
         mDevice = UiDevice.getInstance(getInstrumentation());
-        mActivity = launchActivity(getInstrumentation().getTargetContext().getPackageName(),
-                MyActivity.class);
+        mActivity = launchActivity(mTargetContext.getPackageName(), MyActivity.class);
         mPackageName = mActivity.getPackageName();
         mCM = (ConnectivityManager) mActivity.getSystemService(Context.CONNECTIVITY_SERVICE);
         mWifiManager = (WifiManager) mActivity.getSystemService(Context.WIFI_SERVICE);
         mRemoteSocketFactoryClient = new RemoteSocketFactoryClient(mActivity);
         mRemoteSocketFactoryClient.bind();
         mDevice.waitForIdle();
-        mCtsNetUtils = new CtsNetUtils(getInstrumentation().getContext());
-        mPackageManager = getInstrumentation().getContext().getPackageManager();
-        mTelephonyManager =
-                getInstrumentation().getContext().getSystemService(TelephonyManager.class);
+        mCtsNetUtils = new CtsNetUtils(mTestContext);
+        mPackageManager = mTestContext.getPackageManager();
     }
 
     @After
@@ -743,7 +740,7 @@
     }
 
     private ContentResolver getContentResolver() {
-        return getInstrumentation().getContext().getContentResolver();
+        return mTestContext.getContentResolver();
     }
 
     private boolean isPrivateDnsInStrictMode() {
@@ -792,7 +789,7 @@
     }
 
     private void setAndVerifyPrivateDns(boolean strictMode) throws Exception {
-        final ContentResolver cr = getInstrumentation().getContext().getContentResolver();
+        final ContentResolver cr = mTestContext.getContentResolver();
         String privateDnsHostname;
 
         if (strictMode) {
@@ -930,7 +927,7 @@
         }
 
         final BlockingBroadcastReceiver receiver = new BlockingBroadcastReceiver(
-                getInstrumentation().getTargetContext(), MyVpnService.ACTION_ESTABLISHED);
+                mTargetContext, MyVpnService.ACTION_ESTABLISHED);
         receiver.register();
 
         // Test the behaviour of a variety of types of network callbacks.
@@ -1526,7 +1523,7 @@
         private boolean received;
 
         public ProxyChangeBroadcastReceiver() {
-            super(getInstrumentation().getContext(), Proxy.PROXY_CHANGE_ACTION);
+            super(mTestContext, Proxy.PROXY_CHANGE_ACTION);
             received = false;
         }
 
@@ -1556,12 +1553,11 @@
                 "" /* allowedApps */, "com.android.providers.downloads", null /* proxyInfo */,
                 null /* underlyingNetworks */, false /* isAlwaysMetered */);
 
-        final Context context = getInstrumentation().getContext();
-        final DownloadManager dm = context.getSystemService(DownloadManager.class);
+        final DownloadManager dm = mTestContext.getSystemService(DownloadManager.class);
         final DownloadCompleteReceiver receiver = new DownloadCompleteReceiver();
         try {
             final int flags = SdkLevel.isAtLeastT() ? RECEIVER_EXPORTED : 0;
-            context.registerReceiver(receiver,
+            mTestContext.registerReceiver(receiver,
                     new IntentFilter(DownloadManager.ACTION_DOWNLOAD_COMPLETE), flags);
 
             // Enqueue a request and check only one download.
@@ -1579,7 +1575,7 @@
             assertEquals(1, dm.remove(id));
             assertEquals(0, getTotalNumberDownloads(dm, new Query()));
         } finally {
-            context.unregisterReceiver(receiver);
+            mTestContext.unregisterReceiver(receiver);
         }
     }
 
@@ -1616,8 +1612,7 @@
 
         // Create a TUN interface
         final FileDescriptor tunFd = runWithShellPermissionIdentity(() -> {
-            final TestNetworkManager tnm = getInstrumentation().getContext().getSystemService(
-                    TestNetworkManager.class);
+            final TestNetworkManager tnm = mTestContext.getSystemService(TestNetworkManager.class);
             final TestNetworkInterface iface = tnm.createTunInterface(List.of(
                     TEST_IP4_DST_ADDR, TEST_IP6_DST_ADDR));
             return iface.getFileDescriptor().getFileDescriptor();
diff --git a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
index 68b20e2..d2a3f91 100644
--- a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
+++ b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
@@ -3409,12 +3409,16 @@
         runWithShellPermissionIdentity(() -> {
             // Firewall chain status will be restored after the test.
             final boolean wasChainEnabled = mCm.getFirewallChainEnabled(chain);
+            final int previousUidFirewallRule = mCm.getUidFirewallRule(chain, myUid);
             final DatagramSocket srcSock = new DatagramSocket();
             final DatagramSocket dstSock = new DatagramSocket();
             testAndCleanup(() -> {
                 if (wasChainEnabled) {
                     mCm.setFirewallChainEnabled(chain, false /* enable */);
                 }
+                if (previousUidFirewallRule == ruleToAddMatch) {
+                    mCm.setUidFirewallRule(chain, myUid, ruleToRemoveMatch);
+                }
                 dstSock.setSoTimeout(SOCKET_TIMEOUT_MS);
 
                 // Chain disabled, UID not on chain.
@@ -3444,8 +3448,9 @@
                     // Restore the global chain status
                     mCm.setFirewallChainEnabled(chain, wasChainEnabled);
                 }, /* cleanup */ () -> {
+                    // Restore the uid firewall rule status
                     try {
-                        mCm.setUidFirewallRule(chain, myUid, ruleToRemoveMatch);
+                        mCm.setUidFirewallRule(chain, myUid, previousUidFirewallRule);
                     } catch (IllegalStateException ignored) {
                         // Removing match causes an exception when the rule entry for the uid does
                         // not exist. But this is fine and can be ignored.
diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
index ea50999..e487295 100755
--- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
@@ -2114,7 +2114,7 @@
         }
 
         @Override
-        public boolean isFeatureEnabled(@NonNull final String name) {
+        public boolean isFeatureEnabled(@NonNull final String name, final boolean defaultEnabled) {
             // Tests for enabling the feature are verified in AutomaticOnOffKeepaliveTrackerTest.
             // Assuming enabled here to focus on ConnectivityService tests.
             return true;
@@ -8895,6 +8895,14 @@
         final TestNetworkCallback callback = new TestNetworkCallback();
         mCm.registerNetworkCallback(request, callback);
 
+        // File a VPN request to prevent VPN network being lingered.
+        final NetworkRequest vpnRequest = new NetworkRequest.Builder()
+                .addTransportType(TRANSPORT_VPN)
+                .removeCapability(NET_CAPABILITY_NOT_VPN)
+                .build();
+        final TestNetworkCallback vpnCallback = new TestNetworkCallback();
+        mCm.requestNetwork(vpnRequest, vpnCallback);
+
         // Bring up a VPN
         mMockVpn.establishForMyUid();
         assertUidRangesUpdatedForMyUid(true);
@@ -8953,6 +8961,9 @@
                 && c.getUids().contains(singleUidRange)
                 && c.hasTransport(TRANSPORT_VPN)
                 && !c.hasTransport(TRANSPORT_WIFI));
+
+        mCm.unregisterNetworkCallback(callback);
+        mCm.unregisterNetworkCallback(vpnCallback);
     }
 
     @Test
diff --git a/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java b/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
index 6c29d6e..ddf1d4d 100644
--- a/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
+++ b/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
@@ -20,6 +20,7 @@
 import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyBoolean;
 import static org.mockito.Mockito.doNothing;
 import static org.mockito.Mockito.doReturn;
 
@@ -171,7 +172,7 @@
         mHandlerThread.start();
         doReturn(mKeepaliveTracker).when(mDependencies).newKeepaliveTracker(
                 mCtx, mHandlerThread.getThreadHandler());
-        doReturn(true).when(mDependencies).isFeatureEnabled(any());
+        doReturn(true).when(mDependencies).isFeatureEnabled(any(), anyBoolean());
         mAOOKeepaliveTracker = new AutomaticOnOffKeepaliveTracker(
                 mCtx, mHandlerThread.getThreadHandler(), mDependencies);
     }