Merge "[mdns] clean up unused serviceInfo ref in discovery requests" into main
diff --git a/DnsResolver/include/DnsHelperPublic.h b/DnsResolver/include/DnsHelperPublic.h
index 7c9fc9e..44b0012 100644
--- a/DnsResolver/include/DnsHelperPublic.h
+++ b/DnsResolver/include/DnsHelperPublic.h
@@ -25,7 +25,8 @@
  * Perform any required initialization - including opening any required BPF maps. This function
  * needs to be called before using other functions of this library.
  *
- * Returns 0 on success, a negative POSIX error code (see errno.h) on other failures.
+ * Returns 0 on success, -EOPNOTSUPP when the function is called on the Android version before
+ * T. Returns a negative POSIX error code (see errno.h) on other failures.
  */
 int ADnsHelper_init();
 
@@ -36,7 +37,9 @@
  * |uid| is a Linux/Android UID to be queried. It is a combination of UserID and AppID.
  * |metered| indicates whether the uid is currently using a billing network.
  *
- * Returns 0(false)/1(true) on success, a negative POSIX error code (see errno.h) on other failures.
+ * Returns 0(false)/1(true) on success, -EUNATCH when the ADnsHelper_init is not called before
+ * calling this function. Returns a negative POSIX error code (see errno.h) on other failures
+ * that return from bpf syscall.
  */
 int ADnsHelper_isUidNetworkingBlocked(uid_t uid, bool metered);
 
diff --git a/Tethering/src/com/android/networkstack/tethering/Tethering.java b/Tethering/src/com/android/networkstack/tethering/Tethering.java
index 5022b40..552b105 100644
--- a/Tethering/src/com/android/networkstack/tethering/Tethering.java
+++ b/Tethering/src/com/android/networkstack/tethering/Tethering.java
@@ -136,6 +136,7 @@
 import com.android.modules.utils.build.SdkLevel;
 import com.android.net.module.util.BaseNetdUnsolicitedEventListener;
 import com.android.net.module.util.CollectionUtils;
+import com.android.net.module.util.HandlerUtils;
 import com.android.net.module.util.NetdUtils;
 import com.android.net.module.util.SdkUtil.LateSdk;
 import com.android.net.module.util.SharedLog;
@@ -161,11 +162,8 @@
 import java.util.List;
 import java.util.Objects;
 import java.util.Set;
-import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.Executor;
 import java.util.concurrent.RejectedExecutionException;
-import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicReference;
 
 /**
  *
@@ -2694,31 +2692,10 @@
             return;
         }
 
-        final CountDownLatch latch = new CountDownLatch(1);
-
-        // Don't crash the system if something in doDump throws an exception, but try to propagate
-        // the exception to the caller.
-        AtomicReference<RuntimeException> exceptionRef = new AtomicReference<>();
-        mHandler.post(() -> {
-            try {
-                doDump(fd, writer, args);
-            } catch (RuntimeException e) {
-                exceptionRef.set(e);
-            }
-            latch.countDown();
-        });
-
-        try {
-            if (!latch.await(DUMP_TIMEOUT_MS, TimeUnit.MILLISECONDS)) {
-                writer.println("Dump timeout after " + DUMP_TIMEOUT_MS + "ms");
-                return;
-            }
-        } catch (InterruptedException e) {
-            exceptionRef.compareAndSet(null, new IllegalStateException("Dump interrupted", e));
+        if (!HandlerUtils.runWithScissorsForDump(mHandler, () -> doDump(fd, writer, args),
+                DUMP_TIMEOUT_MS)) {
+            writer.println("Dump timeout after " + DUMP_TIMEOUT_MS + "ms");
         }
-
-        final RuntimeException e = exceptionRef.get();
-        if (e != null) throw e;
     }
 
     private void maybeDhcpLeasesChanged() {
diff --git a/Tethering/tests/unit/src/com/android/networkstack/tethering/TetheringTest.java b/Tethering/tests/unit/src/com/android/networkstack/tethering/TetheringTest.java
index 82b8845..750bfce 100644
--- a/Tethering/tests/unit/src/com/android/networkstack/tethering/TetheringTest.java
+++ b/Tethering/tests/unit/src/com/android/networkstack/tethering/TetheringTest.java
@@ -2810,12 +2810,10 @@
         final FileDescriptor mockFd = mock(FileDescriptor.class);
         final PrintWriter mockPw = mock(PrintWriter.class);
         runUsbTethering(null);
-        mLooper.startAutoDispatch();
         mTethering.dump(mockFd, mockPw, new String[0]);
         verify(mConfig).dump(any());
         verify(mEntitleMgr).dump(any());
         verify(mOffloadCtrl).dump(any());
-        mLooper.stopAutoDispatch();
     }
 
     @Test
diff --git a/service-t/jni/com_android_server_net_NetworkStatsService.cpp b/service-t/jni/com_android_server_net_NetworkStatsService.cpp
index bdbb655..81912ae 100644
--- a/service-t/jni/com_android_server_net_NetworkStatsService.cpp
+++ b/service-t/jni/com_android_server_net_NetworkStatsService.cpp
@@ -34,77 +34,64 @@
 
 using android::bpf::bpfGetUidStats;
 using android::bpf::bpfGetIfaceStats;
-using android::bpf::bpfGetIfIndexStats;
 using android::bpf::NetworkTraceHandler;
 
 namespace android {
 
-// NOTE: keep these in sync with TrafficStats.java
-static const uint64_t UNKNOWN = -1;
-
-enum StatsType {
-    RX_BYTES = 0,
-    RX_PACKETS = 1,
-    TX_BYTES = 2,
-    TX_PACKETS = 3,
-};
-
-static uint64_t getStatsType(StatsValue* stats, StatsType type) {
-    switch (type) {
-        case RX_BYTES:
-            return stats->rxBytes;
-        case RX_PACKETS:
-            return stats->rxPackets;
-        case TX_BYTES:
-            return stats->txBytes;
-        case TX_PACKETS:
-            return stats->txPackets;
-        default:
-            return UNKNOWN;
+static jobject statsValueToEntry(JNIEnv* env, StatsValue* stats) {
+    // Find the Java class that represents the structure
+    jclass gEntryClass = env->FindClass("android/net/NetworkStats$Entry");
+    if (gEntryClass == nullptr) {
+        return nullptr;
     }
+
+    // Create a new instance of the Java class
+    jobject result = env->AllocObject(gEntryClass);
+    if (result == nullptr) {
+        return nullptr;
+    }
+
+    // Set the values of the structure fields in the Java object
+    env->SetLongField(result, env->GetFieldID(gEntryClass, "rxBytes", "J"), stats->rxBytes);
+    env->SetLongField(result, env->GetFieldID(gEntryClass, "txBytes", "J"), stats->txBytes);
+    env->SetLongField(result, env->GetFieldID(gEntryClass, "rxPackets", "J"), stats->rxPackets);
+    env->SetLongField(result, env->GetFieldID(gEntryClass, "txPackets", "J"), stats->txPackets);
+
+    return result;
 }
 
-static jlong nativeGetTotalStat(JNIEnv* env, jclass clazz, jint type) {
+static jobject nativeGetTotalStat(JNIEnv* env, jclass clazz) {
     StatsValue stats = {};
 
     if (bpfGetIfaceStats(NULL, &stats) == 0) {
-        return getStatsType(&stats, (StatsType) type);
+        return statsValueToEntry(env, &stats);
     } else {
-        return UNKNOWN;
+        return nullptr;
     }
 }
 
-static jlong nativeGetIfaceStat(JNIEnv* env, jclass clazz, jstring iface, jint type) {
+static jobject nativeGetIfaceStat(JNIEnv* env, jclass clazz, jstring iface) {
     ScopedUtfChars iface8(env, iface);
     if (iface8.c_str() == NULL) {
-        return UNKNOWN;
+        return nullptr;
     }
 
     StatsValue stats = {};
 
     if (bpfGetIfaceStats(iface8.c_str(), &stats) == 0) {
-        return getStatsType(&stats, (StatsType) type);
+        return statsValueToEntry(env, &stats);
     } else {
-        return UNKNOWN;
+        return nullptr;
     }
 }
 
-static jlong nativeGetIfIndexStat(JNIEnv* env, jclass clazz, jint ifindex, jint type) {
-    StatsValue stats = {};
-    if (bpfGetIfIndexStats(ifindex, &stats) == 0) {
-        return getStatsType(&stats, (StatsType) type);
-    } else {
-        return UNKNOWN;
-    }
-}
-
-static jlong nativeGetUidStat(JNIEnv* env, jclass clazz, jint uid, jint type) {
+static jobject nativeGetUidStat(JNIEnv* env, jclass clazz, jint uid) {
     StatsValue stats = {};
 
     if (bpfGetUidStats(uid, &stats) == 0) {
-        return getStatsType(&stats, (StatsType) type);
+        return statsValueToEntry(env, &stats);
     } else {
-        return UNKNOWN;
+        return nullptr;
     }
 }
 
@@ -113,11 +100,26 @@
 }
 
 static const JNINativeMethod gMethods[] = {
-        {"nativeGetTotalStat", "(I)J", (void*)nativeGetTotalStat},
-        {"nativeGetIfaceStat", "(Ljava/lang/String;I)J", (void*)nativeGetIfaceStat},
-        {"nativeGetIfIndexStat", "(II)J", (void*)nativeGetIfIndexStat},
-        {"nativeGetUidStat", "(II)J", (void*)nativeGetUidStat},
-        {"nativeInitNetworkTracing", "()V", (void*)nativeInitNetworkTracing},
+        {
+            "nativeGetTotalStat",
+            "()Landroid/net/NetworkStats$Entry;",
+            (void*)nativeGetTotalStat
+        },
+        {
+            "nativeGetIfaceStat",
+            "(Ljava/lang/String;)Landroid/net/NetworkStats$Entry;",
+            (void*)nativeGetIfaceStat
+        },
+        {
+            "nativeGetUidStat",
+            "(I)Landroid/net/NetworkStats$Entry;",
+            (void*)nativeGetUidStat
+        },
+        {
+            "nativeInitNetworkTracing",
+            "()V",
+            (void*)nativeInitNetworkTracing
+        },
 };
 
 int register_android_server_net_NetworkStatsService(JNIEnv* env) {
diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java
index d858a85..6c25d76 100644
--- a/service-t/src/com/android/server/NsdService.java
+++ b/service-t/src/com/android/server/NsdService.java
@@ -538,13 +538,13 @@
         }
 
         private void maybeStartDaemon() {
-            if (mMDnsManager == null) {
-                Log.wtf(TAG, "maybeStartDaemon: mMDnsManager is null");
+            if (mIsDaemonStarted) {
+                if (DBG) Log.d(TAG, "Daemon is already started.");
                 return;
             }
 
-            if (mIsDaemonStarted) {
-                if (DBG) Log.d(TAG, "Daemon is already started.");
+            if (mMDnsManager == null) {
+                Log.wtf(TAG, "maybeStartDaemon: mMDnsManager is null");
                 return;
             }
             mMDnsManager.registerEventListener(mMDnsEventCallback);
@@ -555,13 +555,13 @@
         }
 
         private void maybeStopDaemon() {
-            if (mMDnsManager == null) {
-                Log.wtf(TAG, "maybeStopDaemon: mMDnsManager is null");
+            if (!mIsDaemonStarted) {
+                if (DBG) Log.d(TAG, "Daemon has not been started.");
                 return;
             }
 
-            if (!mIsDaemonStarted) {
-                if (DBG) Log.d(TAG, "Daemon has not been started.");
+            if (mMDnsManager == null) {
+                Log.wtf(TAG, "maybeStopDaemon: mMDnsManager is null");
                 return;
             }
             mMDnsManager.unregisterEventListener(mMDnsEventCallback);
diff --git a/service-t/src/com/android/server/net/NetworkStatsService.java b/service-t/src/com/android/server/net/NetworkStatsService.java
index 3ac5e29..eb75461 100644
--- a/service-t/src/com/android/server/net/NetworkStatsService.java
+++ b/service-t/src/com/android/server/net/NetworkStatsService.java
@@ -1981,36 +1981,56 @@
         if (callingUid != android.os.Process.SYSTEM_UID && callingUid != uid) {
             return UNSUPPORTED;
         }
-        return nativeGetUidStat(uid, type);
+        return getEntryValueForType(nativeGetUidStat(uid), type);
     }
 
     @Override
     public long getIfaceStats(@NonNull String iface, int type) {
         Objects.requireNonNull(iface);
-        long nativeIfaceStats = nativeGetIfaceStat(iface, type);
-        if (nativeIfaceStats == -1) {
-            return nativeIfaceStats;
+        final NetworkStats.Entry entry = nativeGetIfaceStat(iface);
+        final long value = getEntryValueForType(entry, type);
+        if (value == UNSUPPORTED) {
+            return UNSUPPORTED;
         } else {
             // When tethering offload is in use, nativeIfaceStats does not contain usage from
             // offload, add it back here. Note that the included statistics might be stale
             // since polling newest stats from hardware might impact system health and not
             // suitable for TrafficStats API use cases.
-            return nativeIfaceStats + getProviderIfaceStats(iface, type);
+            entry.add(getProviderIfaceStats(iface));
+            return getEntryValueForType(entry, type);
+        }
+    }
+
+    private long getEntryValueForType(@Nullable NetworkStats.Entry entry, int type) {
+        if (entry == null) return UNSUPPORTED;
+        switch (type) {
+            case TrafficStats.TYPE_RX_BYTES:
+                return entry.rxBytes;
+            case TrafficStats.TYPE_TX_BYTES:
+                return entry.txBytes;
+            case TrafficStats.TYPE_RX_PACKETS:
+                return entry.rxPackets;
+            case TrafficStats.TYPE_TX_PACKETS:
+                return entry.txPackets;
+            default:
+                return UNSUPPORTED;
         }
     }
 
     @Override
     public long getTotalStats(int type) {
-        long nativeTotalStats = nativeGetTotalStat(type);
-        if (nativeTotalStats == -1) {
-            return nativeTotalStats;
+        final NetworkStats.Entry entry = nativeGetTotalStat();
+        final long value = getEntryValueForType(entry, type);
+        if (value == UNSUPPORTED) {
+            return UNSUPPORTED;
         } else {
             // Refer to comment in getIfaceStats
-            return nativeTotalStats + getProviderIfaceStats(IFACE_ALL, type);
+            entry.add(getProviderIfaceStats(IFACE_ALL));
+            return getEntryValueForType(entry, type);
         }
     }
 
-    private long getProviderIfaceStats(@Nullable String iface, int type) {
+    private NetworkStats.Entry getProviderIfaceStats(@Nullable String iface) {
         final NetworkStats providerSnapshot = getNetworkStatsFromProviders(STATS_PER_IFACE);
         final HashSet<String> limitIfaces;
         if (iface == IFACE_ALL) {
@@ -2019,19 +2039,7 @@
             limitIfaces = new HashSet<>();
             limitIfaces.add(iface);
         }
-        final NetworkStats.Entry entry = providerSnapshot.getTotal(null, limitIfaces);
-        switch (type) {
-            case TrafficStats.TYPE_RX_BYTES:
-                return entry.rxBytes;
-            case TrafficStats.TYPE_RX_PACKETS:
-                return entry.rxPackets;
-            case TrafficStats.TYPE_TX_BYTES:
-                return entry.txBytes;
-            case TrafficStats.TYPE_TX_PACKETS:
-                return entry.txPackets;
-            default:
-                return 0;
-        }
+        return providerSnapshot.getTotal(null, limitIfaces);
     }
 
     /**
@@ -3398,10 +3406,13 @@
         }
     }
 
-    private static native long nativeGetTotalStat(int type);
-    private static native long nativeGetIfaceStat(String iface, int type);
-    private static native long nativeGetIfIndexStat(int ifindex, int type);
-    private static native long nativeGetUidStat(int uid, int type);
+    // TODO: Read stats by using BpfNetMapsReader.
+    @Nullable
+    private static native NetworkStats.Entry nativeGetTotalStat();
+    @Nullable
+    private static native NetworkStats.Entry nativeGetIfaceStat(String iface);
+    @Nullable
+    private static native NetworkStats.Entry nativeGetUidStat(int uid);
 
     /** Initializes and registers the Perfetto Network Trace data source */
     public static native void nativeInitNetworkTracing();
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index 6b47654..3b31ed2 100755
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -291,6 +291,7 @@
 import com.android.net.module.util.BpfUtils;
 import com.android.net.module.util.CollectionUtils;
 import com.android.net.module.util.DeviceConfigUtils;
+import com.android.net.module.util.HandlerUtils;
 import com.android.net.module.util.InterfaceParams;
 import com.android.net.module.util.LinkPropertiesUtils.CompareOrUpdateResult;
 import com.android.net.module.util.LinkPropertiesUtils.CompareResult;
@@ -315,7 +316,6 @@
 import com.android.server.connectivity.DnsManager.PrivateDnsValidationUpdate;
 import com.android.server.connectivity.DscpPolicyTracker;
 import com.android.server.connectivity.FullScore;
-import com.android.server.connectivity.HandlerUtils;
 import com.android.server.connectivity.InvalidTagException;
 import com.android.server.connectivity.KeepaliveResourceUtil;
 import com.android.server.connectivity.KeepaliveTracker;
@@ -1280,7 +1280,7 @@
         LocalPriorityDump() {}
 
         private void dumpHigh(FileDescriptor fd, PrintWriter pw) {
-            if (!HandlerUtils.runWithScissors(mHandler, () -> {
+            if (!HandlerUtils.runWithScissorsForDump(mHandler, () -> {
                 doDump(fd, pw, new String[]{DIAG_ARG});
                 doDump(fd, pw, new String[]{SHORT_ARG});
             }, DUMPSYS_DEFAULT_TIMEOUT_MS)) {
@@ -1289,7 +1289,7 @@
         }
 
         private void dumpNormal(FileDescriptor fd, PrintWriter pw, String[] args) {
-            if (!HandlerUtils.runWithScissors(mHandler, () -> doDump(fd, pw, args),
+            if (!HandlerUtils.runWithScissorsForDump(mHandler, () -> doDump(fd, pw, args),
                     DUMPSYS_DEFAULT_TIMEOUT_MS)) {
                 pw.println("dumpNormal timeout");
             }
diff --git a/service/src/com/android/server/connectivity/HandlerUtils.java b/service/src/com/android/server/connectivity/HandlerUtils.java
deleted file mode 100644
index 997ecbf..0000000
--- a/service/src/com/android/server/connectivity/HandlerUtils.java
+++ /dev/null
@@ -1,139 +0,0 @@
-/*
- * Copyright (C) 2023 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.server.connectivity;
-
-import android.annotation.NonNull;
-import android.os.Handler;
-import android.os.Looper;
-import android.os.SystemClock;
-
-/**
- * Helper class for Handler related utilities.
- *
- * @hide
- */
-public class HandlerUtils {
-    // Note: @hide methods copied from android.os.Handler
-    /**
-     * Runs the specified task synchronously.
-     * <p>
-     * If the current thread is the same as the handler thread, then the runnable
-     * runs immediately without being enqueued.  Otherwise, posts the runnable
-     * to the handler and waits for it to complete before returning.
-     * </p><p>
-     * This method is dangerous!  Improper use can result in deadlocks.
-     * Never call this method while any locks are held or use it in a
-     * possibly re-entrant manner.
-     * </p><p>
-     * This method is occasionally useful in situations where a background thread
-     * must synchronously await completion of a task that must run on the
-     * handler's thread.  However, this problem is often a symptom of bad design.
-     * Consider improving the design (if possible) before resorting to this method.
-     * </p><p>
-     * One example of where you might want to use this method is when you just
-     * set up a Handler thread and need to perform some initialization steps on
-     * it before continuing execution.
-     * </p><p>
-     * If timeout occurs then this method returns <code>false</code> but the runnable
-     * will remain posted on the handler and may already be in progress or
-     * complete at a later time.
-     * </p><p>
-     * When using this method, be sure to use {@link Looper#quitSafely} when
-     * quitting the looper.  Otherwise {@link #runWithScissors} may hang indefinitely.
-     * (TODO: We should fix this by making MessageQueue aware of blocking runnables.)
-     * </p>
-     *
-     * @param h The target handler.
-     * @param r The Runnable that will be executed synchronously.
-     * @param timeout The timeout in milliseconds, or 0 to wait indefinitely.
-     *
-     * @return Returns true if the Runnable was successfully executed.
-     *         Returns false on failure, usually because the
-     *         looper processing the message queue is exiting.
-     *
-     * @hide This method is prone to abuse and should probably not be in the API.
-     * If we ever do make it part of the API, we might want to rename it to something
-     * less funny like runUnsafe().
-     */
-    public static boolean runWithScissors(@NonNull Handler h, @NonNull Runnable r, long timeout) {
-        if (r == null) {
-            throw new IllegalArgumentException("runnable must not be null");
-        }
-        if (timeout < 0) {
-            throw new IllegalArgumentException("timeout must be non-negative");
-        }
-
-        if (Looper.myLooper() == h.getLooper()) {
-            r.run();
-            return true;
-        }
-
-        BlockingRunnable br = new BlockingRunnable(r);
-        return br.postAndWait(h, timeout);
-    }
-
-    private static final class BlockingRunnable implements Runnable {
-        private final Runnable mTask;
-        private boolean mDone;
-
-        BlockingRunnable(Runnable task) {
-            mTask = task;
-        }
-
-        @Override
-        public void run() {
-            try {
-                mTask.run();
-            } finally {
-                synchronized (this) {
-                    mDone = true;
-                    notifyAll();
-                }
-            }
-        }
-
-        public boolean postAndWait(Handler handler, long timeout) {
-            if (!handler.post(this)) {
-                return false;
-            }
-
-            synchronized (this) {
-                if (timeout > 0) {
-                    final long expirationTime = SystemClock.uptimeMillis() + timeout;
-                    while (!mDone) {
-                        long delay = expirationTime - SystemClock.uptimeMillis();
-                        if (delay <= 0) {
-                            return false; // timeout
-                        }
-                        try {
-                            wait(delay);
-                        } catch (InterruptedException ex) {
-                        }
-                    }
-                } else {
-                    while (!mDone) {
-                        try {
-                            wait();
-                        } catch (InterruptedException ex) {
-                        }
-                    }
-                }
-            }
-            return true;
-        }
-    }
-}
diff --git a/staticlibs/Android.bp b/staticlibs/Android.bp
index 6f7ea4c..8f018c0 100644
--- a/staticlibs/Android.bp
+++ b/staticlibs/Android.bp
@@ -43,6 +43,7 @@
       "device/com/android/net/module/util/SharedLog.java",
       "device/com/android/net/module/util/SocketUtils.java",
       "device/com/android/net/module/util/FeatureVersions.java",
+      "device/com/android/net/module/util/HandlerUtils.java",
       // This library is used by system modules, for which the system health impact of Kotlin
       // has not yet been evaluated. Annotations may need jarjar'ing.
       // "src_devicecommon/**/*.kt",
diff --git a/staticlibs/device/com/android/net/module/util/HandlerUtils.java b/staticlibs/device/com/android/net/module/util/HandlerUtils.java
new file mode 100644
index 0000000..c620368
--- /dev/null
+++ b/staticlibs/device/com/android/net/module/util/HandlerUtils.java
@@ -0,0 +1,105 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.net.module.util;
+
+import android.annotation.NonNull;
+import android.os.Handler;
+import android.os.Looper;
+
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
+
+/**
+ * Helper class for Handler related utilities.
+ *
+ * @hide
+ */
+public class HandlerUtils {
+    /**
+     * Runs the specified task synchronously for dump method.
+     * <p>
+     * If the current thread is the same as the handler thread, then the runnable
+     * runs immediately without being enqueued.  Otherwise, posts the runnable
+     * to the handler and waits for it to complete before returning.
+     * </p><p>
+     * This method is dangerous!  Improper use can result in deadlocks.
+     * Never call this method while any locks are held or use it in a
+     * possibly re-entrant manner.
+     * </p><p>
+     * This method is made to let dump method access members on the handler thread to
+     * avoid concurrent access problems or races.
+     * </p><p>
+     * If timeout occurs then this method returns <code>false</code> but the runnable
+     * will remain posted on the handler and may already be in progress or
+     * complete at a later time.
+     * </p><p>
+     * When using this method, be sure to use {@link Looper#quitSafely} when
+     * quitting the looper.  Otherwise {@link #runWithScissorsForDump} may hang indefinitely.
+     * (TODO: We should fix this by making MessageQueue aware of blocking runnables.)
+     * </p>
+     *
+     * @param h The target handler.
+     * @param r The Runnable that will be executed synchronously.
+     * @param timeout The timeout in milliseconds, or 0 to not wait at all.
+     *
+     * @return Returns true if the Runnable was successfully executed.
+     *         Returns false on failure, usually because the
+     *         looper processing the message queue is exiting.
+     *
+     * @hide
+     */
+    public static boolean runWithScissorsForDump(@NonNull Handler h, @NonNull Runnable r,
+                                                 long timeout) {
+        if (r == null) {
+            throw new IllegalArgumentException("runnable must not be null");
+        }
+        if (timeout < 0) {
+            throw new IllegalArgumentException("timeout must be non-negative");
+        }
+        if (Looper.myLooper() == h.getLooper()) {
+            r.run();
+            return true;
+        }
+
+        final CountDownLatch latch = new CountDownLatch(1);
+
+        // Don't crash in the handler if something in the runnable throws an exception,
+        // but try to propagate the exception to the caller.
+        AtomicReference<RuntimeException> exceptionRef = new AtomicReference<>();
+        h.post(() -> {
+            try {
+                r.run();
+            } catch (RuntimeException e) {
+                exceptionRef.set(e);
+            }
+            latch.countDown();
+        });
+
+        try {
+            if (!latch.await(timeout, TimeUnit.MILLISECONDS)) {
+                return false;
+            }
+        } catch (InterruptedException e) {
+            exceptionRef.compareAndSet(null, new IllegalStateException("Thread interrupted", e));
+        }
+
+        final RuntimeException e = exceptionRef.get();
+        if (e != null) throw e;
+        return true;
+    }
+}
diff --git a/tests/unit/java/com/android/server/HandlerUtilsTest.kt b/staticlibs/tests/unit/src/com/android/net/module/util/HandlerUtilsTest.kt
similarity index 90%
rename from tests/unit/java/com/android/server/HandlerUtilsTest.kt
rename to staticlibs/tests/unit/src/com/android/net/module/util/HandlerUtilsTest.kt
index 62bb651..f2c902f 100644
--- a/tests/unit/java/com/android/server/HandlerUtilsTest.kt
+++ b/staticlibs/tests/unit/src/com/android/net/module/util/HandlerUtilsTest.kt
@@ -14,11 +14,11 @@
  * limitations under the License.
  */
 
-package com.android.server
+package com.android.net.module.util
 
 import android.os.HandlerThread
-import com.android.server.connectivity.HandlerUtils
 import com.android.testutils.DevSdkIgnoreRunner
+import com.android.testutils.DevSdkIgnoreRunner.MonitorThreadLeak
 import kotlin.test.assertEquals
 import kotlin.test.assertTrue
 import org.junit.After
@@ -27,6 +27,8 @@
 
 const val THREAD_BLOCK_TIMEOUT_MS = 1000L
 const val TEST_REPEAT_COUNT = 100
+
+@MonitorThreadLeak
 @RunWith(DevSdkIgnoreRunner::class)
 class HandlerUtilsTest {
     val handlerThread = HandlerThread("HandlerUtilsTestHandlerThread").also {
@@ -39,7 +41,7 @@
         // Repeat the test a fair amount of times to ensure that it does not pass by chance.
         repeat(TEST_REPEAT_COUNT) {
             var result = false
-            HandlerUtils.runWithScissors(handler, {
+            HandlerUtils.runWithScissorsForDump(handler, {
                 assertEquals(Thread.currentThread(), handlerThread)
                 result = true
             }, THREAD_BLOCK_TIMEOUT_MS)
diff --git a/tests/cts/net/src/android/net/cts/ConnectivityDiagnosticsManagerTest.java b/tests/cts/net/src/android/net/cts/ConnectivityDiagnosticsManagerTest.java
index e0fe929..ceb48d4 100644
--- a/tests/cts/net/src/android/net/cts/ConnectivityDiagnosticsManagerTest.java
+++ b/tests/cts/net/src/android/net/cts/ConnectivityDiagnosticsManagerTest.java
@@ -298,17 +298,6 @@
                 },
                 android.Manifest.permission.MODIFY_PHONE_STATE);
 
-        // TODO(b/157779832): This should use android.permission.CHANGE_NETWORK_STATE. However, the
-        // shell does not have CHANGE_NETWORK_STATE, so use CONNECTIVITY_INTERNAL until the shell
-        // permissions are updated.
-        runWithShellPermissionIdentity(
-                () -> mConnectivityManager.requestNetwork(
-                        CELLULAR_NETWORK_REQUEST, testNetworkCallback),
-                android.Manifest.permission.CONNECTIVITY_INTERNAL);
-
-        final Network network = testNetworkCallback.waitForAvailable();
-        assertNotNull(network);
-
         assertTrue("Didn't receive broadcast for ACTION_CARRIER_CONFIG_CHANGED for subId=" + subId,
                 carrierConfigReceiver.waitForCarrierConfigChanged());
 
@@ -324,6 +313,17 @@
 
         Thread.sleep(5_000);
 
+        // TODO(b/157779832): This should use android.permission.CHANGE_NETWORK_STATE. However, the
+        // shell does not have CHANGE_NETWORK_STATE, so use CONNECTIVITY_INTERNAL until the shell
+        // permissions are updated.
+        runWithShellPermissionIdentity(
+                () -> mConnectivityManager.requestNetwork(
+                        CELLULAR_NETWORK_REQUEST, testNetworkCallback),
+                android.Manifest.permission.CONNECTIVITY_INTERNAL);
+
+        final Network network = testNetworkCallback.waitForAvailable();
+        assertNotNull(network);
+
         // TODO(b/217559768): Receiving carrier config change and immediately checking carrier
         //  privileges is racy, as the CP status is updated after receiving the same signal. Move
         //  the CP check after sleep to temporarily reduce the flakiness. This will soon be fixed
diff --git a/tests/integration/src/com/android/server/net/integrationtests/ConnectivityServiceIntegrationTest.kt b/tests/integration/src/com/android/server/net/integrationtests/ConnectivityServiceIntegrationTest.kt
index 496d163..76d30e6 100644
--- a/tests/integration/src/com/android/server/net/integrationtests/ConnectivityServiceIntegrationTest.kt
+++ b/tests/integration/src/com/android/server/net/integrationtests/ConnectivityServiceIntegrationTest.kt
@@ -56,9 +56,11 @@
 import com.android.server.connectivity.MockableSystemProperties
 import com.android.server.connectivity.MultinetworkPolicyTracker
 import com.android.server.connectivity.ProxyTracker
+import com.android.testutils.DevSdkIgnoreRunner
 import com.android.testutils.DeviceInfoUtils
 import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged
 import com.android.testutils.TestableNetworkCallback
+import com.android.testutils.tryTest
 import kotlin.test.assertEquals
 import kotlin.test.assertNotNull
 import kotlin.test.assertTrue
@@ -254,13 +256,18 @@
         na.addCapability(NET_CAPABILITY_INTERNET)
         na.connect()
 
-        testCallback.expectAvailableThenValidatedCallbacks(na.network, TEST_TIMEOUT_MS)
-        val requestedSize = nsInstrumentation.getRequestUrls().size
-        if (requestedSize == 2 || (requestedSize == 1 &&
-                nsInstrumentation.getRequestUrls()[0] == httpsProbeUrl)) {
-            return
+        tryTest {
+            testCallback.expectAvailableThenValidatedCallbacks(na.network, TEST_TIMEOUT_MS)
+            val requestedSize = nsInstrumentation.getRequestUrls().size
+            if (requestedSize == 2 || (requestedSize == 1 &&
+                        nsInstrumentation.getRequestUrls()[0] == httpsProbeUrl)
+            ) {
+                return@tryTest
+            }
+            fail("Unexpected request urls: ${nsInstrumentation.getRequestUrls()}")
+        } cleanup {
+            na.destroy()
         }
-        fail("Unexpected request urls: ${nsInstrumentation.getRequestUrls()}")
     }
 
     @Test
@@ -292,24 +299,32 @@
         val lp = LinkProperties()
         lp.captivePortalApiUrl = Uri.parse(apiUrl)
         val na = NetworkAgentWrapper(TRANSPORT_CELLULAR, lp, null /* ncTemplate */, context)
-        networkStackClient.verifyNetworkMonitorCreated(na.network, TEST_TIMEOUT_MS)
 
-        na.addCapability(NET_CAPABILITY_INTERNET)
-        na.connect()
+        tryTest {
+            networkStackClient.verifyNetworkMonitorCreated(na.network, TEST_TIMEOUT_MS)
 
-        testCb.expectAvailableCallbacks(na.network, validated = false, tmt = TEST_TIMEOUT_MS)
+            na.addCapability(NET_CAPABILITY_INTERNET)
+            na.connect()
 
-        val capportData = testCb.expect<LinkPropertiesChanged>(na, TEST_TIMEOUT_MS) {
-            it.lp.captivePortalData != null
-        }.lp.captivePortalData
-        assertNotNull(capportData)
-        assertTrue(capportData.isCaptive)
-        assertEquals(Uri.parse("https://login.capport.android.com"), capportData.userPortalUrl)
-        assertEquals(Uri.parse("https://venueinfo.capport.android.com"), capportData.venueInfoUrl)
+            testCb.expectAvailableCallbacks(na.network, validated = false, tmt = TEST_TIMEOUT_MS)
 
-        testCb.expectCaps(na, TEST_TIMEOUT_MS) {
-            it.hasCapability(NET_CAPABILITY_CAPTIVE_PORTAL) &&
-                    !it.hasCapability(NET_CAPABILITY_VALIDATED)
+            val capportData = testCb.expect<LinkPropertiesChanged>(na, TEST_TIMEOUT_MS) {
+                it.lp.captivePortalData != null
+            }.lp.captivePortalData
+            assertNotNull(capportData)
+            assertTrue(capportData.isCaptive)
+            assertEquals(Uri.parse("https://login.capport.android.com"), capportData.userPortalUrl)
+            assertEquals(
+                Uri.parse("https://venueinfo.capport.android.com"),
+                capportData.venueInfoUrl
+            )
+
+            testCb.expectCaps(na, TEST_TIMEOUT_MS) {
+                it.hasCapability(NET_CAPABILITY_CAPTIVE_PORTAL) &&
+                        !it.hasCapability(NET_CAPABILITY_VALIDATED)
+            }
+        } cleanup {
+            na.destroy()
         }
     }
 
diff --git a/tests/integration/util/com/android/server/NetworkAgentWrapper.java b/tests/integration/util/com/android/server/NetworkAgentWrapper.java
index ec09f9e..960c6ca 100644
--- a/tests/integration/util/com/android/server/NetworkAgentWrapper.java
+++ b/tests/integration/util/com/android/server/NetworkAgentWrapper.java
@@ -36,6 +36,7 @@
 import static org.junit.Assert.fail;
 
 import android.annotation.NonNull;
+import android.annotation.SuppressLint;
 import android.content.Context;
 import android.net.ConnectivityManager;
 import android.net.LinkProperties;
@@ -51,6 +52,7 @@
 import android.os.ConditionVariable;
 import android.os.HandlerThread;
 import android.os.Message;
+import android.util.CloseGuard;
 import android.util.Log;
 import android.util.Range;
 
@@ -65,11 +67,14 @@
 import java.util.function.Consumer;
 
 public class NetworkAgentWrapper implements TestableNetworkCallback.HasNetwork {
+    private static final long DESTROY_TIMEOUT_MS = 10_000L;
+
     // Note : Please do not add any new instrumentation here. If you need new instrumentation,
     // please add it in CSAgentWrapper and use subclasses of CSTest instead of adding more
     // tools in ConnectivityServiceTest.
     private final NetworkCapabilities mNetworkCapabilities;
     private final HandlerThread mHandlerThread;
+    private final CloseGuard mCloseGuard;
     private final Context mContext;
     private final String mLogTag;
     private final NetworkAgentConfig mNetworkAgentConfig;
@@ -157,6 +162,8 @@
         mLogTag = "Mock-" + typeName;
         mHandlerThread = new HandlerThread(mLogTag);
         mHandlerThread.start();
+        mCloseGuard = new CloseGuard();
+        mCloseGuard.open("destroy");
 
         // extraInfo is set to "" by default in NetworkAgentConfig.
         final String extraInfo = (transport == TRANSPORT_CELLULAR) ? "internet.apn" : "";
@@ -359,6 +366,35 @@
         mNetworkAgent.unregister();
     }
 
+    /**
+     * Destroy the network agent and stop its looper.
+     *
+     * <p>This must always be called.
+     */
+    public void destroy() {
+        mHandlerThread.quitSafely();
+        try {
+            mHandlerThread.join(DESTROY_TIMEOUT_MS);
+        } catch (InterruptedException e) {
+            Log.e(mLogTag, "Interrupted when waiting for handler thread on destroy", e);
+        }
+        mCloseGuard.close();
+    }
+
+    @SuppressLint("Finalize") // Follows the recommended pattern for CloseGuard
+    @Override
+    protected void finalize() throws Throwable {
+        try {
+            // Note that mCloseGuard could be null if the constructor threw.
+            if (mCloseGuard != null) {
+                mCloseGuard.warnIfOpen();
+            }
+            destroy();
+        } finally {
+            super.finalize();
+        }
+    }
+
     @Override
     public Network getNetwork() {
         return mNetworkAgent.getNetwork();
diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
index 8f5fd7c..c681356 100755
--- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
@@ -589,6 +589,7 @@
     private TestNetworkAgentWrapper mWiFiAgent;
     private TestNetworkAgentWrapper mCellAgent;
     private TestNetworkAgentWrapper mEthernetAgent;
+    private final List<TestNetworkAgentWrapper> mCreatedAgents = new ArrayList<>();
     private MockVpn mMockVpn;
     private Context mContext;
     private NetworkPolicyCallback mPolicyCallback;
@@ -1092,6 +1093,7 @@
                 NetworkCapabilities ncTemplate, NetworkProvider provider,
                 NetworkAgentWrapper.Callbacks callbacks) throws Exception {
             super(transport, linkProperties, ncTemplate, provider, callbacks, mServiceContext);
+            mCreatedAgents.add(this);
 
             // Waits for the NetworkAgent to be registered, which includes the creation of the
             // NetworkMonitor.
@@ -2404,6 +2406,11 @@
         FakeSettingsProvider.clearSettingsProvider();
         ConnectivityResources.setResourcesContextForTest(null);
 
+        for (TestNetworkAgentWrapper agent : mCreatedAgents) {
+            agent.destroy();
+        }
+        mCreatedAgents.clear();
+
         mCsHandlerThread.quitSafely();
         mCsHandlerThread.join();
         mAlarmManagerThread.quitSafely();
diff --git a/thread/flags/Android.bp b/thread/flags/Android.bp
new file mode 100644
index 0000000..225022c
--- /dev/null
+++ b/thread/flags/Android.bp
@@ -0,0 +1,35 @@
+//
+// Copyright (C) 2024 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+package {
+    default_applicable_licenses: ["Android-Apache-2.0"],
+}
+
+aconfig_declarations {
+    name: "thread_aconfig_flags",
+    package: "com.android.net.thread.flags",
+    srcs: ["thread_base.aconfig"],
+}
+
+java_aconfig_library {
+    name: "thread_aconfig_flags_lib",
+    aconfig_declarations: "thread_aconfig_flags",
+    min_sdk_version: "30",
+    apex_available: [
+        "//apex_available:platform",
+        "com.android.tethering",
+    ],
+}
diff --git a/thread/flags/thread_base.aconfig b/thread/flags/thread_base.aconfig
index bf1f288..f73ea6b 100644
--- a/thread/flags/thread_base.aconfig
+++ b/thread/flags/thread_base.aconfig
@@ -6,3 +6,10 @@
     description: "Controls whether the Android Thread feature is enabled"
     bug: "301473012"
 }
+
+flag {
+    name: "thread_user_restriction_enabled"
+    namespace: "thread_network"
+    description: "Controls whether user restriction on thread networks is enabled"
+    bug: "307679182"
+}
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
index 6cd0ac3..cd59e4e 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
@@ -333,6 +333,7 @@
                     mLinkProperties.setMtu(TunInterfaceController.MTU);
                     mConnectivityManager.registerNetworkProvider(mNetworkProvider);
                     requestUpstreamNetwork();
+                    requestThreadNetwork();
 
                     initializeOtDaemon();
                 });
@@ -413,9 +414,10 @@
     private void requestThreadNetwork() {
         mConnectivityManager.registerNetworkCallback(
                 new NetworkRequest.Builder()
+                        // clearCapabilities() is needed to remove forbidden capabilities and UID
+                        // requirement.
                         .clearCapabilities()
                         .addTransportType(NetworkCapabilities.TRANSPORT_THREAD)
-                        .removeForbiddenCapability(NetworkCapabilities.NET_CAPABILITY_LOCAL_NETWORK)
                         .build(),
                 new ThreadNetworkCallback(),
                 mHandler);
@@ -459,8 +461,6 @@
             return;
         }
 
-        requestThreadNetwork();
-
         mNetworkAgent = newNetworkAgent();
         mNetworkAgent.register();
         mNetworkAgent.markConnected();
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkService.java b/thread/service/java/com/android/server/thread/ThreadNetworkService.java
index a3cf278..53f2d4f 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkService.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkService.java
@@ -48,17 +48,20 @@
     }
 
     /**
-     * Called by the service initializer.
+     * Called by {@link com.android.server.ConnectivityServiceInitializer}.
      *
      * @see com.android.server.SystemService#onBootPhase
      */
     public void onBootPhase(int phase) {
-        if (phase == SystemService.PHASE_BOOT_COMPLETED) {
+        if (phase == SystemService.PHASE_SYSTEM_SERVICES_READY) {
             mControllerService = ThreadNetworkControllerService.newInstance(mContext);
             mControllerService.initialize();
+        } else if (phase == SystemService.PHASE_BOOT_COMPLETED) {
+            // Country code initialization is delayed to the BOOT_COMPLETED phase because it will
+            // call into Wi-Fi and Telephony service whose country code module is ready after
+            // PHASE_ACTIVITY_MANAGER_READY and PHASE_THIRD_PARTY_APPS_CAN_START
             mCountryCode = ThreadNetworkCountryCode.newInstance(mContext, mControllerService);
             mCountryCode.initialize();
-
             mShellCommand = new ThreadNetworkShellCommand(mCountryCode);
         }
     }
diff --git a/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java b/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java
index 362ff39..e02e74d 100644
--- a/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java
+++ b/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java
@@ -521,7 +521,7 @@
     }
 
     @Test
-    public void scheduleMigration_withPrivilegedPermission_success() throws Exception {
+    public void scheduleMigration_withPrivilegedPermission_newDatasetApplied() throws Exception {
         grantPermissions(permission.ACCESS_NETWORK_STATE, PERMISSION_THREAD_NETWORK_PRIVILEGED);
 
         for (ThreadNetworkController controller : getAllControllers()) {
@@ -548,11 +548,32 @@
 
             controller.scheduleMigration(
                     pendingDataset2, mExecutor, newOutcomeReceiver(migrateFuture));
-
             migrateFuture.get();
-            Thread.sleep(35 * 1000);
-            assertThat(getActiveOperationalDataset(controller)).isEqualTo(activeDataset2);
-            assertThat(getPendingOperationalDataset(controller)).isNull();
+
+            SettableFuture<Boolean> dataset2IsApplied = SettableFuture.create();
+            SettableFuture<Boolean> pendingDatasetIsRemoved = SettableFuture.create();
+            OperationalDatasetCallback datasetCallback =
+                    new OperationalDatasetCallback() {
+                        @Override
+                        public void onActiveOperationalDatasetChanged(
+                                ActiveOperationalDataset activeDataset) {
+                            if (activeDataset.equals(activeDataset2)) {
+                                dataset2IsApplied.set(true);
+                            }
+                        }
+
+                        @Override
+                        public void onPendingOperationalDatasetChanged(
+                                PendingOperationalDataset pendingDataset) {
+                            if (pendingDataset == null) {
+                                pendingDatasetIsRemoved.set(true);
+                            }
+                        }
+                    };
+            controller.registerOperationalDatasetCallback(directExecutor(), datasetCallback);
+            assertThat(dataset2IsApplied.get()).isTrue();
+            assertThat(pendingDatasetIsRemoved.get()).isTrue();
+            controller.unregisterOperationalDatasetCallback(datasetCallback);
         }
     }
 
@@ -629,7 +650,8 @@
     }
 
     @Test
-    public void scheduleMigration_secondRequestHasLargerTimestamp_success() throws Exception {
+    public void scheduleMigration_secondRequestHasLargerTimestamp_newDatasetApplied()
+            throws Exception {
         grantPermissions(permission.ACCESS_NETWORK_STATE, PERMISSION_THREAD_NETWORK_PRIVILEGED);
 
         for (ThreadNetworkController controller : getAllControllers()) {
@@ -669,11 +691,32 @@
             migrateFuture1.get();
             controller.scheduleMigration(
                     pendingDataset2, mExecutor, newOutcomeReceiver(migrateFuture2));
-
             migrateFuture2.get();
-            Thread.sleep(35 * 1000);
-            assertThat(getActiveOperationalDataset(controller)).isEqualTo(activeDataset2);
-            assertThat(getPendingOperationalDataset(controller)).isNull();
+
+            SettableFuture<Boolean> dataset2IsApplied = SettableFuture.create();
+            SettableFuture<Boolean> pendingDatasetIsRemoved = SettableFuture.create();
+            OperationalDatasetCallback datasetCallback =
+                    new OperationalDatasetCallback() {
+                        @Override
+                        public void onActiveOperationalDatasetChanged(
+                                ActiveOperationalDataset activeDataset) {
+                            if (activeDataset.equals(activeDataset2)) {
+                                dataset2IsApplied.set(true);
+                            }
+                        }
+
+                        @Override
+                        public void onPendingOperationalDatasetChanged(
+                                PendingOperationalDataset pendingDataset) {
+                            if (pendingDataset == null) {
+                                pendingDatasetIsRemoved.set(true);
+                            }
+                        }
+                    };
+            controller.registerOperationalDatasetCallback(directExecutor(), datasetCallback);
+            assertThat(dataset2IsApplied.get()).isTrue();
+            assertThat(pendingDatasetIsRemoved.get()).isTrue();
+            controller.unregisterOperationalDatasetCallback(datasetCallback);
         }
     }