Merge "Use an integer for query mode" into main
diff --git a/framework-t/api/system-current.txt b/framework-t/api/system-current.txt
index d346af3..8251f85 100644
--- a/framework-t/api/system-current.txt
+++ b/framework-t/api/system-current.txt
@@ -500,6 +500,7 @@
     method @RequiresPermission(allOf={android.Manifest.permission.ACCESS_NETWORK_STATE, "android.permission.THREAD_NETWORK_PRIVILEGED"}) public void registerOperationalDatasetCallback(@NonNull java.util.concurrent.Executor, @NonNull android.net.thread.ThreadNetworkController.OperationalDatasetCallback);
     method @RequiresPermission(android.Manifest.permission.ACCESS_NETWORK_STATE) public void registerStateCallback(@NonNull java.util.concurrent.Executor, @NonNull android.net.thread.ThreadNetworkController.StateCallback);
     method @RequiresPermission("android.permission.THREAD_NETWORK_PRIVILEGED") public void scheduleMigration(@NonNull android.net.thread.PendingOperationalDataset, @NonNull java.util.concurrent.Executor, @NonNull android.os.OutcomeReceiver<java.lang.Void,android.net.thread.ThreadNetworkException>);
+    method @RequiresPermission("android.permission.THREAD_NETWORK_PRIVILEGED") public void setEnabled(boolean, @NonNull java.util.concurrent.Executor, @NonNull android.os.OutcomeReceiver<java.lang.Void,android.net.thread.ThreadNetworkException>);
     method @RequiresPermission(allOf={android.Manifest.permission.ACCESS_NETWORK_STATE, "android.permission.THREAD_NETWORK_PRIVILEGED"}) public void unregisterOperationalDatasetCallback(@NonNull android.net.thread.ThreadNetworkController.OperationalDatasetCallback);
     method @RequiresPermission(android.Manifest.permission.ACCESS_NETWORK_STATE) public void unregisterStateCallback(@NonNull android.net.thread.ThreadNetworkController.StateCallback);
     field public static final int DEVICE_ROLE_CHILD = 2; // 0x2
@@ -507,6 +508,9 @@
     field public static final int DEVICE_ROLE_LEADER = 4; // 0x4
     field public static final int DEVICE_ROLE_ROUTER = 3; // 0x3
     field public static final int DEVICE_ROLE_STOPPED = 0; // 0x0
+    field public static final int STATE_DISABLED = 0; // 0x0
+    field public static final int STATE_DISABLING = 2; // 0x2
+    field public static final int STATE_ENABLED = 1; // 0x1
     field public static final int THREAD_VERSION_1_3 = 4; // 0x4
   }
 
@@ -518,6 +522,7 @@
   public static interface ThreadNetworkController.StateCallback {
     method public void onDeviceRoleChanged(int);
     method public default void onPartitionIdChanged(long);
+    method public default void onThreadEnableStateChanged(int);
   }
 
   @FlaggedApi("com.android.net.thread.flags.thread_enabled") public class ThreadNetworkException extends java.lang.Exception {
@@ -530,6 +535,7 @@
     field public static final int ERROR_REJECTED_BY_PEER = 8; // 0x8
     field public static final int ERROR_RESOURCE_EXHAUSTED = 10; // 0xa
     field public static final int ERROR_RESPONSE_BAD_FORMAT = 9; // 0x9
+    field public static final int ERROR_THREAD_DISABLED = 12; // 0xc
     field public static final int ERROR_TIMEOUT = 3; // 0x3
     field public static final int ERROR_UNAVAILABLE = 4; // 0x4
     field public static final int ERROR_UNKNOWN = 11; // 0xb
diff --git a/service-t/jni/com_android_server_net_NetworkStatsService.cpp b/service-t/jni/com_android_server_net_NetworkStatsService.cpp
index 48ac993..c999398 100644
--- a/service-t/jni/com_android_server_net_NetworkStatsService.cpp
+++ b/service-t/jni/com_android_server_net_NetworkStatsService.cpp
@@ -52,8 +52,14 @@
         return nullptr;
     }
 
+    // Find the constructor.
+    jmethodID constructorID = env->GetMethodID(gEntryClass, "<init>", "()V");
+    if (constructorID == nullptr) {
+        return nullptr;
+    }
+
     // Create a new instance of the Java class
-    jobject result = env->AllocObject(gEntryClass);
+    jobject result = env->NewObject(gEntryClass, constructorID);
     if (result == nullptr) {
         return nullptr;
     }
diff --git a/service-t/native/libs/libnetworkstats/BpfNetworkStats.cpp b/service-t/native/libs/libnetworkstats/BpfNetworkStats.cpp
index def5f88..d3e331e 100644
--- a/service-t/native/libs/libnetworkstats/BpfNetworkStats.cpp
+++ b/service-t/native/libs/libnetworkstats/BpfNetworkStats.cpp
@@ -50,6 +50,15 @@
     return ifaceStatsMap;
 }
 
+Result<IfaceValue> ifindex2name(const uint32_t ifindex) {
+    Result<IfaceValue> v = getIfaceIndexNameMap().readValue(ifindex);
+    if (v.ok()) return v;
+    IfaceValue iv = {};
+    if (!if_indextoname(ifindex, iv.name)) return v;
+    getIfaceIndexNameMap().writeValue(ifindex, iv, BPF_ANY);
+    return iv;
+}
+
 void bpfRegisterIface(const char* iface) {
     if (!iface) return;
     if (strlen(iface) >= sizeof(IfaceValue)) return;
@@ -78,14 +87,14 @@
 
 int bpfGetIfaceStatsInternal(const char* iface, StatsValue* stats,
                              const BpfMapRO<uint32_t, StatsValue>& ifaceStatsMap,
-                             const BpfMapRO<uint32_t, IfaceValue>& ifaceNameMap) {
+                             const IfIndexToNameFunc ifindex2name) {
     *stats = {};
     int64_t unknownIfaceBytesTotal = 0;
     const auto processIfaceStats =
-            [iface, stats, &ifaceNameMap, &unknownIfaceBytesTotal](
+            [iface, stats, ifindex2name, &unknownIfaceBytesTotal](
                     const uint32_t& key,
                     const BpfMapRO<uint32_t, StatsValue>& ifaceStatsMap) -> Result<void> {
-        Result<IfaceValue> ifname = ifaceNameMap.readValue(key);
+        Result<IfaceValue> ifname = ifindex2name(key);
         if (!ifname.ok()) {
             maybeLogUnknownIface(key, ifaceStatsMap, key, &unknownIfaceBytesTotal);
             return Result<void>();
@@ -104,7 +113,7 @@
 }
 
 int bpfGetIfaceStats(const char* iface, StatsValue* stats) {
-    return bpfGetIfaceStatsInternal(iface, stats, getIfaceStatsMap(), getIfaceIndexNameMap());
+    return bpfGetIfaceStatsInternal(iface, stats, getIfaceStatsMap(), ifindex2name);
 }
 
 int bpfGetIfIndexStatsInternal(uint32_t ifindex, StatsValue* stats,
@@ -138,13 +147,13 @@
 
 int parseBpfNetworkStatsDetailInternal(std::vector<stats_line>& lines,
                                        const BpfMapRO<StatsKey, StatsValue>& statsMap,
-                                       const BpfMapRO<uint32_t, IfaceValue>& ifaceMap) {
+                                       const IfIndexToNameFunc ifindex2name) {
     int64_t unknownIfaceBytesTotal = 0;
     const auto processDetailUidStats =
-            [&lines, &unknownIfaceBytesTotal, &ifaceMap](
+            [&lines, &unknownIfaceBytesTotal, &ifindex2name](
                     const StatsKey& key,
                     const BpfMapRO<StatsKey, StatsValue>& statsMap) -> Result<void> {
-        Result<IfaceValue> ifname = ifaceMap.readValue(key.ifaceIndex);
+        Result<IfaceValue> ifname = ifindex2name(key.ifaceIndex);
         if (!ifname.ok()) {
             maybeLogUnknownIface(key.ifaceIndex, statsMap, key, &unknownIfaceBytesTotal);
             return Result<void>();
@@ -212,7 +221,7 @@
     // TODO: the above comment feels like it may be obsolete / out of date,
     // since we no longer swap the map via netd binder rpc - though we do
     // still swap it.
-    int ret = parseBpfNetworkStatsDetailInternal(*lines, *inactiveStatsMap, getIfaceIndexNameMap());
+    int ret = parseBpfNetworkStatsDetailInternal(*lines, *inactiveStatsMap, ifindex2name);
     if (ret) {
         ALOGE("parse detail network stats failed: %s", strerror(errno));
         return ret;
@@ -229,12 +238,12 @@
 
 int parseBpfNetworkStatsDevInternal(std::vector<stats_line>& lines,
                                     const BpfMapRO<uint32_t, StatsValue>& statsMap,
-                                    const BpfMapRO<uint32_t, IfaceValue>& ifaceMap) {
+                                    const IfIndexToNameFunc ifindex2name) {
     int64_t unknownIfaceBytesTotal = 0;
-    const auto processDetailIfaceStats = [&lines, &unknownIfaceBytesTotal, &ifaceMap, &statsMap](
+    const auto processDetailIfaceStats = [&lines, &unknownIfaceBytesTotal, ifindex2name, &statsMap](
                                              const uint32_t& key, const StatsValue& value,
                                              const BpfMapRO<uint32_t, StatsValue>&) {
-        Result<IfaceValue> ifname = ifaceMap.readValue(key);
+        Result<IfaceValue> ifname = ifindex2name(key);
         if (!ifname.ok()) {
             maybeLogUnknownIface(key, statsMap, key, &unknownIfaceBytesTotal);
             return Result<void>();
@@ -259,7 +268,7 @@
 }
 
 int parseBpfNetworkStatsDev(std::vector<stats_line>* lines) {
-    return parseBpfNetworkStatsDevInternal(*lines, getIfaceStatsMap(), getIfaceIndexNameMap());
+    return parseBpfNetworkStatsDevInternal(*lines, getIfaceStatsMap(), ifindex2name);
 }
 
 void groupNetworkStats(std::vector<stats_line>& lines) {
diff --git a/service-t/native/libs/libnetworkstats/BpfNetworkStatsTest.cpp b/service-t/native/libs/libnetworkstats/BpfNetworkStatsTest.cpp
index 57822fc..484c166 100644
--- a/service-t/native/libs/libnetworkstats/BpfNetworkStatsTest.cpp
+++ b/service-t/native/libs/libnetworkstats/BpfNetworkStatsTest.cpp
@@ -77,6 +77,10 @@
     BpfMap<uint32_t, IfaceValue> mFakeIfaceIndexNameMap;
     BpfMap<uint32_t, StatsValue> mFakeIfaceStatsMap;
 
+    IfIndexToNameFunc mIfIndex2Name = [this](const uint32_t ifindex){
+        return mFakeIfaceIndexNameMap.readValue(ifindex);
+    };
+
     void SetUp() {
         ASSERT_EQ(0, setrlimitForTest());
 
@@ -228,7 +232,7 @@
     populateFakeStats(TEST_UID1, 0, IFACE_INDEX1, TEST_COUNTERSET0, value1, mFakeStatsMap);
     populateFakeStats(TEST_UID1, 0, IFACE_INDEX2, TEST_COUNTERSET1, value1, mFakeStatsMap);
     populateFakeStats(TEST_UID2, 0, IFACE_INDEX3, TEST_COUNTERSET1, value1, mFakeStatsMap);
-    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mFakeIfaceIndexNameMap));
+    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mIfIndex2Name));
     ASSERT_EQ((unsigned long)3, lines.size());
 }
 
@@ -256,16 +260,15 @@
     EXPECT_RESULT_OK(mFakeIfaceStatsMap.writeValue(ifaceStatsKey, value1, BPF_ANY));
 
     StatsValue result1 = {};
-    ASSERT_EQ(0, bpfGetIfaceStatsInternal(IFACE_NAME1, &result1, mFakeIfaceStatsMap,
-                                          mFakeIfaceIndexNameMap));
+    ASSERT_EQ(0,
+              bpfGetIfaceStatsInternal(IFACE_NAME1, &result1, mFakeIfaceStatsMap, mIfIndex2Name));
     expectStatsEqual(value1, result1);
     StatsValue result2 = {};
-    ASSERT_EQ(0, bpfGetIfaceStatsInternal(IFACE_NAME2, &result2, mFakeIfaceStatsMap,
-                                          mFakeIfaceIndexNameMap));
+    ASSERT_EQ(0,
+              bpfGetIfaceStatsInternal(IFACE_NAME2, &result2, mFakeIfaceStatsMap, mIfIndex2Name));
     expectStatsEqual(value2, result2);
     StatsValue totalResult = {};
-    ASSERT_EQ(0, bpfGetIfaceStatsInternal(NULL, &totalResult, mFakeIfaceStatsMap,
-                                          mFakeIfaceIndexNameMap));
+    ASSERT_EQ(0, bpfGetIfaceStatsInternal(NULL, &totalResult, mFakeIfaceStatsMap, mIfIndex2Name));
     StatsValue totalValue = {
             .rxPackets = TEST_PACKET0 * 2 + TEST_PACKET1,
             .rxBytes = TEST_BYTES0 * 2 + TEST_BYTES1,
@@ -304,7 +307,7 @@
                       mFakeStatsMap);
     populateFakeStats(TEST_UID2, TEST_TAG, IFACE_INDEX1, TEST_COUNTERSET0, value1, mFakeStatsMap);
     std::vector<stats_line> lines;
-    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mFakeIfaceIndexNameMap));
+    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mIfIndex2Name));
     ASSERT_EQ((unsigned long)7, lines.size());
 }
 
@@ -324,7 +327,7 @@
     populateFakeStats(TEST_UID1, 0, IFACE_INDEX1, TEST_COUNTERSET1, value1, mFakeStatsMap);
     populateFakeStats(TEST_UID2, 0, IFACE_INDEX1, TEST_COUNTERSET0, value1, mFakeStatsMap);
     std::vector<stats_line> lines;
-    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mFakeIfaceIndexNameMap));
+    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mIfIndex2Name));
     ASSERT_EQ((unsigned long)4, lines.size());
 }
 
@@ -365,7 +368,7 @@
     ASSERT_EQ(-1, unknownIfaceBytesTotal);
     std::vector<stats_line> lines;
     // TODO: find a way to test the total of unknown Iface Bytes go above limit.
-    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mFakeIfaceIndexNameMap));
+    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mIfIndex2Name));
     ASSERT_EQ((unsigned long)1, lines.size());
     expectStatsLineEqual(value1, IFACE_NAME1, TEST_UID1, TEST_COUNTERSET0, 0, lines.front());
 }
@@ -396,8 +399,7 @@
     ifaceStatsKey = IFACE_INDEX4;
     EXPECT_RESULT_OK(mFakeIfaceStatsMap.writeValue(ifaceStatsKey, value2, BPF_ANY));
     std::vector<stats_line> lines;
-    ASSERT_EQ(0,
-              parseBpfNetworkStatsDevInternal(lines, mFakeIfaceStatsMap, mFakeIfaceIndexNameMap));
+    ASSERT_EQ(0, parseBpfNetworkStatsDevInternal(lines, mFakeIfaceStatsMap, mIfIndex2Name));
     ASSERT_EQ((unsigned long)4, lines.size());
 
     expectStatsLineEqual(value1, IFACE_NAME1, UID_ALL, SET_ALL, TAG_NONE, lines[0]);
@@ -441,13 +443,13 @@
     std::vector<stats_line> lines;
 
     // Test empty stats.
-    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mFakeIfaceIndexNameMap));
+    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mIfIndex2Name));
     ASSERT_EQ((size_t) 0, lines.size());
     lines.clear();
 
     // Test 1 line stats.
     populateFakeStats(TEST_UID1, TEST_TAG, IFACE_INDEX1, TEST_COUNTERSET0, value1, mFakeStatsMap);
-    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mFakeIfaceIndexNameMap));
+    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mIfIndex2Name));
     ASSERT_EQ((size_t) 2, lines.size());  // TEST_TAG != 0 -> 1 entry becomes 2 lines
     expectStatsLineEqual(value1, IFACE_NAME1, TEST_UID1, TEST_COUNTERSET0, 0, lines[0]);
     expectStatsLineEqual(value1, IFACE_NAME1, TEST_UID1, TEST_COUNTERSET0, TEST_TAG, lines[1]);
@@ -459,7 +461,7 @@
     populateFakeStats(TEST_UID1, TEST_TAG + 1, IFACE_INDEX1, TEST_COUNTERSET0, value2,
                       mFakeStatsMap);
     populateFakeStats(TEST_UID2, TEST_TAG, IFACE_INDEX1, TEST_COUNTERSET0, value1, mFakeStatsMap);
-    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mFakeIfaceIndexNameMap));
+    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mIfIndex2Name));
     ASSERT_EQ((size_t) 9, lines.size());
     lines.clear();
 
@@ -467,7 +469,7 @@
     populateFakeStats(TEST_UID1, TEST_TAG, IFACE_INDEX3, TEST_COUNTERSET0, value1, mFakeStatsMap);
     populateFakeStats(TEST_UID2, TEST_TAG, IFACE_INDEX3, TEST_COUNTERSET0, value1, mFakeStatsMap);
 
-    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mFakeIfaceIndexNameMap));
+    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mIfIndex2Name));
     ASSERT_EQ((size_t) 9, lines.size());
 
     // Verify Sorted & Grouped.
@@ -492,8 +494,7 @@
     ifaceStatsKey = IFACE_INDEX3;
     EXPECT_RESULT_OK(mFakeIfaceStatsMap.writeValue(ifaceStatsKey, value1, BPF_ANY));
 
-    ASSERT_EQ(0,
-              parseBpfNetworkStatsDevInternal(lines, mFakeIfaceStatsMap, mFakeIfaceIndexNameMap));
+    ASSERT_EQ(0, parseBpfNetworkStatsDevInternal(lines, mFakeIfaceStatsMap, mIfIndex2Name));
     ASSERT_EQ((size_t) 2, lines.size());
 
     expectStatsLineEqual(value3, IFACE_NAME1, UID_ALL, SET_ALL, TAG_NONE, lines[0]);
@@ -534,7 +535,7 @@
     // TODO: Mutate counterSet and enlarge TEST_MAP_SIZE if overflow on counterSet is possible.
 
     std::vector<stats_line> lines;
-    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mFakeIfaceIndexNameMap));
+    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mIfIndex2Name));
     ASSERT_EQ((size_t) 12, lines.size());
 
     // Uid 0 first
diff --git a/service-t/native/libs/libnetworkstats/include/netdbpf/BpfNetworkStats.h b/service-t/native/libs/libnetworkstats/include/netdbpf/BpfNetworkStats.h
index 173dee4..59eb195 100644
--- a/service-t/native/libs/libnetworkstats/include/netdbpf/BpfNetworkStats.h
+++ b/service-t/native/libs/libnetworkstats/include/netdbpf/BpfNetworkStats.h
@@ -55,20 +55,25 @@
 bool operator==(const stats_line& lhs, const stats_line& rhs);
 bool operator<(const stats_line& lhs, const stats_line& rhs);
 
+// This mirrors BpfMap.h's:
+//   Result<Value> readValue(const Key key) const
+// for a BpfMap<uint32_t, IfaceValue>
+using IfIndexToNameFunc = std::function<Result<IfaceValue>(const uint32_t)>;
+
 // For test only
 int bpfGetUidStatsInternal(uid_t uid, StatsValue* stats,
                            const BpfMapRO<uint32_t, StatsValue>& appUidStatsMap);
 // For test only
 int bpfGetIfaceStatsInternal(const char* iface, StatsValue* stats,
                              const BpfMapRO<uint32_t, StatsValue>& ifaceStatsMap,
-                             const BpfMapRO<uint32_t, IfaceValue>& ifaceNameMap);
+                             const IfIndexToNameFunc ifindex2name);
 // For test only
 int bpfGetIfIndexStatsInternal(uint32_t ifindex, StatsValue* stats,
                                const BpfMapRO<uint32_t, StatsValue>& ifaceStatsMap);
 // For test only
 int parseBpfNetworkStatsDetailInternal(std::vector<stats_line>& lines,
                                        const BpfMapRO<StatsKey, StatsValue>& statsMap,
-                                       const BpfMapRO<uint32_t, IfaceValue>& ifaceMap);
+                                       const IfIndexToNameFunc ifindex2name);
 // For test only
 int cleanStatsMapInternal(const base::unique_fd& cookieTagMap, const base::unique_fd& tagStatsMap);
 
@@ -98,7 +103,7 @@
 // For test only
 int parseBpfNetworkStatsDevInternal(std::vector<stats_line>& lines,
                                     const BpfMapRO<uint32_t, StatsValue>& statsMap,
-                                    const BpfMapRO<uint32_t, IfaceValue>& ifaceMap);
+                                    const IfIndexToNameFunc ifindex2name);
 
 void bpfRegisterIface(const char* iface);
 int bpfGetUidStats(uid_t uid, StatsValue* stats);
diff --git a/service-t/src/com/android/server/net/BpfInterfaceMapHelper.java b/service-t/src/com/android/server/net/BpfInterfaceMapHelper.java
new file mode 100644
index 0000000..3c95b8e
--- /dev/null
+++ b/service-t/src/com/android/server/net/BpfInterfaceMapHelper.java
@@ -0,0 +1,101 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.android.server.net;
+
+import android.os.Build;
+import android.system.ErrnoException;
+import android.util.IndentingPrintWriter;
+import android.util.Log;
+
+import androidx.annotation.RequiresApi;
+
+import com.android.internal.annotations.VisibleForTesting;
+import com.android.net.module.util.BpfDump;
+import com.android.net.module.util.BpfMap;
+import com.android.net.module.util.IBpfMap;
+import com.android.net.module.util.Struct.S32;
+
+/**
+ * Monitor interface added (without removed) and right interface name and its index to bpf map.
+ */
+@RequiresApi(Build.VERSION_CODES.TIRAMISU)
+public class BpfInterfaceMapHelper {
+    private static final String TAG = BpfInterfaceMapHelper.class.getSimpleName();
+    // This is current path but may be changed soon.
+    private static final String IFACE_INDEX_NAME_MAP_PATH =
+            "/sys/fs/bpf/netd_shared/map_netd_iface_index_name_map";
+    private final IBpfMap<S32, InterfaceMapValue> mIndexToIfaceBpfMap;
+
+    public BpfInterfaceMapHelper() {
+        this(new Dependencies());
+    }
+
+    @VisibleForTesting
+    public BpfInterfaceMapHelper(Dependencies deps) {
+        mIndexToIfaceBpfMap = deps.getInterfaceMap();
+    }
+
+    /**
+     * Dependencies of BpfInerfaceMapUpdater, for injection in tests.
+     */
+    @VisibleForTesting
+    public static class Dependencies {
+        /** Create BpfMap for updating interface and index mapping. */
+        public IBpfMap<S32, InterfaceMapValue> getInterfaceMap() {
+            try {
+                return new BpfMap<>(IFACE_INDEX_NAME_MAP_PATH,
+                    S32.class, InterfaceMapValue.class);
+            } catch (ErrnoException e) {
+                Log.e(TAG, "Cannot create interface map: " + e);
+                return null;
+            }
+        }
+    }
+
+    /** get interface name by interface index from bpf map */
+    public String getIfNameByIndex(final int index) {
+        try {
+            final InterfaceMapValue value = mIndexToIfaceBpfMap.getValue(new S32(index));
+            if (value == null) {
+                Log.e(TAG, "No if name entry for index " + index);
+                return null;
+            }
+            return value.getInterfaceNameString();
+        } catch (ErrnoException e) {
+            Log.e(TAG, "Failed to get entry for index " + index + ": " + e);
+            return null;
+        }
+    }
+
+    /**
+     * Dump BPF map
+     *
+     * @param pw print writer
+     */
+    public void dump(final IndentingPrintWriter pw) {
+        pw.println("BPF map status:");
+        pw.increaseIndent();
+        BpfDump.dumpMapStatus(mIndexToIfaceBpfMap, pw, "IfaceIndexNameMap",
+                IFACE_INDEX_NAME_MAP_PATH);
+        pw.decreaseIndent();
+        pw.println("BPF map content:");
+        pw.increaseIndent();
+        BpfDump.dumpMap(mIndexToIfaceBpfMap, pw, "IfaceIndexNameMap",
+                (key, value) -> "ifaceIndex=" + key.val
+                        + " ifaceName=" + value.getInterfaceNameString());
+        pw.decreaseIndent();
+    }
+}
diff --git a/service-t/src/com/android/server/net/BpfInterfaceMapUpdater.java b/service-t/src/com/android/server/net/BpfInterfaceMapUpdater.java
deleted file mode 100644
index 59de2c4..0000000
--- a/service-t/src/com/android/server/net/BpfInterfaceMapUpdater.java
+++ /dev/null
@@ -1,179 +0,0 @@
-/*
- * Copyright (C) 2022 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package com.android.server.net;
-
-import android.os.Build;
-import android.content.Context;
-import android.net.INetd;
-import android.os.Handler;
-import android.os.IBinder;
-import android.os.RemoteException;
-import android.os.ServiceSpecificException;
-import android.system.ErrnoException;
-import android.util.IndentingPrintWriter;
-import android.util.Log;
-
-import androidx.annotation.RequiresApi;
-
-import com.android.internal.annotations.VisibleForTesting;
-import com.android.net.module.util.BaseNetdUnsolicitedEventListener;
-import com.android.net.module.util.BpfDump;
-import com.android.net.module.util.BpfMap;
-import com.android.net.module.util.IBpfMap;
-import com.android.net.module.util.InterfaceParams;
-import com.android.net.module.util.Struct.S32;
-
-/**
- * Monitor interface added (without removed) and right interface name and its index to bpf map.
- */
-@RequiresApi(Build.VERSION_CODES.TIRAMISU)
-public class BpfInterfaceMapUpdater {
-    private static final String TAG = BpfInterfaceMapUpdater.class.getSimpleName();
-    // This is current path but may be changed soon.
-    private static final String IFACE_INDEX_NAME_MAP_PATH =
-            "/sys/fs/bpf/netd_shared/map_netd_iface_index_name_map";
-    private final IBpfMap<S32, InterfaceMapValue> mIndexToIfaceBpfMap;
-    private final INetd mNetd;
-    private final Handler mHandler;
-    private final Dependencies mDeps;
-
-    public BpfInterfaceMapUpdater(Context ctx, Handler handler) {
-        this(ctx, handler, new Dependencies());
-    }
-
-    @VisibleForTesting
-    public BpfInterfaceMapUpdater(Context ctx, Handler handler, Dependencies deps) {
-        mDeps = deps;
-        mIndexToIfaceBpfMap = deps.getInterfaceMap();
-        mNetd = deps.getINetd(ctx);
-        mHandler = handler;
-    }
-
-    /**
-     * Dependencies of BpfInerfaceMapUpdater, for injection in tests.
-     */
-    @VisibleForTesting
-    public static class Dependencies {
-        /** Create BpfMap for updating interface and index mapping. */
-        public IBpfMap<S32, InterfaceMapValue> getInterfaceMap() {
-            try {
-                return new BpfMap<>(IFACE_INDEX_NAME_MAP_PATH,
-                    S32.class, InterfaceMapValue.class);
-            } catch (ErrnoException e) {
-                Log.e(TAG, "Cannot create interface map: " + e);
-                return null;
-            }
-        }
-
-        /** Get InterfaceParams for giving interface name. */
-        public InterfaceParams getInterfaceParams(String ifaceName) {
-            return InterfaceParams.getByName(ifaceName);
-        }
-
-        /** Get INetd binder object. */
-        public INetd getINetd(Context ctx) {
-            return INetd.Stub.asInterface((IBinder) ctx.getSystemService(Context.NETD_SERVICE));
-        }
-    }
-
-    /**
-     * Start listening interface update event.
-     * Query current interface names before listening.
-     */
-    public void start() {
-        mHandler.post(() -> {
-            if (mIndexToIfaceBpfMap == null) {
-                Log.wtf(TAG, "Fail to start: Null bpf map");
-                return;
-            }
-
-            try {
-                // TODO: use a NetlinkMonitor and listen for RTM_NEWLINK messages instead.
-                mNetd.registerUnsolicitedEventListener(new InterfaceChangeObserver());
-            } catch (RemoteException e) {
-                Log.wtf(TAG, "Unable to register netd UnsolicitedEventListener, " + e);
-            }
-
-            final String[] ifaces;
-            try {
-                // TODO: use a netlink dump to get the current interface list.
-                ifaces = mNetd.interfaceGetList();
-            } catch (RemoteException | ServiceSpecificException e) {
-                Log.wtf(TAG, "Unable to query interface names by netd, " + e);
-                return;
-            }
-
-            for (String ifaceName : ifaces) {
-                addInterface(ifaceName);
-            }
-        });
-    }
-
-    private void addInterface(String ifaceName) {
-        final InterfaceParams iface = mDeps.getInterfaceParams(ifaceName);
-        if (iface == null) {
-            Log.e(TAG, "Unable to get InterfaceParams for " + ifaceName);
-            return;
-        }
-
-        try {
-            mIndexToIfaceBpfMap.updateEntry(new S32(iface.index), new InterfaceMapValue(ifaceName));
-        } catch (ErrnoException e) {
-            Log.e(TAG, "Unable to update entry for " + ifaceName + ", " + e);
-        }
-    }
-
-    private class InterfaceChangeObserver extends BaseNetdUnsolicitedEventListener {
-        @Override
-        public void onInterfaceAdded(String ifName) {
-            mHandler.post(() -> addInterface(ifName));
-        }
-    }
-
-    /** get interface name by interface index from bpf map */
-    public String getIfNameByIndex(final int index) {
-        try {
-            final InterfaceMapValue value = mIndexToIfaceBpfMap.getValue(new S32(index));
-            if (value == null) {
-                Log.e(TAG, "No if name entry for index " + index);
-                return null;
-            }
-            return value.getInterfaceNameString();
-        } catch (ErrnoException e) {
-            Log.e(TAG, "Failed to get entry for index " + index + ": " + e);
-            return null;
-        }
-    }
-
-    /**
-     * Dump BPF map
-     *
-     * @param pw print writer
-     */
-    public void dump(final IndentingPrintWriter pw) {
-        pw.println("BPF map status:");
-        pw.increaseIndent();
-        BpfDump.dumpMapStatus(mIndexToIfaceBpfMap, pw, "IfaceIndexNameMap",
-                IFACE_INDEX_NAME_MAP_PATH);
-        pw.decreaseIndent();
-        pw.println("BPF map content:");
-        pw.increaseIndent();
-        BpfDump.dumpMap(mIndexToIfaceBpfMap, pw, "IfaceIndexNameMap",
-                (key, value) -> "ifaceIndex=" + key.val
-                        + " ifaceName=" + value.getInterfaceNameString());
-        pw.decreaseIndent();
-    }
-}
diff --git a/service-t/src/com/android/server/net/NetworkStatsService.java b/service-t/src/com/android/server/net/NetworkStatsService.java
index 7b24315..ec10158 100644
--- a/service-t/src/com/android/server/net/NetworkStatsService.java
+++ b/service-t/src/com/android/server/net/NetworkStatsService.java
@@ -476,7 +476,7 @@
     private final LocationPermissionChecker mLocationPermissionChecker;
 
     @NonNull
-    private final BpfInterfaceMapUpdater mInterfaceMapUpdater;
+    private final BpfInterfaceMapHelper mInterfaceMapHelper;
 
     @Nullable
     private final SkDestroyListener mSkDestroyListener;
@@ -628,8 +628,7 @@
         mContentObserver = mDeps.makeContentObserver(mHandler, mSettings,
                 mNetworkStatsSubscriptionsMonitor);
         mLocationPermissionChecker = mDeps.makeLocationPermissionChecker(mContext);
-        mInterfaceMapUpdater = mDeps.makeBpfInterfaceMapUpdater(mContext, mHandler);
-        mInterfaceMapUpdater.start();
+        mInterfaceMapHelper = mDeps.makeBpfInterfaceMapHelper();
         mUidCounterSetMap = mDeps.getUidCounterSetMap();
         mCookieTagMap = mDeps.getCookieTagMap();
         mStatsMapA = mDeps.getStatsMapA();
@@ -798,11 +797,10 @@
             return new LocationPermissionChecker(context);
         }
 
-        /** Create BpfInterfaceMapUpdater to update bpf interface map. */
+        /** Create BpfInterfaceMapHelper to update bpf interface map. */
         @NonNull
-        public BpfInterfaceMapUpdater makeBpfInterfaceMapUpdater(
-                @NonNull Context ctx, @NonNull Handler handler) {
-            return new BpfInterfaceMapUpdater(ctx, handler);
+        public BpfInterfaceMapHelper makeBpfInterfaceMapHelper() {
+            return new BpfInterfaceMapHelper();
         }
 
         /** Get counter sets map for each UID. */
@@ -2889,9 +2887,9 @@
             }
 
             pw.println();
-            pw.println("InterfaceMapUpdater:");
+            pw.println("InterfaceMapHelper:");
             pw.increaseIndent();
-            mInterfaceMapUpdater.dump(pw);
+            mInterfaceMapHelper.dump(pw);
             pw.decreaseIndent();
 
             pw.println();
@@ -3038,7 +3036,7 @@
         BpfDump.dumpMap(statsMap, pw, mapName,
                 "ifaceIndex ifaceName tag_hex uid_int cnt_set rxBytes rxPackets txBytes txPackets",
                 (key, value) -> {
-                    final String ifName = mInterfaceMapUpdater.getIfNameByIndex(key.ifaceIndex);
+                    final String ifName = mInterfaceMapHelper.getIfNameByIndex(key.ifaceIndex);
                     return key.ifaceIndex + " "
                             + (ifName != null ? ifName : "unknown") + " "
                             + "0x" + Long.toHexString(key.tag) + " "
@@ -3056,7 +3054,7 @@
         BpfDump.dumpMap(mIfaceStatsMap, pw, "mIfaceStatsMap",
                 "ifaceIndex ifaceName rxBytes rxPackets txBytes txPackets",
                 (key, value) -> {
-                    final String ifName = mInterfaceMapUpdater.getIfNameByIndex(key.val);
+                    final String ifName = mInterfaceMapHelper.getIfNameByIndex(key.val);
                     return key.val + " "
                             + (ifName != null ? ifName : "unknown") + " "
                             + value.rxBytes + " "
diff --git a/staticlibs/device/com/android/net/module/util/netlink/NetlinkUtils.java b/staticlibs/device/com/android/net/module/util/netlink/NetlinkUtils.java
index 7c2be2c..f735e46 100644
--- a/staticlibs/device/com/android/net/module/util/netlink/NetlinkUtils.java
+++ b/staticlibs/device/com/android/net/module/util/netlink/NetlinkUtils.java
@@ -29,6 +29,7 @@
 import static android.system.OsConstants.SO_RCVBUF;
 import static android.system.OsConstants.SO_RCVTIMEO;
 import static android.system.OsConstants.SO_SNDTIMEO;
+
 import static com.android.net.module.util.netlink.NetlinkConstants.hexify;
 import static com.android.net.module.util.netlink.NetlinkConstants.NLMSG_DONE;
 import static com.android.net.module.util.netlink.NetlinkConstants.RTNL_FAMILY_IP6MR;
@@ -49,6 +50,7 @@
 import java.io.IOException;
 import java.io.InterruptedIOException;
 import java.net.Inet6Address;
+import java.net.InetAddress;
 import java.net.SocketException;
 import java.nio.ByteBuffer;
 import java.nio.ByteOrder;
@@ -56,7 +58,6 @@
 import java.util.List;
 import java.util.Objects;
 import java.util.function.Consumer;
-import java.util.stream.Collectors;
 
 /**
  * Utilities for netlink related class that may not be able to fit into a specific class.
@@ -177,19 +178,19 @@
     }
 
     /**
-     * Send an RTM_NEWADDR message to kernel to add or update an IPv6 address.
+     * Send an RTM_NEWADDR message to kernel to add or update an IP address.
      *
      * @param ifIndex interface index.
-     * @param ip IPv6 address to be added.
-     * @param prefixlen IPv6 address prefix length.
-     * @param flags IPv6 address flags.
-     * @param scope IPv6 address scope.
-     * @param preferred The preferred lifetime of IPv6 address.
-     * @param valid The valid lifetime of IPv6 address.
+     * @param ip IP address to be added.
+     * @param prefixlen IP address prefix length.
+     * @param flags IP address flags.
+     * @param scope IP address scope.
+     * @param preferred The preferred lifetime of IP address.
+     * @param valid The valid lifetime of IP address.
      */
-    public static boolean sendRtmNewAddressRequest(int ifIndex, @NonNull final Inet6Address ip,
+    public static boolean sendRtmNewAddressRequest(int ifIndex, @NonNull final InetAddress ip,
             short prefixlen, int flags, byte scope, long preferred, long valid) {
-        Objects.requireNonNull(ip, "IPv6 address to be added should not be null.");
+        Objects.requireNonNull(ip, "IP address to be added should not be null.");
         final byte[] msg = RtNetlinkAddressMessage.newRtmNewAddressMessage(1 /* seqNo*/, ip,
                 prefixlen, flags, scope, ifIndex, preferred, valid);
         try {
diff --git a/staticlibs/device/com/android/net/module/util/netlink/RtNetlinkAddressMessage.java b/staticlibs/device/com/android/net/module/util/netlink/RtNetlinkAddressMessage.java
index cbe0ab0..4846df7 100644
--- a/staticlibs/device/com/android/net/module/util/netlink/RtNetlinkAddressMessage.java
+++ b/staticlibs/device/com/android/net/module/util/netlink/RtNetlinkAddressMessage.java
@@ -16,6 +16,7 @@
 
 package com.android.net.module.util.netlink;
 
+import static com.android.net.module.util.Inet4AddressUtils.getBroadcastAddress;
 import static com.android.net.module.util.netlink.StructNlMsgHdr.NLM_F_ACK;
 import static com.android.net.module.util.netlink.StructNlMsgHdr.NLM_F_REPLACE;
 import static com.android.net.module.util.netlink.StructNlMsgHdr.NLM_F_REQUEST;
@@ -28,6 +29,7 @@
 
 import com.android.net.module.util.HexDump;
 
+import java.net.Inet4Address;
 import java.net.Inet6Address;
 import java.net.InetAddress;
 import java.nio.ByteBuffer;
@@ -48,6 +50,8 @@
  */
 public class RtNetlinkAddressMessage extends NetlinkMessage {
     public static final short IFA_ADDRESS        = 1;
+    public static final short IFA_LOCAL          = 2;
+    public static final short IFA_BROADCAST      = 4;
     public static final short IFA_CACHEINFO      = 6;
     public static final short IFA_FLAGS          = 8;
 
@@ -71,6 +75,7 @@
         mIfacacheInfo = structIfacacheInfo;
         mFlags = flags;
     }
+
     private RtNetlinkAddressMessage(@NonNull StructNlMsgHdr header) {
         this(header, null, null, null, 0);
     }
@@ -158,6 +163,24 @@
         // still be packed to ByteBuffer even if the flag is 0.
         final StructNlAttr flags = new StructNlAttr(IFA_FLAGS, mFlags);
         flags.pack(byteBuffer);
+
+        // Add the required IFA_LOCAL and IFA_BROADCAST attributes for IPv4 addresses. The IFA_LOCAL
+        // attribute represents the local address, which is equivalent to IFA_ADDRESS on a normally
+        // configured broadcast interface, however, for PPP interfaces, IFA_ADDRESS indicates the
+        // destination address and the local address is provided in the IFA_LOCAL attribute. If the
+        // IFA_LOCAL attribute is not present in the RTM_NEWADDR message, the kernel replies with an
+        // error netlink message with invalid parameters. IFA_BROADCAST is also required, otherwise
+        // the broadcast on the interface is 0.0.0.0. See include/uapi/linux/if_addr.h for details.
+        // For IPv6 addresses, the IFA_ADDRESS attribute applies and introduces no ambiguity.
+        if (mIpAddress instanceof Inet4Address) {
+            final StructNlAttr localAddress = new StructNlAttr(IFA_LOCAL, mIpAddress);
+            localAddress.pack(byteBuffer);
+
+            final Inet4Address broadcast =
+                    getBroadcastAddress((Inet4Address) mIpAddress, mIfaddrmsg.prefixLen);
+            final StructNlAttr broadcastAddress = new StructNlAttr(IFA_BROADCAST, broadcast);
+            broadcastAddress.pack(byteBuffer);
+        }
     }
 
     /**
@@ -184,7 +207,7 @@
                 0 /* tstamp */);
         msg.mFlags = flags;
 
-        final byte[] bytes = new byte[msg.getRequiredSpace()];
+        final byte[] bytes = new byte[msg.getRequiredSpace(family)];
         nlmsghdr.nlmsg_len = bytes.length;
         final ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
         byteBuffer.order(ByteOrder.nativeOrder());
@@ -237,7 +260,7 @@
     // RtNetlinkAddressMessage, e.g. RTM_DELADDR sent from user space to kernel to delete an
     // IP address only requires IFA_ADDRESS attribute. The caller should check if these attributes
     // are necessary to carry when constructing a RtNetlinkAddressMessage.
-    private int getRequiredSpace() {
+    private int getRequiredSpace(int family) {
         int spaceRequired = StructNlMsgHdr.STRUCT_SIZE + StructIfaddrMsg.STRUCT_SIZE;
         // IFA_ADDRESS attr
         spaceRequired += NetlinkConstants.alignedLengthOf(
@@ -247,6 +270,14 @@
                 StructNlAttr.NLA_HEADERLEN + StructIfacacheInfo.STRUCT_SIZE);
         // IFA_FLAGS "u32" attr
         spaceRequired += StructNlAttr.NLA_HEADERLEN + 4;
+        if (family == OsConstants.AF_INET) {
+            // IFA_LOCAL attr
+            spaceRequired += NetlinkConstants.alignedLengthOf(
+                    StructNlAttr.NLA_HEADERLEN + mIpAddress.getAddress().length);
+            // IFA_BROADCAST attr
+            spaceRequired += NetlinkConstants.alignedLengthOf(
+                    StructNlAttr.NLA_HEADERLEN + mIpAddress.getAddress().length);
+        }
         return spaceRequired;
     }
 
diff --git a/staticlibs/tests/unit/src/com/android/net/module/util/netlink/RtNetlinkAddressMessageTest.java b/staticlibs/tests/unit/src/com/android/net/module/util/netlink/RtNetlinkAddressMessageTest.java
index 01126d2..1d08525 100644
--- a/staticlibs/tests/unit/src/com/android/net/module/util/netlink/RtNetlinkAddressMessageTest.java
+++ b/staticlibs/tests/unit/src/com/android/net/module/util/netlink/RtNetlinkAddressMessageTest.java
@@ -42,6 +42,7 @@
 import org.junit.Test;
 import org.junit.runner.RunWith;
 
+import java.net.Inet4Address;
 import java.net.Inet6Address;
 import java.nio.ByteBuffer;
 import java.nio.ByteOrder;
@@ -179,6 +180,57 @@
     }
 
     @Test
+    public void testCreateRtmNewAddressMessage_IPv4Address() {
+        // Hexadecimal representation of our created packet.
+        final String expectedNewAddressHex =
+                // struct nlmsghdr
+                "4c000000"      // length = 76
+                + "1400"        // type = 20 (RTM_NEWADDR)
+                + "0501"        // flags = NLM_F_ACK | NLM_F_REQUEST | NLM_F_REPLACE
+                + "01000000"    // seqno = 1
+                + "00000000"    // pid = 0 (send to kernel)
+                // struct IfaddrMsg
+                + "02"          // family = inet
+                + "18"          // prefix len = 24
+                + "00"          // flags = 0
+                + "00"          // scope = RT_SCOPE_UNIVERSE
+                + "14000000"    // ifindex = 20
+                // struct nlattr: IFA_ADDRESS
+                + "0800"        // len
+                + "0100"        // type
+                + "C0A80491"    // IPv4 address = 192.168.4.145
+                // struct nlattr: IFA_CACHEINFO
+                + "1400"        // len
+                + "0600"        // type
+                + "C0A80000"    // preferred = 43200s
+                + "C0A80000"    // valid = 43200s
+                + "00000000"    // cstamp
+                + "00000000"    // tstamp
+                // struct nlattr: IFA_FLAGS
+                + "0800"        // len
+                + "0800"        // type
+                + "00000000"    // flags = 0
+                // struct nlattr: IFA_LOCAL
+                + "0800"        // len
+                + "0200"        // type
+                + "C0A80491"    // local address = 192.168.4.145
+                // struct nlattr: IFA_BROADCAST
+                + "0800"        // len
+                + "0400"        // type
+                + "C0A804FF";   // broadcast address = 192.168.4.255
+        final byte[] expectedNewAddress =
+                HexEncoding.decode(expectedNewAddressHex.toCharArray(), false);
+
+        final Inet4Address ipAddress =
+                (Inet4Address) InetAddresses.parseNumericAddress("192.168.4.145");
+        final byte[] bytes = RtNetlinkAddressMessage.newRtmNewAddressMessage(1 /* seqno */,
+                ipAddress, (short) 24 /* prefix len */, 0 /* flags */,
+                (byte) RT_SCOPE_UNIVERSE /* scope */, 20 /* ifindex */,
+                (long) 0xA8C0 /* preferred */, (long) 0xA8C0 /* valid */);
+        assertArrayEquals(expectedNewAddress, bytes);
+    }
+
+    @Test
     public void testCreateRtmDelAddressMessage() {
         // Hexadecimal representation of our created packet.
         final String expectedDelAddressHex =
diff --git a/tests/unit/java/com/android/server/net/BpfInterfaceMapHelperTest.java b/tests/unit/java/com/android/server/net/BpfInterfaceMapHelperTest.java
new file mode 100644
index 0000000..7b3bea3
--- /dev/null
+++ b/tests/unit/java/com/android/server/net/BpfInterfaceMapHelperTest.java
@@ -0,0 +1,113 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.net;
+
+import static android.system.OsConstants.EPERM;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.spy;
+
+import android.os.Build;
+import android.system.ErrnoException;
+import android.util.IndentingPrintWriter;
+
+import androidx.test.filters.SmallTest;
+
+import com.android.net.module.util.BaseNetdUnsolicitedEventListener;
+import com.android.net.module.util.IBpfMap;
+import com.android.net.module.util.Struct.S32;
+import com.android.testutils.DevSdkIgnoreRule;
+import com.android.testutils.DevSdkIgnoreRunner;
+import com.android.testutils.TestBpfMap;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.MockitoAnnotations;
+
+import java.io.PrintWriter;
+import java.io.StringWriter;
+
+@SmallTest
+@RunWith(DevSdkIgnoreRunner.class)
+@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
+public final class BpfInterfaceMapHelperTest {
+    private static final int TEST_INDEX = 1;
+    private static final int TEST_INDEX2 = 2;
+    private static final String TEST_INTERFACE_NAME = "test1";
+    private static final String TEST_INTERFACE_NAME2 = "test2";
+
+    private BaseNetdUnsolicitedEventListener mListener;
+    private BpfInterfaceMapHelper mUpdater;
+    private IBpfMap<S32, InterfaceMapValue> mBpfMap =
+            spy(new TestBpfMap<>(S32.class, InterfaceMapValue.class));
+
+    private class TestDependencies extends BpfInterfaceMapHelper.Dependencies {
+        @Override
+        public IBpfMap<S32, InterfaceMapValue> getInterfaceMap() {
+            return mBpfMap;
+        }
+    }
+
+    @Before
+    public void setUp() throws Exception {
+        MockitoAnnotations.initMocks(this);
+        mUpdater = new BpfInterfaceMapHelper(new TestDependencies());
+    }
+
+    @Test
+    public void testGetIfNameByIndex() throws Exception {
+        mBpfMap.updateEntry(new S32(TEST_INDEX), new InterfaceMapValue(TEST_INTERFACE_NAME));
+        assertEquals(TEST_INTERFACE_NAME, mUpdater.getIfNameByIndex(TEST_INDEX));
+    }
+
+    @Test
+    public void testGetIfNameByIndexNoEntry() {
+        assertNull(mUpdater.getIfNameByIndex(TEST_INDEX));
+    }
+
+    @Test
+    public void testGetIfNameByIndexException() throws Exception {
+        doThrow(new ErrnoException("", EPERM)).when(mBpfMap).getValue(new S32(TEST_INDEX));
+        assertNull(mUpdater.getIfNameByIndex(TEST_INDEX));
+    }
+
+    private void assertDumpContains(final String dump, final String message) {
+        assertTrue(String.format("dump(%s) does not contain '%s'", dump, message),
+                dump.contains(message));
+    }
+
+    private String getDump() {
+        final StringWriter sw = new StringWriter();
+        mUpdater.dump(new IndentingPrintWriter(new PrintWriter(sw), " "));
+        return sw.toString();
+    }
+
+    @Test
+    public void testDump() throws ErrnoException {
+        mBpfMap.updateEntry(new S32(TEST_INDEX), new InterfaceMapValue(TEST_INTERFACE_NAME));
+        mBpfMap.updateEntry(new S32(TEST_INDEX2), new InterfaceMapValue(TEST_INTERFACE_NAME2));
+
+        final String dump = getDump();
+        assertDumpContains(dump, "IfaceIndexNameMap: OK");
+        assertDumpContains(dump, "ifaceIndex=1 ifaceName=test1");
+        assertDumpContains(dump, "ifaceIndex=2 ifaceName=test2");
+    }
+}
diff --git a/tests/unit/java/com/android/server/net/BpfInterfaceMapUpdaterTest.java b/tests/unit/java/com/android/server/net/BpfInterfaceMapUpdaterTest.java
deleted file mode 100644
index c730856..0000000
--- a/tests/unit/java/com/android/server/net/BpfInterfaceMapUpdaterTest.java
+++ /dev/null
@@ -1,174 +0,0 @@
-/*
- * Copyright (C) 2022 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.server.net;
-
-import static android.system.OsConstants.EPERM;
-
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNull;
-import static org.junit.Assert.assertTrue;
-import static org.mockito.Matchers.eq;
-import static org.mockito.Mockito.doThrow;
-import static org.mockito.Mockito.spy;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.verifyNoMoreInteractions;
-import static org.mockito.Mockito.when;
-
-import android.content.Context;
-import android.net.INetd;
-import android.net.MacAddress;
-import android.os.Build;
-import android.os.Handler;
-import android.os.test.TestLooper;
-import android.system.ErrnoException;
-import android.util.IndentingPrintWriter;
-
-import androidx.test.filters.SmallTest;
-
-import com.android.net.module.util.BaseNetdUnsolicitedEventListener;
-import com.android.net.module.util.IBpfMap;
-import com.android.net.module.util.InterfaceParams;
-import com.android.net.module.util.Struct.S32;
-import com.android.testutils.DevSdkIgnoreRule;
-import com.android.testutils.DevSdkIgnoreRunner;
-import com.android.testutils.TestBpfMap;
-
-import org.junit.Before;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.mockito.ArgumentCaptor;
-import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
-
-import java.io.PrintWriter;
-import java.io.StringWriter;
-
-@SmallTest
-@RunWith(DevSdkIgnoreRunner.class)
-@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
-public final class BpfInterfaceMapUpdaterTest {
-    private static final int TEST_INDEX = 1;
-    private static final int TEST_INDEX2 = 2;
-    private static final String TEST_INTERFACE_NAME = "test1";
-    private static final String TEST_INTERFACE_NAME2 = "test2";
-
-    private final TestLooper mLooper = new TestLooper();
-    private BaseNetdUnsolicitedEventListener mListener;
-    private BpfInterfaceMapUpdater mUpdater;
-    private IBpfMap<S32, InterfaceMapValue> mBpfMap =
-            spy(new TestBpfMap<>(S32.class, InterfaceMapValue.class));
-    @Mock private INetd mNetd;
-    @Mock private Context mContext;
-
-    private class TestDependencies extends BpfInterfaceMapUpdater.Dependencies {
-        @Override
-        public IBpfMap<S32, InterfaceMapValue> getInterfaceMap() {
-            return mBpfMap;
-        }
-
-        @Override
-        public InterfaceParams getInterfaceParams(String ifaceName) {
-            if (ifaceName.equals(TEST_INTERFACE_NAME)) {
-                return new InterfaceParams(TEST_INTERFACE_NAME, TEST_INDEX,
-                        MacAddress.ALL_ZEROS_ADDRESS);
-            } else if (ifaceName.equals(TEST_INTERFACE_NAME2)) {
-                return new InterfaceParams(TEST_INTERFACE_NAME2, TEST_INDEX2,
-                        MacAddress.ALL_ZEROS_ADDRESS);
-            }
-
-            return null;
-        }
-
-        @Override
-        public INetd getINetd(Context ctx) {
-            return mNetd;
-        }
-    }
-
-    @Before
-    public void setUp() throws Exception {
-        MockitoAnnotations.initMocks(this);
-        when(mNetd.interfaceGetList()).thenReturn(new String[] {TEST_INTERFACE_NAME});
-        mUpdater = new BpfInterfaceMapUpdater(mContext, new Handler(mLooper.getLooper()),
-                new TestDependencies());
-    }
-
-    private void verifyStartUpdater() throws Exception {
-        mUpdater.start();
-        mLooper.dispatchAll();
-        final ArgumentCaptor<BaseNetdUnsolicitedEventListener> listenerCaptor =
-                ArgumentCaptor.forClass(BaseNetdUnsolicitedEventListener.class);
-        verify(mNetd).registerUnsolicitedEventListener(listenerCaptor.capture());
-        mListener = listenerCaptor.getValue();
-        verify(mBpfMap).updateEntry(eq(new S32(TEST_INDEX)),
-                eq(new InterfaceMapValue(TEST_INTERFACE_NAME)));
-    }
-
-    @Test
-    public void testUpdateInterfaceMap() throws Exception {
-        verifyStartUpdater();
-
-        mListener.onInterfaceAdded(TEST_INTERFACE_NAME2);
-        mLooper.dispatchAll();
-        verify(mBpfMap).updateEntry(eq(new S32(TEST_INDEX2)),
-                eq(new InterfaceMapValue(TEST_INTERFACE_NAME2)));
-
-        // Check that when onInterfaceRemoved is called, nothing happens.
-        mListener.onInterfaceRemoved(TEST_INTERFACE_NAME);
-        mLooper.dispatchAll();
-        verifyNoMoreInteractions(mBpfMap);
-    }
-
-    @Test
-    public void testGetIfNameByIndex() throws Exception {
-        mBpfMap.updateEntry(new S32(TEST_INDEX), new InterfaceMapValue(TEST_INTERFACE_NAME));
-        assertEquals(TEST_INTERFACE_NAME, mUpdater.getIfNameByIndex(TEST_INDEX));
-    }
-
-    @Test
-    public void testGetIfNameByIndexNoEntry() {
-        assertNull(mUpdater.getIfNameByIndex(TEST_INDEX));
-    }
-
-    @Test
-    public void testGetIfNameByIndexException() throws Exception {
-        doThrow(new ErrnoException("", EPERM)).when(mBpfMap).getValue(new S32(TEST_INDEX));
-        assertNull(mUpdater.getIfNameByIndex(TEST_INDEX));
-    }
-
-    private void assertDumpContains(final String dump, final String message) {
-        assertTrue(String.format("dump(%s) does not contain '%s'", dump, message),
-                dump.contains(message));
-    }
-
-    private String getDump() {
-        final StringWriter sw = new StringWriter();
-        mUpdater.dump(new IndentingPrintWriter(new PrintWriter(sw), " "));
-        return sw.toString();
-    }
-
-    @Test
-    public void testDump() throws ErrnoException {
-        mBpfMap.updateEntry(new S32(TEST_INDEX), new InterfaceMapValue(TEST_INTERFACE_NAME));
-        mBpfMap.updateEntry(new S32(TEST_INDEX2), new InterfaceMapValue(TEST_INTERFACE_NAME2));
-
-        final String dump = getDump();
-        assertDumpContains(dump, "IfaceIndexNameMap: OK");
-        assertDumpContains(dump, "ifaceIndex=1 ifaceName=test1");
-        assertDumpContains(dump, "ifaceIndex=2 ifaceName=test2");
-    }
-}
diff --git a/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java b/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java
index a5fee5b..3ed51bc 100644
--- a/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java
+++ b/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java
@@ -254,7 +254,7 @@
     private @Mock AlarmManager mAlarmManager;
     @Mock
     private NetworkStatsSubscriptionsMonitor mNetworkStatsSubscriptionsMonitor;
-    private @Mock BpfInterfaceMapUpdater mBpfInterfaceMapUpdater;
+    private @Mock BpfInterfaceMapHelper mBpfInterfaceMapHelper;
     private HandlerThread mHandlerThread;
     @Mock
     private LocationPermissionChecker mLocationPermissionChecker;
@@ -519,9 +519,8 @@
         }
 
         @Override
-        public BpfInterfaceMapUpdater makeBpfInterfaceMapUpdater(
-                @NonNull Context ctx, @NonNull Handler handler) {
-            return mBpfInterfaceMapUpdater;
+        public BpfInterfaceMapHelper makeBpfInterfaceMapHelper() {
+            return mBpfInterfaceMapHelper;
         }
 
         @Override
@@ -2764,13 +2763,13 @@
 
     @Test
     public void testDumpStatsMap() throws ErrnoException {
-        doReturn("wlan0").when(mBpfInterfaceMapUpdater).getIfNameByIndex(10 /* index */);
+        doReturn("wlan0").when(mBpfInterfaceMapHelper).getIfNameByIndex(10 /* index */);
         doTestDumpStatsMap("wlan0");
     }
 
     @Test
     public void testDumpStatsMapUnknownInterface() throws ErrnoException {
-        doReturn(null).when(mBpfInterfaceMapUpdater).getIfNameByIndex(10 /* index */);
+        doReturn(null).when(mBpfInterfaceMapHelper).getIfNameByIndex(10 /* index */);
         doTestDumpStatsMap("unknown");
     }
 
@@ -2785,13 +2784,13 @@
 
     @Test
     public void testDumpIfaceStatsMap() throws Exception {
-        doReturn("wlan0").when(mBpfInterfaceMapUpdater).getIfNameByIndex(10 /* index */);
+        doReturn("wlan0").when(mBpfInterfaceMapHelper).getIfNameByIndex(10 /* index */);
         doTestDumpIfaceStatsMap("wlan0");
     }
 
     @Test
     public void testDumpIfaceStatsMapUnknownInterface() throws Exception {
-        doReturn(null).when(mBpfInterfaceMapUpdater).getIfNameByIndex(10 /* index */);
+        doReturn(null).when(mBpfInterfaceMapHelper).getIfNameByIndex(10 /* index */);
         doTestDumpIfaceStatsMap("unknown");
     }
 
diff --git a/thread/framework/java/android/net/thread/IStateCallback.aidl b/thread/framework/java/android/net/thread/IStateCallback.aidl
index d7cbda9..9d0a571 100644
--- a/thread/framework/java/android/net/thread/IStateCallback.aidl
+++ b/thread/framework/java/android/net/thread/IStateCallback.aidl
@@ -22,4 +22,5 @@
 oneway interface IStateCallback {
     void onDeviceRoleChanged(int deviceRole);
     void onPartitionIdChanged(long partitionId);
+    void onThreadEnableStateChanged(int enabledState);
 }
diff --git a/thread/framework/java/android/net/thread/IThreadNetworkController.aidl b/thread/framework/java/android/net/thread/IThreadNetworkController.aidl
index a9da8d6..485e25d 100644
--- a/thread/framework/java/android/net/thread/IThreadNetworkController.aidl
+++ b/thread/framework/java/android/net/thread/IThreadNetworkController.aidl
@@ -42,4 +42,6 @@
 
     int getThreadVersion();
     void createRandomizedDataset(String networkName, IActiveOperationalDatasetReceiver receiver);
+
+    void setEnabled(boolean enabled, in IOperationReceiver receiver);
 }
diff --git a/thread/framework/java/android/net/thread/ThreadNetworkController.java b/thread/framework/java/android/net/thread/ThreadNetworkController.java
index 7242ed7..db761a3 100644
--- a/thread/framework/java/android/net/thread/ThreadNetworkController.java
+++ b/thread/framework/java/android/net/thread/ThreadNetworkController.java
@@ -68,6 +68,15 @@
     /** The device is a Thread Leader. */
     public static final int DEVICE_ROLE_LEADER = 4;
 
+    /** The Thread radio is disabled. */
+    public static final int STATE_DISABLED = 0;
+
+    /** The Thread radio is enabled. */
+    public static final int STATE_ENABLED = 1;
+
+    /** The Thread radio is being disabled. */
+    public static final int STATE_DISABLING = 2;
+
     /** @hide */
     @Retention(RetentionPolicy.SOURCE)
     @IntDef({
@@ -79,6 +88,13 @@
     })
     public @interface DeviceRole {}
 
+    /** @hide */
+    @Retention(RetentionPolicy.SOURCE)
+    @IntDef(
+            prefix = {"STATE_"},
+            value = {STATE_DISABLED, STATE_ENABLED, STATE_DISABLING})
+    public @interface EnabledState {}
+
     /** Thread standard version 1.3. */
     public static final int THREAD_VERSION_1_3 = 4;
 
@@ -106,6 +122,40 @@
         mControllerService = controllerService;
     }
 
+    /**
+     * Enables/Disables the radio of this ThreadNetworkController. The requested enabled state will
+     * be persistent and survives device reboots.
+     *
+     * <p>When Thread is in {@code STATE_DISABLED}, {@link ThreadNetworkController} APIs which
+     * require the Thread radio will fail with error code {@link
+     * ThreadNetworkException#ERROR_THREAD_DISABLED}. When Thread is in {@code STATE_DISABLING},
+     * {@link ThreadNetworkController} APIs that return a {@link ThreadNetworkException} will fail
+     * with error code {@link ThreadNetworkException#ERROR_BUSY}.
+     *
+     * <p>On success, {@link OutcomeReceiver#onResult} of {@code receiver} is called. It indicates
+     * the operation has completed. But there maybe subsequent calls to update the enabled state,
+     * callers of this method should use {@link #registerStateCallback} to subscribe to the Thread
+     * enabled state changes.
+     *
+     * <p>On failure, {@link OutcomeReceiver#onError} of {@code receiver} will be invoked with a
+     * specific error in {@link ThreadNetworkException#ERROR_}.
+     *
+     * @param enabled {@code true} for enabling Thread
+     * @param executor the executor to execute {@code receiver}
+     * @param receiver the receiver to receive result of this operation
+     */
+    @RequiresPermission("android.permission.THREAD_NETWORK_PRIVILEGED")
+    public void setEnabled(
+            boolean enabled,
+            @NonNull @CallbackExecutor Executor executor,
+            @NonNull OutcomeReceiver<Void, ThreadNetworkException> receiver) {
+        try {
+            mControllerService.setEnabled(enabled, new OperationReceiverProxy(executor, receiver));
+        } catch (RemoteException e) {
+            throw e.rethrowFromSystemServer();
+        }
+    }
+
     /** Returns the Thread version this device is operating on. */
     @ThreadVersion
     public int getThreadVersion() {
@@ -170,6 +220,16 @@
          * @param partitionId the new Thread partition ID
          */
         default void onPartitionIdChanged(long partitionId) {}
+
+        /**
+         * The Thread enabled state has changed.
+         *
+         * <p>The Thread enabled state can be set with {@link setEnabled}, it may also be updated by
+         * airplane mode or admin control.
+         *
+         * @param enabledState the new Thread enabled state
+         */
+        default void onThreadEnableStateChanged(@EnabledState int enabledState) {}
     }
 
     private static final class StateCallbackProxy extends IStateCallback.Stub {
@@ -200,6 +260,16 @@
                 Binder.restoreCallingIdentity(identity);
             }
         }
+
+        @Override
+        public void onThreadEnableStateChanged(@EnabledState int enabled) {
+            final long identity = Binder.clearCallingIdentity();
+            try {
+                mExecutor.execute(() -> mCallback.onThreadEnableStateChanged(enabled));
+            } finally {
+                Binder.restoreCallingIdentity(identity);
+            }
+        }
     }
 
     /**
diff --git a/thread/framework/java/android/net/thread/ThreadNetworkException.java b/thread/framework/java/android/net/thread/ThreadNetworkException.java
index af0a84b..23ed53e 100644
--- a/thread/framework/java/android/net/thread/ThreadNetworkException.java
+++ b/thread/framework/java/android/net/thread/ThreadNetworkException.java
@@ -48,6 +48,7 @@
         ERROR_RESPONSE_BAD_FORMAT,
         ERROR_RESOURCE_EXHAUSTED,
         ERROR_UNKNOWN,
+        ERROR_THREAD_DISABLED,
     })
     public @interface ErrorCode {}
 
@@ -129,6 +130,13 @@
      */
     public static final int ERROR_UNKNOWN = 11;
 
+    /**
+     * The operation failed because the Thread radio is disabled by {@link
+     * ThreadNetworkController#setEnabled}, airplane mode or device admin. The caller should retry
+     * only after Thread is enabled.
+     */
+    public static final int ERROR_THREAD_DISABLED = 12;
+
     private final int mErrorCode;
 
     /** Creates a new {@link ThreadNetworkException} object with given error code and message. */
diff --git a/thread/scripts/make-pretty.sh b/thread/scripts/make-pretty.sh
index e4bd459..c176bfa 100755
--- a/thread/scripts/make-pretty.sh
+++ b/thread/scripts/make-pretty.sh
@@ -3,5 +3,7 @@
 SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
 
 GOOGLE_JAVA_FORMAT=$SCRIPT_DIR/../../../../../prebuilts/tools/common/google-java-format/google-java-format
+ANDROID_BP_FORMAT=$SCRIPT_DIR/../../../../../prebuilts/build-tools/linux-x86/bin/bpfmt
 
 $GOOGLE_JAVA_FORMAT --aosp -i $(find $SCRIPT_DIR/../ -name "*.java")
+$ANDROID_BP_FORMAT -w $(find $SCRIPT_DIR/../ -name "*.bp")
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
index 1c51c42..7b9f290 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
@@ -27,6 +27,9 @@
 import static android.net.thread.ActiveOperationalDataset.MESH_LOCAL_PREFIX_FIRST_BYTE;
 import static android.net.thread.ActiveOperationalDataset.SecurityPolicy.DEFAULT_ROTATION_TIME_HOURS;
 import static android.net.thread.ThreadNetworkController.DEVICE_ROLE_DETACHED;
+import static android.net.thread.ThreadNetworkController.STATE_DISABLED;
+import static android.net.thread.ThreadNetworkController.STATE_DISABLING;
+import static android.net.thread.ThreadNetworkController.STATE_ENABLED;
 import static android.net.thread.ThreadNetworkController.THREAD_VERSION_1_3;
 import static android.net.thread.ThreadNetworkException.ERROR_ABORTED;
 import static android.net.thread.ThreadNetworkException.ERROR_BUSY;
@@ -35,6 +38,7 @@
 import static android.net.thread.ThreadNetworkException.ERROR_REJECTED_BY_PEER;
 import static android.net.thread.ThreadNetworkException.ERROR_RESOURCE_EXHAUSTED;
 import static android.net.thread.ThreadNetworkException.ERROR_RESPONSE_BAD_FORMAT;
+import static android.net.thread.ThreadNetworkException.ERROR_THREAD_DISABLED;
 import static android.net.thread.ThreadNetworkException.ERROR_TIMEOUT;
 import static android.net.thread.ThreadNetworkException.ERROR_UNSUPPORTED_CHANNEL;
 import static android.net.thread.ThreadNetworkManager.PERMISSION_THREAD_NETWORK_PRIVILEGED;
@@ -48,7 +52,11 @@
 import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_REASSEMBLY_TIMEOUT;
 import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_REJECTED;
 import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_RESPONSE_TIMEOUT;
+import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_THREAD_DISABLED;
 import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_UNSUPPORTED_CHANNEL;
+import static com.android.server.thread.openthread.IOtDaemon.OT_STATE_DISABLED;
+import static com.android.server.thread.openthread.IOtDaemon.OT_STATE_DISABLING;
+import static com.android.server.thread.openthread.IOtDaemon.OT_STATE_ENABLED;
 import static com.android.server.thread.openthread.IOtDaemon.TUN_IF_NAME;
 
 import android.Manifest.permission;
@@ -160,6 +168,7 @@
     private UpstreamNetworkCallback mUpstreamNetworkCallback;
     private TestNetworkSpecifier mUpstreamTestNetworkSpecifier;
     private final HashMap<Network, String> mNetworkToInterface;
+    private final ThreadPersistentSettings mPersistentSettings;
 
     private BorderRouterConfigurationParcel mBorderRouterConfig;
 
@@ -171,7 +180,8 @@
             Supplier<IOtDaemon> otDaemonSupplier,
             ConnectivityManager connectivityManager,
             TunInterfaceController tunIfController,
-            InfraInterfaceController infraIfController) {
+            InfraInterfaceController infraIfController,
+            ThreadPersistentSettings persistentSettings) {
         mContext = context;
         mHandler = handler;
         mNetworkProvider = networkProvider;
@@ -182,9 +192,11 @@
         mUpstreamNetworkRequest = newUpstreamNetworkRequest();
         mNetworkToInterface = new HashMap<Network, String>();
         mBorderRouterConfig = new BorderRouterConfigurationParcel();
+        mPersistentSettings = persistentSettings;
     }
 
-    public static ThreadNetworkControllerService newInstance(Context context) {
+    public static ThreadNetworkControllerService newInstance(
+            Context context, ThreadPersistentSettings persistentSettings) {
         HandlerThread handlerThread = new HandlerThread("ThreadHandlerThread");
         handlerThread.start();
         NetworkProvider networkProvider =
@@ -197,7 +209,8 @@
                 () -> IOtDaemon.Stub.asInterface(ServiceManagerWrapper.waitForService("ot_daemon")),
                 context.getSystemService(ConnectivityManager.class),
                 new TunInterfaceController(TUN_IF_NAME),
-                new InfraInterfaceController());
+                new InfraInterfaceController(),
+                persistentSettings);
     }
 
     private static Inet6Address bytesToInet6Address(byte[] addressBytes) {
@@ -273,7 +286,9 @@
         if (otDaemon == null) {
             throw new RemoteException("Internal error: failed to start OT daemon");
         }
-        otDaemon.initialize(mTunIfController.getTunFd());
+        otDaemon.initialize(
+                mTunIfController.getTunFd(),
+                mPersistentSettings.get(ThreadPersistentSettings.THREAD_ENABLED));
         otDaemon.registerStateCallback(mOtDaemonCallbackProxy, -1);
         otDaemon.asBinder().linkToDeath(() -> mHandler.post(this::onOtDaemonDied), 0);
         mOtDaemon = otDaemon;
@@ -308,6 +323,26 @@
                 });
     }
 
+    public void setEnabled(@NonNull boolean isEnabled, @NonNull IOperationReceiver receiver) {
+        enforceAllPermissionsGranted(PERMISSION_THREAD_NETWORK_PRIVILEGED);
+
+        mHandler.post(() -> setEnabledInternal(isEnabled, new OperationReceiverWrapper(receiver)));
+    }
+
+    private void setEnabledInternal(
+            @NonNull boolean isEnabled, @Nullable OperationReceiverWrapper receiver) {
+        // The persistent setting keeps the desired enabled state, thus it's set regardless
+        // the otDaemon set enabled state operation succeeded or not, so that it can recover
+        // to the desired value after reboot.
+        mPersistentSettings.put(ThreadPersistentSettings.THREAD_ENABLED.key, isEnabled);
+        try {
+            getOtDaemon().setThreadEnabled(isEnabled, newOtStatusReceiver(receiver));
+        } catch (RemoteException e) {
+            Log.e(TAG, "otDaemon.setThreadEnabled failed", e);
+            receiver.onError(ERROR_INTERNAL_ERROR, "Thread stack error");
+        }
+    }
+
     private void requestUpstreamNetwork() {
         if (mUpstreamNetworkCallback != null) {
             throw new AssertionError("The upstream network request is already there.");
@@ -658,6 +693,8 @@
                 return ERROR_REJECTED_BY_PEER;
             case OT_ERROR_UNSUPPORTED_CHANNEL:
                 return ERROR_UNSUPPORTED_CHANNEL;
+            case OT_ERROR_THREAD_DISABLED:
+                return ERROR_THREAD_DISABLED;
             default:
                 return ERROR_INTERNAL_ERROR;
         }
@@ -1001,6 +1038,15 @@
             }
         }
 
+        private void notifyThreadEnabledUpdated(IStateCallback callback, int enabledState) {
+            try {
+                callback.onThreadEnableStateChanged(enabledState);
+                Log.i(TAG, "onThreadEnableStateChanged " + enabledState);
+            } catch (RemoteException ignored) {
+                // do nothing if the client is dead
+            }
+        }
+
         public void unregisterStateCallback(IStateCallback callback) {
             checkOnHandlerThread();
             if (!mStateCallbacks.containsKey(callback)) {
@@ -1065,6 +1111,31 @@
         }
 
         @Override
+        public void onThreadEnabledChanged(int state) {
+            mHandler.post(() -> onThreadEnabledChangedInternal(state));
+        }
+
+        private void onThreadEnabledChangedInternal(int state) {
+            checkOnHandlerThread();
+            for (IStateCallback callback : mStateCallbacks.keySet()) {
+                notifyThreadEnabledUpdated(callback, otStateToAndroidState(state));
+            }
+        }
+
+        private static int otStateToAndroidState(int state) {
+            switch (state) {
+                case OT_STATE_ENABLED:
+                    return STATE_ENABLED;
+                case OT_STATE_DISABLED:
+                    return STATE_DISABLED;
+                case OT_STATE_DISABLING:
+                    return STATE_DISABLING;
+                default:
+                    throw new IllegalArgumentException("Unknown ot state " + state);
+            }
+        }
+
+        @Override
         public void onStateChanged(OtDaemonState newState, long listenerId) {
             mHandler.post(() -> onStateChangedInternal(newState, listenerId));
         }
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkCountryCode.java b/thread/service/java/com/android/server/thread/ThreadNetworkCountryCode.java
index a1484a4..ffa7b44 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkCountryCode.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkCountryCode.java
@@ -458,7 +458,7 @@
      *   <li>5. Location country code - Country code retrieved from LocationManager passive location
      *       provider.
      *   <li>6. OEM country code - Country code retrieved from the system property
-     *       `ro.boot.threadnetwork.country_code`.
+     *       `threadnetwork.country_code`.
      *   <li>7. Default country code `WW`.
      * </ul>
      *
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkService.java b/thread/service/java/com/android/server/thread/ThreadNetworkService.java
index 53f2d4f..5cf27f7 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkService.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkService.java
@@ -18,16 +18,21 @@
 
 import static android.content.pm.PackageManager.PERMISSION_GRANTED;
 
+import static com.android.net.module.util.DeviceConfigUtils.TETHERING_MODULE_NAME;
+
 import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.content.ApexEnvironment;
 import android.content.Context;
 import android.net.thread.IThreadNetworkController;
 import android.net.thread.IThreadNetworkManager;
 import android.os.Binder;
 import android.os.ParcelFileDescriptor;
+import android.util.AtomicFile;
 
 import com.android.server.SystemService;
 
+import java.io.File;
 import java.io.FileDescriptor;
 import java.io.PrintWriter;
 import java.util.Collections;
@@ -40,11 +45,18 @@
     private final Context mContext;
     @Nullable private ThreadNetworkCountryCode mCountryCode;
     @Nullable private ThreadNetworkControllerService mControllerService;
+    private final ThreadPersistentSettings mPersistentSettings;
     @Nullable private ThreadNetworkShellCommand mShellCommand;
 
     /** Creates a new {@link ThreadNetworkService} object. */
     public ThreadNetworkService(Context context) {
         mContext = context;
+        mPersistentSettings =
+                new ThreadPersistentSettings(
+                        new AtomicFile(
+                                new File(
+                                        getOrCreateThreadnetworkDir(),
+                                        ThreadPersistentSettings.FILE_NAME)));
     }
 
     /**
@@ -54,7 +66,9 @@
      */
     public void onBootPhase(int phase) {
         if (phase == SystemService.PHASE_SYSTEM_SERVICES_READY) {
-            mControllerService = ThreadNetworkControllerService.newInstance(mContext);
+            mPersistentSettings.initialize();
+            mControllerService =
+                    ThreadNetworkControllerService.newInstance(mContext, mPersistentSettings);
             mControllerService.initialize();
         } else if (phase == SystemService.PHASE_BOOT_COMPLETED) {
             // Country code initialization is delayed to the BOOT_COMPLETED phase because it will
@@ -109,4 +123,19 @@
 
         pw.println();
     }
+
+    /** Get device protected storage dir for the tethering apex. */
+    private static File getOrCreateThreadnetworkDir() {
+        final File threadnetworkDir;
+        final File apexDataDir =
+                ApexEnvironment.getApexEnvironment(TETHERING_MODULE_NAME)
+                        .getDeviceProtectedDataDir();
+        threadnetworkDir = new File(apexDataDir, "thread");
+
+        if (threadnetworkDir.exists() || threadnetworkDir.mkdirs()) {
+            return threadnetworkDir;
+        }
+        throw new IllegalStateException(
+                "Cannot write into thread network data directory: " + threadnetworkDir);
+    }
 }
diff --git a/thread/tests/cts/Android.bp b/thread/tests/cts/Android.bp
index 3cf31e5..2f38bfd 100644
--- a/thread/tests/cts/Android.bp
+++ b/thread/tests/cts/Android.bp
@@ -45,7 +45,7 @@
     libs: [
         "android.test.base",
         "android.test.runner",
-        "framework-connectivity-module-api-stubs-including-flagged"
+        "framework-connectivity-module-api-stubs-including-flagged",
     ],
     // Test coverage system runs on different devices. Need to
     // compile for all architectures.
diff --git a/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java b/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java
index 7a6c9aa..aab4b2e 100644
--- a/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java
+++ b/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java
@@ -16,11 +16,19 @@
 
 package android.net.thread.cts;
 
+import static android.Manifest.permission.ACCESS_NETWORK_STATE;
+import static android.net.thread.ThreadNetworkController.DEVICE_ROLE_CHILD;
+import static android.net.thread.ThreadNetworkController.DEVICE_ROLE_LEADER;
+import static android.net.thread.ThreadNetworkController.DEVICE_ROLE_ROUTER;
 import static android.net.thread.ThreadNetworkController.DEVICE_ROLE_STOPPED;
+import static android.net.thread.ThreadNetworkController.STATE_DISABLED;
+import static android.net.thread.ThreadNetworkController.STATE_DISABLING;
+import static android.net.thread.ThreadNetworkController.STATE_ENABLED;
 import static android.net.thread.ThreadNetworkController.THREAD_VERSION_1_3;
 import static android.net.thread.ThreadNetworkException.ERROR_ABORTED;
 import static android.net.thread.ThreadNetworkException.ERROR_FAILED_PRECONDITION;
 import static android.net.thread.ThreadNetworkException.ERROR_REJECTED_BY_PEER;
+import static android.net.thread.ThreadNetworkException.ERROR_THREAD_DISABLED;
 
 import static androidx.test.platform.app.InstrumentationRegistry.getInstrumentation;
 
@@ -29,12 +37,13 @@
 import static com.google.common.truth.Truth.assertThat;
 import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
 
+import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.fail;
 import static org.junit.Assume.assumeNotNull;
 
 import static java.util.concurrent.TimeUnit.MILLISECONDS;
 
-import android.Manifest.permission;
 import android.content.Context;
 import android.net.ConnectivityManager;
 import android.net.Network;
@@ -54,11 +63,11 @@
 import androidx.test.core.app.ApplicationProvider;
 import androidx.test.filters.LargeTest;
 
+import com.android.net.module.util.ArrayTrackRecord;
 import com.android.testutils.DevSdkIgnoreRule;
 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo;
 import com.android.testutils.DevSdkIgnoreRunner;
-
-import com.google.common.util.concurrent.SettableFuture;
+import com.android.testutils.FunctionalUtils.ThrowingRunnable;
 
 import org.junit.After;
 import org.junit.Before;
@@ -68,9 +77,11 @@
 
 import java.time.Duration;
 import java.time.Instant;
+import java.util.ArrayList;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
+import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
@@ -81,52 +92,714 @@
 @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU) // Thread is available on only U+
 public class ThreadNetworkControllerTest {
     private static final int JOIN_TIMEOUT_MILLIS = 30 * 1000;
+    private static final int LEAVE_TIMEOUT_MILLIS = 2_000;
+    private static final int MIGRATION_TIMEOUT_MILLIS = 40 * 1_000;
     private static final int NETWORK_CALLBACK_TIMEOUT_MILLIS = 10 * 1000;
-    private static final int CALLBACK_TIMEOUT_MILLIS = 1000;
-    private static final String PERMISSION_THREAD_NETWORK_PRIVILEGED =
+    private static final int CALLBACK_TIMEOUT_MILLIS = 1_000;
+    private static final int ENABLED_TIMEOUT_MILLIS = 2_000;
+    private static final String THREAD_NETWORK_PRIVILEGED =
             "android.permission.THREAD_NETWORK_PRIVILEGED";
 
     @Rule public DevSdkIgnoreRule mIgnoreRule = new DevSdkIgnoreRule();
 
     private final Context mContext = ApplicationProvider.getApplicationContext();
     private ExecutorService mExecutor;
-    private ThreadNetworkManager mManager;
+    private ThreadNetworkController mController;
 
     private Set<String> mGrantedPermissions;
 
     @Before
-    public void setUp() {
-        mExecutor = Executors.newSingleThreadExecutor();
-        mManager = mContext.getSystemService(ThreadNetworkManager.class);
+    public void setUp() throws Exception {
+
         mGrantedPermissions = new HashSet<String>();
+        mExecutor = Executors.newSingleThreadExecutor();
+        ThreadNetworkManager manager = mContext.getSystemService(ThreadNetworkManager.class);
+        if (manager != null) {
+            mController = manager.getAllThreadNetworkControllers().get(0);
+        }
 
         // TODO: we will also need it in tearDown(), it's better to have a Rule to skip
         // tests if a feature is not available.
-        assumeNotNull(mManager);
+        assumeNotNull(mController);
+
+        setEnabledAndWait(mController, true);
     }
 
     @After
     public void tearDown() throws Exception {
-        if (mManager != null) {
-            leaveAndWait();
+        if (mController != null) {
+            grantPermissions(THREAD_NETWORK_PRIVILEGED);
+            CompletableFuture<Void> future = new CompletableFuture<>();
+            mController.leave(mExecutor, future::complete);
+            future.get(LEAVE_TIMEOUT_MILLIS, MILLISECONDS);
+        }
+        dropAllPermissions();
+    }
+
+    @Test
+    public void getThreadVersion_returnsAtLeastThreadVersion1P3() {
+        assertThat(mController.getThreadVersion()).isAtLeast(THREAD_VERSION_1_3);
+    }
+
+    @Test
+    public void registerStateCallback_permissionsGranted_returnsCurrentStates() throws Exception {
+        CompletableFuture<Integer> deviceRole = new CompletableFuture<>();
+        StateCallback callback = deviceRole::complete;
+
+        try {
+            runAsShell(
+                    ACCESS_NETWORK_STATE,
+                    () -> mController.registerStateCallback(mExecutor, callback));
+
+            assertThat(deviceRole.get(CALLBACK_TIMEOUT_MILLIS, MILLISECONDS))
+                    .isEqualTo(DEVICE_ROLE_STOPPED);
+        } finally {
+            runAsShell(ACCESS_NETWORK_STATE, () -> mController.unregisterStateCallback(callback));
+        }
+    }
+
+    @Test
+    public void registerStateCallback_returnsUpdatedEnabledStates() throws Exception {
+        CompletableFuture<Void> setFuture1 = new CompletableFuture<>();
+        CompletableFuture<Void> setFuture2 = new CompletableFuture<>();
+        EnabledStateListener listener = new EnabledStateListener(mController);
+
+        try {
+            runAsShell(
+                    THREAD_NETWORK_PRIVILEGED,
+                    () -> {
+                        mController.setEnabled(false, mExecutor, newOutcomeReceiver(setFuture1));
+                    });
+            setFuture1.get(ENABLED_TIMEOUT_MILLIS, MILLISECONDS);
+
+            runAsShell(
+                    THREAD_NETWORK_PRIVILEGED,
+                    () -> {
+                        mController.setEnabled(true, mExecutor, newOutcomeReceiver(setFuture2));
+                    });
+            setFuture2.get(ENABLED_TIMEOUT_MILLIS, MILLISECONDS);
+
+            listener.expectThreadEnabledState(STATE_ENABLED);
+            listener.expectThreadEnabledState(STATE_DISABLING);
+            listener.expectThreadEnabledState(STATE_DISABLED);
+            listener.expectThreadEnabledState(STATE_ENABLED);
+        } finally {
+            listener.unregisterStateCallback();
+        }
+    }
+
+    @Test
+    public void registerStateCallback_noPermissions_throwsSecurityException() throws Exception {
+        dropAllPermissions();
+
+        assertThrows(
+                SecurityException.class,
+                () -> mController.registerStateCallback(mExecutor, role -> {}));
+    }
+
+    @Test
+    public void registerStateCallback_alreadyRegistered_throwsIllegalArgumentException()
+            throws Exception {
+        grantPermissions(ACCESS_NETWORK_STATE);
+        CompletableFuture<Integer> deviceRole = new CompletableFuture<>();
+        StateCallback callback = role -> deviceRole.complete(role);
+
+        mController.registerStateCallback(mExecutor, callback);
+
+        assertThrows(
+                IllegalArgumentException.class,
+                () -> mController.registerStateCallback(mExecutor, callback));
+    }
+
+    @Test
+    public void unregisterStateCallback_noPermissions_throwsSecurityException() throws Exception {
+        CompletableFuture<Integer> deviceRole = new CompletableFuture<>();
+        StateCallback callback = role -> deviceRole.complete(role);
+        runAsShell(
+                ACCESS_NETWORK_STATE, () -> mController.registerStateCallback(mExecutor, callback));
+
+        try {
             dropAllPermissions();
+            assertThrows(
+                    SecurityException.class, () -> mController.unregisterStateCallback(callback));
+        } finally {
+            runAsShell(ACCESS_NETWORK_STATE, () -> mController.unregisterStateCallback(callback));
         }
     }
 
-    private List<ThreadNetworkController> getAllControllers() {
-        return mManager.getAllThreadNetworkControllers();
+    @Test
+    public void unregisterStateCallback_callbackRegistered_success() throws Exception {
+        grantPermissions(ACCESS_NETWORK_STATE);
+        CompletableFuture<Integer> deviceRole = new CompletableFuture<>();
+        StateCallback callback = role -> deviceRole.complete(role);
+
+        assertDoesNotThrow(() -> mController.registerStateCallback(mExecutor, callback));
+        mController.unregisterStateCallback(callback);
     }
 
-    private void leaveAndWait() throws Exception {
-        grantPermissions(PERMISSION_THREAD_NETWORK_PRIVILEGED);
+    @Test
+    public void unregisterStateCallback_callbackNotRegistered_throwsIllegalArgumentException()
+            throws Exception {
+        CompletableFuture<Integer> deviceRole = new CompletableFuture<>();
+        StateCallback callback = role -> deviceRole.complete(role);
 
-        for (ThreadNetworkController controller : getAllControllers()) {
-            SettableFuture<Void> future = SettableFuture.create();
-            controller.leave(mExecutor, future::set);
-            future.get();
+        assertThrows(
+                IllegalArgumentException.class,
+                () -> mController.unregisterStateCallback(callback));
+    }
+
+    @Test
+    public void unregisterStateCallback_alreadyUnregistered_throwsIllegalArgumentException()
+            throws Exception {
+        grantPermissions(ACCESS_NETWORK_STATE);
+        CompletableFuture<Integer> deviceRole = new CompletableFuture<>();
+        StateCallback callback = deviceRole::complete;
+        mController.registerStateCallback(mExecutor, callback);
+        mController.unregisterStateCallback(callback);
+
+        assertThrows(
+                IllegalArgumentException.class,
+                () -> mController.unregisterStateCallback(callback));
+    }
+
+    @Test
+    public void registerOperationalDatasetCallback_permissionsGranted_returnsCurrentStates()
+            throws Exception {
+        grantPermissions(ACCESS_NETWORK_STATE, THREAD_NETWORK_PRIVILEGED);
+        CompletableFuture<ActiveOperationalDataset> activeFuture = new CompletableFuture<>();
+        CompletableFuture<PendingOperationalDataset> pendingFuture = new CompletableFuture<>();
+        var callback = newDatasetCallback(activeFuture, pendingFuture);
+
+        try {
+            mController.registerOperationalDatasetCallback(mExecutor, callback);
+
+            assertThat(activeFuture.get(CALLBACK_TIMEOUT_MILLIS, MILLISECONDS)).isNull();
+            assertThat(pendingFuture.get(CALLBACK_TIMEOUT_MILLIS, MILLISECONDS)).isNull();
+        } finally {
+            mController.unregisterOperationalDatasetCallback(callback);
         }
     }
 
+    @Test
+    public void registerOperationalDatasetCallback_noPermissions_throwsSecurityException()
+            throws Exception {
+        dropAllPermissions();
+        CompletableFuture<ActiveOperationalDataset> activeFuture = new CompletableFuture<>();
+        CompletableFuture<PendingOperationalDataset> pendingFuture = new CompletableFuture<>();
+        var callback = newDatasetCallback(activeFuture, pendingFuture);
+
+        assertThrows(
+                SecurityException.class,
+                () -> mController.registerOperationalDatasetCallback(mExecutor, callback));
+    }
+
+    @Test
+    public void unregisterOperationalDatasetCallback_callbackRegistered_success() throws Exception {
+        grantPermissions(ACCESS_NETWORK_STATE, THREAD_NETWORK_PRIVILEGED);
+        CompletableFuture<ActiveOperationalDataset> activeFuture = new CompletableFuture<>();
+        CompletableFuture<PendingOperationalDataset> pendingFuture = new CompletableFuture<>();
+        var callback = newDatasetCallback(activeFuture, pendingFuture);
+        mController.registerOperationalDatasetCallback(mExecutor, callback);
+
+        assertDoesNotThrow(() -> mController.unregisterOperationalDatasetCallback(callback));
+    }
+
+    @Test
+    public void unregisterOperationalDatasetCallback_noPermissions_throwsSecurityException()
+            throws Exception {
+        CompletableFuture<ActiveOperationalDataset> activeFuture = new CompletableFuture<>();
+        CompletableFuture<PendingOperationalDataset> pendingFuture = new CompletableFuture<>();
+        var callback = newDatasetCallback(activeFuture, pendingFuture);
+        runAsShell(
+                ACCESS_NETWORK_STATE,
+                THREAD_NETWORK_PRIVILEGED,
+                () -> mController.registerOperationalDatasetCallback(mExecutor, callback));
+
+        try {
+            dropAllPermissions();
+            assertThrows(
+                    SecurityException.class,
+                    () -> mController.unregisterOperationalDatasetCallback(callback));
+        } finally {
+            runAsShell(
+                    ACCESS_NETWORK_STATE,
+                    THREAD_NETWORK_PRIVILEGED,
+                    () -> mController.unregisterOperationalDatasetCallback(callback));
+        }
+    }
+
+    private static <V> OutcomeReceiver<V, ThreadNetworkException> newOutcomeReceiver(
+            CompletableFuture<V> future) {
+        return new OutcomeReceiver<V, ThreadNetworkException>() {
+            @Override
+            public void onResult(V result) {
+                future.complete(result);
+            }
+
+            @Override
+            public void onError(ThreadNetworkException e) {
+                future.completeExceptionally(e);
+            }
+        };
+    }
+
+    @Test
+    public void join_withPrivilegedPermission_success() throws Exception {
+        ActiveOperationalDataset activeDataset = newRandomizedDataset("TestNet", mController);
+        CompletableFuture<Void> joinFuture = new CompletableFuture<>();
+
+        runAsShell(
+                THREAD_NETWORK_PRIVILEGED,
+                () -> mController.join(activeDataset, mExecutor, newOutcomeReceiver(joinFuture)));
+        joinFuture.get(JOIN_TIMEOUT_MILLIS, MILLISECONDS);
+
+        assertThat(isAttached(mController)).isTrue();
+        assertThat(getActiveOperationalDataset(mController)).isEqualTo(activeDataset);
+    }
+
+    @Test
+    public void join_withoutPrivilegedPermission_throwsSecurityException() throws Exception {
+        dropAllPermissions();
+        ActiveOperationalDataset activeDataset = newRandomizedDataset("TestNet", mController);
+
+        assertThrows(
+                SecurityException.class, () -> mController.join(activeDataset, mExecutor, v -> {}));
+    }
+
+    @Test
+    public void join_threadDisabled_failsWithErrorThreadDisabled() throws Exception {
+        setEnabledAndWait(mController, false);
+        ActiveOperationalDataset activeDataset = newRandomizedDataset("TestNet", mController);
+        CompletableFuture<Void> joinFuture = new CompletableFuture<>();
+
+        runAsShell(
+                THREAD_NETWORK_PRIVILEGED,
+                () -> mController.join(activeDataset, mExecutor, newOutcomeReceiver(joinFuture)));
+
+        var thrown =
+                assertThrows(
+                        ExecutionException.class,
+                        () -> joinFuture.get(JOIN_TIMEOUT_MILLIS, MILLISECONDS));
+        var threadException = (ThreadNetworkException) thrown.getCause();
+        assertThat(threadException.getErrorCode()).isEqualTo(ERROR_THREAD_DISABLED);
+    }
+
+    @Test
+    public void join_concurrentRequests_firstOneIsAborted() throws Exception {
+        final byte[] KEY_1 = new byte[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
+        final byte[] KEY_2 = new byte[] {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2};
+        ActiveOperationalDataset activeDataset1 =
+                new ActiveOperationalDataset.Builder(newRandomizedDataset("TestNet", mController))
+                        .setNetworkKey(KEY_1)
+                        .build();
+        ActiveOperationalDataset activeDataset2 =
+                new ActiveOperationalDataset.Builder(activeDataset1).setNetworkKey(KEY_2).build();
+        CompletableFuture<Void> joinFuture1 = new CompletableFuture<>();
+        CompletableFuture<Void> joinFuture2 = new CompletableFuture<>();
+
+        runAsShell(
+                THREAD_NETWORK_PRIVILEGED,
+                () -> {
+                    mController.join(activeDataset1, mExecutor, newOutcomeReceiver(joinFuture1));
+                    mController.join(activeDataset2, mExecutor, newOutcomeReceiver(joinFuture2));
+                });
+
+        var thrown =
+                assertThrows(
+                        ExecutionException.class,
+                        () -> joinFuture1.get(JOIN_TIMEOUT_MILLIS, MILLISECONDS));
+        var threadException = (ThreadNetworkException) thrown.getCause();
+        assertThat(threadException.getErrorCode()).isEqualTo(ERROR_ABORTED);
+        joinFuture2.get(JOIN_TIMEOUT_MILLIS, MILLISECONDS);
+        assertThat(isAttached(mController)).isTrue();
+        assertThat(getActiveOperationalDataset(mController)).isEqualTo(activeDataset2);
+    }
+
+    @Test
+    public void leave_withPrivilegedPermission_success() throws Exception {
+        CompletableFuture<Void> leaveFuture = new CompletableFuture<>();
+        joinRandomizedDatasetAndWait(mController);
+
+        runAsShell(
+                THREAD_NETWORK_PRIVILEGED,
+                () -> mController.leave(mExecutor, newOutcomeReceiver(leaveFuture)));
+        leaveFuture.get(LEAVE_TIMEOUT_MILLIS, MILLISECONDS);
+
+        assertThat(getDeviceRole(mController)).isEqualTo(DEVICE_ROLE_STOPPED);
+    }
+
+    @Test
+    public void leave_withoutPrivilegedPermission_throwsSecurityException() {
+        dropAllPermissions();
+
+        assertThrows(SecurityException.class, () -> mController.leave(mExecutor, v -> {}));
+    }
+
+    @Test
+    public void leave_threadDisabled_success() throws Exception {
+        setEnabledAndWait(mController, false);
+        CompletableFuture<Void> leaveFuture = new CompletableFuture<>();
+
+        leave(mController, newOutcomeReceiver(leaveFuture));
+        leaveFuture.get(LEAVE_TIMEOUT_MILLIS, MILLISECONDS);
+
+        assertThat(getDeviceRole(mController)).isEqualTo(DEVICE_ROLE_STOPPED);
+    }
+
+    @Test
+    public void leave_concurrentRequests_bothSuccess() throws Exception {
+        CompletableFuture<Void> leaveFuture1 = new CompletableFuture<>();
+        CompletableFuture<Void> leaveFuture2 = new CompletableFuture<>();
+        joinRandomizedDatasetAndWait(mController);
+
+        runAsShell(
+                THREAD_NETWORK_PRIVILEGED,
+                () -> {
+                    mController.leave(mExecutor, newOutcomeReceiver(leaveFuture1));
+                    mController.leave(mExecutor, newOutcomeReceiver(leaveFuture2));
+                });
+
+        leaveFuture1.get(LEAVE_TIMEOUT_MILLIS, MILLISECONDS);
+        leaveFuture2.get(LEAVE_TIMEOUT_MILLIS, MILLISECONDS);
+        assertThat(getDeviceRole(mController)).isEqualTo(DEVICE_ROLE_STOPPED);
+    }
+
+    @Test
+    public void scheduleMigration_withPrivilegedPermission_newDatasetApplied() throws Exception {
+        grantPermissions(ACCESS_NETWORK_STATE, THREAD_NETWORK_PRIVILEGED);
+        ActiveOperationalDataset activeDataset1 =
+                new ActiveOperationalDataset.Builder(newRandomizedDataset("TestNet", mController))
+                        .setActiveTimestamp(new OperationalDatasetTimestamp(1L, 0, false))
+                        .setExtendedPanId(new byte[] {1, 1, 1, 1, 1, 1, 1, 1})
+                        .build();
+        ActiveOperationalDataset activeDataset2 =
+                new ActiveOperationalDataset.Builder(activeDataset1)
+                        .setActiveTimestamp(new OperationalDatasetTimestamp(2L, 0, false))
+                        .setNetworkName("ThreadNet2")
+                        .build();
+        PendingOperationalDataset pendingDataset2 =
+                new PendingOperationalDataset(
+                        activeDataset2,
+                        OperationalDatasetTimestamp.fromInstant(Instant.now()),
+                        Duration.ofSeconds(30));
+        CompletableFuture<Void> joinFuture = new CompletableFuture<>();
+        CompletableFuture<Void> migrateFuture = new CompletableFuture<>();
+        mController.join(activeDataset1, mExecutor, newOutcomeReceiver(joinFuture));
+        joinFuture.get(JOIN_TIMEOUT_MILLIS, MILLISECONDS);
+
+        mController.scheduleMigration(
+                pendingDataset2, mExecutor, newOutcomeReceiver(migrateFuture));
+        migrateFuture.get(CALLBACK_TIMEOUT_MILLIS, MILLISECONDS);
+
+        CompletableFuture<Boolean> dataset2IsApplied = new CompletableFuture<>();
+        CompletableFuture<Boolean> pendingDatasetIsRemoved = new CompletableFuture<>();
+        OperationalDatasetCallback datasetCallback =
+                new OperationalDatasetCallback() {
+                    @Override
+                    public void onActiveOperationalDatasetChanged(
+                            ActiveOperationalDataset activeDataset) {
+                        if (activeDataset.equals(activeDataset2)) {
+                            dataset2IsApplied.complete(true);
+                        }
+                    }
+
+                    @Override
+                    public void onPendingOperationalDatasetChanged(
+                            PendingOperationalDataset pendingDataset) {
+                        if (pendingDataset == null) {
+                            pendingDatasetIsRemoved.complete(true);
+                        }
+                    }
+                };
+        mController.registerOperationalDatasetCallback(directExecutor(), datasetCallback);
+        try {
+            assertThat(dataset2IsApplied.get(MIGRATION_TIMEOUT_MILLIS, MILLISECONDS)).isTrue();
+            assertThat(pendingDatasetIsRemoved.get(CALLBACK_TIMEOUT_MILLIS, MILLISECONDS)).isTrue();
+        } finally {
+            mController.unregisterOperationalDatasetCallback(datasetCallback);
+        }
+    }
+
+    @Test
+    public void scheduleMigration_whenNotAttached_failWithPreconditionError() throws Exception {
+        grantPermissions(ACCESS_NETWORK_STATE, THREAD_NETWORK_PRIVILEGED);
+        PendingOperationalDataset pendingDataset =
+                new PendingOperationalDataset(
+                        newRandomizedDataset("TestNet", mController),
+                        OperationalDatasetTimestamp.fromInstant(Instant.now()),
+                        Duration.ofSeconds(30));
+        CompletableFuture<Void> migrateFuture = new CompletableFuture<>();
+
+        mController.scheduleMigration(pendingDataset, mExecutor, newOutcomeReceiver(migrateFuture));
+
+        ThreadNetworkException thrown =
+                (ThreadNetworkException)
+                        assertThrows(ExecutionException.class, migrateFuture::get).getCause();
+        assertThat(thrown.getErrorCode()).isEqualTo(ERROR_FAILED_PRECONDITION);
+    }
+
+    @Test
+    public void scheduleMigration_secondRequestHasSmallerTimestamp_rejectedByLeader()
+            throws Exception {
+        grantPermissions(ACCESS_NETWORK_STATE, THREAD_NETWORK_PRIVILEGED);
+        final ActiveOperationalDataset activeDataset =
+                new ActiveOperationalDataset.Builder(newRandomizedDataset("testNet", mController))
+                        .setActiveTimestamp(new OperationalDatasetTimestamp(1L, 0, false))
+                        .build();
+        ActiveOperationalDataset activeDataset1 =
+                new ActiveOperationalDataset.Builder(activeDataset)
+                        .setActiveTimestamp(new OperationalDatasetTimestamp(2L, 0, false))
+                        .setNetworkName("testNet1")
+                        .build();
+        PendingOperationalDataset pendingDataset1 =
+                new PendingOperationalDataset(
+                        activeDataset1,
+                        new OperationalDatasetTimestamp(100, 0, false),
+                        Duration.ofSeconds(30));
+        ActiveOperationalDataset activeDataset2 =
+                new ActiveOperationalDataset.Builder(activeDataset)
+                        .setActiveTimestamp(new OperationalDatasetTimestamp(3L, 0, false))
+                        .setNetworkName("testNet2")
+                        .build();
+        PendingOperationalDataset pendingDataset2 =
+                new PendingOperationalDataset(
+                        activeDataset2,
+                        new OperationalDatasetTimestamp(20, 0, false),
+                        Duration.ofSeconds(30));
+        CompletableFuture<Void> joinFuture = new CompletableFuture<>();
+        CompletableFuture<Void> migrateFuture1 = new CompletableFuture<>();
+        CompletableFuture<Void> migrateFuture2 = new CompletableFuture<>();
+        mController.join(activeDataset, mExecutor, newOutcomeReceiver(joinFuture));
+        joinFuture.get(JOIN_TIMEOUT_MILLIS, MILLISECONDS);
+
+        mController.scheduleMigration(
+                pendingDataset1, mExecutor, newOutcomeReceiver(migrateFuture1));
+        migrateFuture1.get(CALLBACK_TIMEOUT_MILLIS, MILLISECONDS);
+        mController.scheduleMigration(
+                pendingDataset2, mExecutor, newOutcomeReceiver(migrateFuture2));
+
+        ThreadNetworkException thrown =
+                (ThreadNetworkException)
+                        assertThrows(ExecutionException.class, migrateFuture2::get).getCause();
+        assertThat(thrown.getErrorCode()).isEqualTo(ERROR_REJECTED_BY_PEER);
+    }
+
+    @Test
+    public void scheduleMigration_secondRequestHasLargerTimestamp_newDatasetApplied()
+            throws Exception {
+        grantPermissions(ACCESS_NETWORK_STATE, THREAD_NETWORK_PRIVILEGED);
+        final ActiveOperationalDataset activeDataset =
+                new ActiveOperationalDataset.Builder(newRandomizedDataset("validName", mController))
+                        .setActiveTimestamp(new OperationalDatasetTimestamp(1L, 0, false))
+                        .build();
+        ActiveOperationalDataset activeDataset1 =
+                new ActiveOperationalDataset.Builder(activeDataset)
+                        .setActiveTimestamp(new OperationalDatasetTimestamp(2L, 0, false))
+                        .setNetworkName("testNet1")
+                        .build();
+        PendingOperationalDataset pendingDataset1 =
+                new PendingOperationalDataset(
+                        activeDataset1,
+                        new OperationalDatasetTimestamp(100, 0, false),
+                        Duration.ofSeconds(30));
+        ActiveOperationalDataset activeDataset2 =
+                new ActiveOperationalDataset.Builder(activeDataset)
+                        .setActiveTimestamp(new OperationalDatasetTimestamp(3L, 0, false))
+                        .setNetworkName("testNet2")
+                        .build();
+        PendingOperationalDataset pendingDataset2 =
+                new PendingOperationalDataset(
+                        activeDataset2,
+                        new OperationalDatasetTimestamp(200, 0, false),
+                        Duration.ofSeconds(30));
+        CompletableFuture<Void> joinFuture = new CompletableFuture<>();
+        CompletableFuture<Void> migrateFuture1 = new CompletableFuture<>();
+        CompletableFuture<Void> migrateFuture2 = new CompletableFuture<>();
+        mController.join(activeDataset, mExecutor, newOutcomeReceiver(joinFuture));
+        joinFuture.get(JOIN_TIMEOUT_MILLIS, MILLISECONDS);
+
+        mController.scheduleMigration(
+                pendingDataset1, mExecutor, newOutcomeReceiver(migrateFuture1));
+        migrateFuture1.get(CALLBACK_TIMEOUT_MILLIS, MILLISECONDS);
+        mController.scheduleMigration(
+                pendingDataset2, mExecutor, newOutcomeReceiver(migrateFuture2));
+        migrateFuture2.get(CALLBACK_TIMEOUT_MILLIS, MILLISECONDS);
+
+        CompletableFuture<Boolean> dataset2IsApplied = new CompletableFuture<>();
+        CompletableFuture<Boolean> pendingDatasetIsRemoved = new CompletableFuture<>();
+        OperationalDatasetCallback datasetCallback =
+                new OperationalDatasetCallback() {
+                    @Override
+                    public void onActiveOperationalDatasetChanged(
+                            ActiveOperationalDataset activeDataset) {
+                        if (activeDataset.equals(activeDataset2)) {
+                            dataset2IsApplied.complete(true);
+                        }
+                    }
+
+                    @Override
+                    public void onPendingOperationalDatasetChanged(
+                            PendingOperationalDataset pendingDataset) {
+                        if (pendingDataset == null) {
+                            pendingDatasetIsRemoved.complete(true);
+                        }
+                    }
+                };
+        mController.registerOperationalDatasetCallback(directExecutor(), datasetCallback);
+        try {
+            assertThat(dataset2IsApplied.get(MIGRATION_TIMEOUT_MILLIS, MILLISECONDS)).isTrue();
+            assertThat(pendingDatasetIsRemoved.get(CALLBACK_TIMEOUT_MILLIS, MILLISECONDS)).isTrue();
+        } finally {
+            mController.unregisterOperationalDatasetCallback(datasetCallback);
+        }
+    }
+
+    @Test
+    public void scheduleMigration_threadDisabled_failsWithErrorThreadDisabled() throws Exception {
+        ActiveOperationalDataset activeDataset = newRandomizedDataset("TestNet", mController);
+        PendingOperationalDataset pendingDataset =
+                new PendingOperationalDataset(
+                        activeDataset,
+                        OperationalDatasetTimestamp.fromInstant(Instant.now()),
+                        Duration.ofSeconds(30));
+        joinRandomizedDatasetAndWait(mController);
+        CompletableFuture<Void> migrationFuture = new CompletableFuture<>();
+
+        setEnabledAndWait(mController, false);
+
+        scheduleMigration(mController, pendingDataset, newOutcomeReceiver(migrationFuture));
+
+        ThreadNetworkException thrown =
+                (ThreadNetworkException)
+                        assertThrows(ExecutionException.class, migrationFuture::get).getCause();
+        assertThat(thrown.getErrorCode()).isEqualTo(ERROR_THREAD_DISABLED);
+    }
+
+    @Test
+    public void createRandomizedDataset_wrongNetworkNameLength_throwsIllegalArgumentException() {
+        assertThrows(
+                IllegalArgumentException.class,
+                () -> mController.createRandomizedDataset("", mExecutor, dataset -> {}));
+
+        assertThrows(
+                IllegalArgumentException.class,
+                () ->
+                        mController.createRandomizedDataset(
+                                "ANetNameIs17Bytes", mExecutor, dataset -> {}));
+    }
+
+    @Test
+    public void createRandomizedDataset_validNetworkName_success() throws Exception {
+        ActiveOperationalDataset dataset = newRandomizedDataset("validName", mController);
+
+        assertThat(dataset.getNetworkName()).isEqualTo("validName");
+        assertThat(dataset.getPanId()).isLessThan(0xffff);
+        assertThat(dataset.getChannelMask().size()).isAtLeast(1);
+        assertThat(dataset.getExtendedPanId()).hasLength(8);
+        assertThat(dataset.getNetworkKey()).hasLength(16);
+        assertThat(dataset.getPskc()).hasLength(16);
+        assertThat(dataset.getMeshLocalPrefix().getPrefixLength()).isEqualTo(64);
+        assertThat(dataset.getMeshLocalPrefix().getRawAddress()[0]).isEqualTo((byte) 0xfd);
+    }
+
+    @Test
+    public void setEnabled_permissionsGranted_succeeds() throws Exception {
+        CompletableFuture<Void> setFuture1 = new CompletableFuture<>();
+        CompletableFuture<Void> setFuture2 = new CompletableFuture<>();
+
+        runAsShell(
+                THREAD_NETWORK_PRIVILEGED,
+                () -> mController.setEnabled(false, mExecutor, newOutcomeReceiver(setFuture1)));
+        setFuture1.get(ENABLED_TIMEOUT_MILLIS, MILLISECONDS);
+        waitForEnabledState(mController, booleanToEnabledState(false));
+
+        runAsShell(
+                THREAD_NETWORK_PRIVILEGED,
+                () -> mController.setEnabled(true, mExecutor, newOutcomeReceiver(setFuture2)));
+        setFuture2.get(ENABLED_TIMEOUT_MILLIS, MILLISECONDS);
+        waitForEnabledState(mController, booleanToEnabledState(true));
+    }
+
+    @Test
+    public void setEnabled_noPermissions_throwsSecurityException() throws Exception {
+        CompletableFuture<Void> setFuture = new CompletableFuture<>();
+        assertThrows(
+                SecurityException.class,
+                () -> mController.setEnabled(false, mExecutor, newOutcomeReceiver(setFuture)));
+    }
+
+    @Test
+    public void setEnabled_disable_leavesThreadNetwork() throws Exception {
+        joinRandomizedDatasetAndWait(mController);
+        setEnabledAndWait(mController, false);
+        assertThat(getDeviceRole(mController)).isEqualTo(DEVICE_ROLE_STOPPED);
+    }
+
+    @Test
+    public void setEnabled_toggleAfterJoin_joinsThreadNetworkAgain() throws Exception {
+        joinRandomizedDatasetAndWait(mController);
+
+        setEnabledAndWait(mController, false);
+        assertThat(getDeviceRole(mController)).isEqualTo(DEVICE_ROLE_STOPPED);
+        setEnabledAndWait(mController, true);
+
+        runAsShell(ACCESS_NETWORK_STATE, () -> waitForAttachedState(mController));
+    }
+
+    @Test
+    public void setEnabled_enableFollowedByDisable_allSucceed() throws Exception {
+        joinRandomizedDatasetAndWait(mController);
+        CompletableFuture<Void> setFuture1 = new CompletableFuture<>();
+        CompletableFuture<Void> setFuture2 = new CompletableFuture<>();
+        EnabledStateListener listener = new EnabledStateListener(mController);
+        listener.expectThreadEnabledState(STATE_ENABLED);
+
+        runAsShell(
+                THREAD_NETWORK_PRIVILEGED,
+                () -> {
+                    mController.setEnabled(true, mExecutor, newOutcomeReceiver(setFuture1));
+                    mController.setEnabled(false, mExecutor, newOutcomeReceiver(setFuture2));
+                });
+        setFuture1.get(ENABLED_TIMEOUT_MILLIS, MILLISECONDS);
+        setFuture2.get(ENABLED_TIMEOUT_MILLIS, MILLISECONDS);
+
+        listener.expectThreadEnabledState(STATE_DISABLING);
+        listener.expectThreadEnabledState(STATE_DISABLED);
+        assertThat(getDeviceRole(mController)).isEqualTo(DEVICE_ROLE_STOPPED);
+        // FIXME: this is not called when a exception is thrown after the creation of `listener`
+        listener.unregisterStateCallback();
+    }
+
+    // TODO (b/322437869): add test case to verify when Thread is in DISABLING state, any commands
+    // (join/leave/scheduleMigration/setEnabled) fail with ERROR_BUSY. This is not currently tested
+    // because DISABLING has very short lifecycle, it's not possible to guarantee the command can be
+    // sent before state changes to DISABLED.
+
+    @Test
+    public void threadNetworkCallback_deviceAttached_threadNetworkIsAvailable() throws Exception {
+        CompletableFuture<Network> networkFuture = new CompletableFuture<>();
+        ConnectivityManager cm = mContext.getSystemService(ConnectivityManager.class);
+        NetworkRequest networkRequest =
+                new NetworkRequest.Builder()
+                        .addTransportType(NetworkCapabilities.TRANSPORT_THREAD)
+                        .build();
+        ConnectivityManager.NetworkCallback networkCallback =
+                new ConnectivityManager.NetworkCallback() {
+                    @Override
+                    public void onAvailable(Network network) {
+                        networkFuture.complete(network);
+                    }
+                };
+
+        joinRandomizedDatasetAndWait(mController);
+        runAsShell(
+                ACCESS_NETWORK_STATE,
+                () -> cm.registerNetworkCallback(networkRequest, networkCallback));
+
+        assertThat(isAttached(mController)).isTrue();
+        assertThat(networkFuture.get(NETWORK_CALLBACK_TIMEOUT_MILLIS, MILLISECONDS)).isNotNull();
+    }
+
     private void grantPermissions(String... permissions) {
         for (String permission : permissions) {
             mGrantedPermissions.add(permission);
@@ -142,8 +815,8 @@
 
     private static ActiveOperationalDataset newRandomizedDataset(
             String networkName, ThreadNetworkController controller) throws Exception {
-        SettableFuture<ActiveOperationalDataset> future = SettableFuture.create();
-        controller.createRandomizedDataset(networkName, directExecutor(), future::set);
+        CompletableFuture<ActiveOperationalDataset> future = new CompletableFuture<>();
+        controller.createRandomizedDataset(networkName, directExecutor(), future::complete);
         return future.get(CALLBACK_TIMEOUT_MILLIS, MILLISECONDS);
     }
 
@@ -152,642 +825,189 @@
     }
 
     private static int getDeviceRole(ThreadNetworkController controller) throws Exception {
-        SettableFuture<Integer> future = SettableFuture.create();
-        StateCallback callback = future::set;
-        controller.registerStateCallback(directExecutor(), callback);
-        int role = future.get(CALLBACK_TIMEOUT_MILLIS, MILLISECONDS);
-        controller.unregisterStateCallback(callback);
-        return role;
+        CompletableFuture<Integer> future = new CompletableFuture<>();
+        StateCallback callback = future::complete;
+        runAsShell(
+                ACCESS_NETWORK_STATE,
+                () -> controller.registerStateCallback(directExecutor(), callback));
+        try {
+            return future.get(CALLBACK_TIMEOUT_MILLIS, MILLISECONDS);
+        } finally {
+            runAsShell(ACCESS_NETWORK_STATE, () -> controller.unregisterStateCallback(callback));
+        }
+    }
+
+    private static int waitForAttachedState(ThreadNetworkController controller) throws Exception {
+        List<Integer> attachedRoles = new ArrayList<>();
+        attachedRoles.add(DEVICE_ROLE_CHILD);
+        attachedRoles.add(DEVICE_ROLE_ROUTER);
+        attachedRoles.add(DEVICE_ROLE_LEADER);
+        return waitForStateAnyOf(controller, attachedRoles);
     }
 
     private static int waitForStateAnyOf(
             ThreadNetworkController controller, List<Integer> deviceRoles) throws Exception {
-        SettableFuture<Integer> future = SettableFuture.create();
+        CompletableFuture<Integer> future = new CompletableFuture<>();
         StateCallback callback =
                 newRole -> {
                     if (deviceRoles.contains(newRole)) {
-                        future.set(newRole);
+                        future.complete(newRole);
                     }
                 };
         controller.registerStateCallback(directExecutor(), callback);
-        int role = future.get();
+        int role = future.get(JOIN_TIMEOUT_MILLIS, MILLISECONDS);
         controller.unregisterStateCallback(callback);
         return role;
     }
 
+    private static void waitForEnabledState(ThreadNetworkController controller, int state)
+            throws Exception {
+        CompletableFuture<Integer> future = new CompletableFuture<>();
+        StateCallback callback =
+                new ThreadNetworkController.StateCallback() {
+                    @Override
+                    public void onDeviceRoleChanged(int r) {}
+
+                    @Override
+                    public void onThreadEnableStateChanged(int enabled) {
+                        if (enabled == state) {
+                            future.complete(enabled);
+                        }
+                    }
+                };
+        runAsShell(
+                ACCESS_NETWORK_STATE,
+                () -> controller.registerStateCallback(directExecutor(), callback));
+        future.get(ENABLED_TIMEOUT_MILLIS, MILLISECONDS);
+        runAsShell(ACCESS_NETWORK_STATE, () -> controller.unregisterStateCallback(callback));
+    }
+
+    private void leave(
+            ThreadNetworkController controller,
+            OutcomeReceiver<Void, ThreadNetworkException> receiver) {
+        runAsShell(THREAD_NETWORK_PRIVILEGED, () -> controller.leave(mExecutor, receiver));
+    }
+
+    private void scheduleMigration(
+            ThreadNetworkController controller,
+            PendingOperationalDataset pendingDataset,
+            OutcomeReceiver<Void, ThreadNetworkException> receiver) {
+        runAsShell(
+                THREAD_NETWORK_PRIVILEGED,
+                () -> controller.scheduleMigration(pendingDataset, mExecutor, receiver));
+    }
+
+    private class EnabledStateListener {
+        private ArrayTrackRecord<Integer> mEnabledStates = new ArrayTrackRecord<>();
+        private final ArrayTrackRecord<Integer>.ReadHead mReadHead = mEnabledStates.newReadHead();
+        ThreadNetworkController mController;
+        StateCallback mCallback =
+                new ThreadNetworkController.StateCallback() {
+                    @Override
+                    public void onDeviceRoleChanged(int r) {}
+
+                    @Override
+                    public void onThreadEnableStateChanged(int enabled) {
+                        mEnabledStates.add(enabled);
+                    }
+                };
+
+        EnabledStateListener(ThreadNetworkController controller) {
+            this.mController = controller;
+            runAsShell(
+                    ACCESS_NETWORK_STATE,
+                    () -> controller.registerStateCallback(mExecutor, mCallback));
+        }
+
+        public void expectThreadEnabledState(int enabled) {
+            assertNotNull(mReadHead.poll(ENABLED_TIMEOUT_MILLIS, e -> (e == enabled)));
+        }
+
+        public void unregisterStateCallback() {
+            runAsShell(ACCESS_NETWORK_STATE, () -> mController.unregisterStateCallback(mCallback));
+        }
+    }
+
+    private int booleanToEnabledState(boolean enabled) {
+        return enabled ? STATE_ENABLED : STATE_DISABLED;
+    }
+
+    private void setEnabledAndWait(ThreadNetworkController controller, boolean enabled)
+            throws Exception {
+        CompletableFuture<Void> setFuture = new CompletableFuture<>();
+        runAsShell(
+                THREAD_NETWORK_PRIVILEGED,
+                () -> controller.setEnabled(enabled, mExecutor, newOutcomeReceiver(setFuture)));
+        setFuture.get(ENABLED_TIMEOUT_MILLIS, MILLISECONDS);
+        waitForEnabledState(controller, booleanToEnabledState(enabled));
+    }
+
+    private CompletableFuture joinRandomizedDataset(ThreadNetworkController controller)
+            throws Exception {
+        ActiveOperationalDataset activeDataset = newRandomizedDataset("TestNet", controller);
+        CompletableFuture<Void> joinFuture = new CompletableFuture<>();
+        runAsShell(
+                THREAD_NETWORK_PRIVILEGED,
+                () -> controller.join(activeDataset, mExecutor, newOutcomeReceiver(joinFuture)));
+        return joinFuture;
+    }
+
+    private void joinRandomizedDatasetAndWait(ThreadNetworkController controller) throws Exception {
+        CompletableFuture<Void> joinFuture = joinRandomizedDataset(controller);
+        joinFuture.get(JOIN_TIMEOUT_MILLIS, MILLISECONDS);
+        assertThat(isAttached(controller)).isTrue();
+    }
+
     private static ActiveOperationalDataset getActiveOperationalDataset(
             ThreadNetworkController controller) throws Exception {
-        SettableFuture<ActiveOperationalDataset> future = SettableFuture.create();
-        OperationalDatasetCallback callback = future::set;
-        controller.registerOperationalDatasetCallback(directExecutor(), callback);
-        ActiveOperationalDataset dataset = future.get(CALLBACK_TIMEOUT_MILLIS, MILLISECONDS);
-        controller.unregisterOperationalDatasetCallback(callback);
-        return dataset;
+        CompletableFuture<ActiveOperationalDataset> future = new CompletableFuture<>();
+        OperationalDatasetCallback callback = future::complete;
+        runAsShell(
+                ACCESS_NETWORK_STATE,
+                THREAD_NETWORK_PRIVILEGED,
+                () -> controller.registerOperationalDatasetCallback(directExecutor(), callback));
+        try {
+            return future.get(CALLBACK_TIMEOUT_MILLIS, MILLISECONDS);
+        } finally {
+            runAsShell(
+                    ACCESS_NETWORK_STATE,
+                    THREAD_NETWORK_PRIVILEGED,
+                    () -> controller.unregisterOperationalDatasetCallback(callback));
+        }
     }
 
     private static PendingOperationalDataset getPendingOperationalDataset(
             ThreadNetworkController controller) throws Exception {
-        SettableFuture<ActiveOperationalDataset> activeFuture = SettableFuture.create();
-        SettableFuture<PendingOperationalDataset> pendingFuture = SettableFuture.create();
+        CompletableFuture<ActiveOperationalDataset> activeFuture = new CompletableFuture<>();
+        CompletableFuture<PendingOperationalDataset> pendingFuture = new CompletableFuture<>();
         controller.registerOperationalDatasetCallback(
                 directExecutor(), newDatasetCallback(activeFuture, pendingFuture));
-        return pendingFuture.get();
+        return pendingFuture.get(CALLBACK_TIMEOUT_MILLIS, MILLISECONDS);
     }
 
     private static OperationalDatasetCallback newDatasetCallback(
-            SettableFuture<ActiveOperationalDataset> activeFuture,
-            SettableFuture<PendingOperationalDataset> pendingFuture) {
+            CompletableFuture<ActiveOperationalDataset> activeFuture,
+            CompletableFuture<PendingOperationalDataset> pendingFuture) {
         return new OperationalDatasetCallback() {
             @Override
             public void onActiveOperationalDatasetChanged(
                     ActiveOperationalDataset activeOpDataset) {
-                activeFuture.set(activeOpDataset);
+                activeFuture.complete(activeOpDataset);
             }
 
             @Override
             public void onPendingOperationalDatasetChanged(
                     PendingOperationalDataset pendingOpDataset) {
-                pendingFuture.set(pendingOpDataset);
+                pendingFuture.complete(pendingOpDataset);
             }
         };
     }
 
-    @Test
-    public void getThreadVersion_returnsAtLeastThreadVersion1P3() {
-        for (ThreadNetworkController controller : getAllControllers()) {
-            assertThat(controller.getThreadVersion()).isAtLeast(THREAD_VERSION_1_3);
+    private static void assertDoesNotThrow(ThrowingRunnable runnable) {
+        try {
+            runnable.run();
+        } catch (Throwable e) {
+            fail("Should not have thrown " + e);
         }
     }
-
-    @Test
-    public void registerStateCallback_permissionsGranted_returnsCurrentStates() throws Exception {
-        grantPermissions(permission.ACCESS_NETWORK_STATE);
-
-        for (ThreadNetworkController controller : getAllControllers()) {
-            SettableFuture<Integer> deviceRole = SettableFuture.create();
-            StateCallback callback = deviceRole::set;
-
-            try {
-                controller.registerStateCallback(mExecutor, callback);
-
-                assertThat(deviceRole.get()).isEqualTo(DEVICE_ROLE_STOPPED);
-            } finally {
-                controller.unregisterStateCallback(callback);
-            }
-        }
-    }
-
-    @Test
-    public void registerStateCallback_noPermissions_throwsSecurityException() throws Exception {
-        dropAllPermissions();
-
-        for (ThreadNetworkController controller : getAllControllers()) {
-            assertThrows(
-                    SecurityException.class,
-                    () -> controller.registerStateCallback(mExecutor, role -> {}));
-        }
-    }
-
-    @Test
-    public void registerStateCallback_alreadyRegistered_throwsIllegalArgumentException()
-            throws Exception {
-        grantPermissions(permission.ACCESS_NETWORK_STATE);
-
-        for (ThreadNetworkController controller : getAllControllers()) {
-            SettableFuture<Integer> deviceRole = SettableFuture.create();
-            StateCallback callback = role -> deviceRole.set(role);
-            controller.registerStateCallback(mExecutor, callback);
-
-            assertThrows(
-                    IllegalArgumentException.class,
-                    () -> controller.registerStateCallback(mExecutor, callback));
-        }
-    }
-
-    @Test
-    public void unregisterStateCallback_noPermissions_throwsSecurityException() throws Exception {
-        for (ThreadNetworkController controller : getAllControllers()) {
-            SettableFuture<Integer> deviceRole = SettableFuture.create();
-            StateCallback callback = role -> deviceRole.set(role);
-            grantPermissions(permission.ACCESS_NETWORK_STATE);
-            controller.registerStateCallback(mExecutor, callback);
-
-            try {
-                dropAllPermissions();
-                assertThrows(
-                        SecurityException.class,
-                        () -> controller.unregisterStateCallback(callback));
-            } finally {
-                grantPermissions(permission.ACCESS_NETWORK_STATE);
-                controller.unregisterStateCallback(callback);
-            }
-        }
-    }
-
-    @Test
-    public void unregisterStateCallback_callbackRegistered_success() throws Exception {
-        grantPermissions(permission.ACCESS_NETWORK_STATE);
-        for (ThreadNetworkController controller : getAllControllers()) {
-            SettableFuture<Integer> deviceRole = SettableFuture.create();
-            StateCallback callback = role -> deviceRole.set(role);
-            controller.registerStateCallback(mExecutor, callback);
-
-            controller.unregisterStateCallback(callback);
-        }
-    }
-
-    @Test
-    public void unregisterStateCallback_callbackNotRegistered_throwsIllegalArgumentException()
-            throws Exception {
-        for (ThreadNetworkController controller : getAllControllers()) {
-            SettableFuture<Integer> deviceRole = SettableFuture.create();
-            StateCallback callback = role -> deviceRole.set(role);
-
-            assertThrows(
-                    IllegalArgumentException.class,
-                    () -> controller.unregisterStateCallback(callback));
-        }
-    }
-
-    @Test
-    public void unregisterStateCallback_alreadyUnregistered_throwsIllegalArgumentException()
-            throws Exception {
-        grantPermissions(permission.ACCESS_NETWORK_STATE);
-        for (ThreadNetworkController controller : getAllControllers()) {
-            SettableFuture<Integer> deviceRole = SettableFuture.create();
-            StateCallback callback = deviceRole::set;
-            controller.registerStateCallback(mExecutor, callback);
-            controller.unregisterStateCallback(callback);
-
-            assertThrows(
-                    IllegalArgumentException.class,
-                    () -> controller.unregisterStateCallback(callback));
-        }
-    }
-
-    @Test
-    public void registerOperationalDatasetCallback_permissionsGranted_returnsCurrentStates()
-            throws Exception {
-        grantPermissions(permission.ACCESS_NETWORK_STATE, PERMISSION_THREAD_NETWORK_PRIVILEGED);
-
-        for (ThreadNetworkController controller : getAllControllers()) {
-            SettableFuture<ActiveOperationalDataset> activeFuture = SettableFuture.create();
-            SettableFuture<PendingOperationalDataset> pendingFuture = SettableFuture.create();
-            var callback = newDatasetCallback(activeFuture, pendingFuture);
-
-            try {
-                controller.registerOperationalDatasetCallback(mExecutor, callback);
-
-                assertThat(activeFuture.get()).isNull();
-                assertThat(pendingFuture.get()).isNull();
-            } finally {
-                controller.unregisterOperationalDatasetCallback(callback);
-            }
-        }
-    }
-
-    @Test
-    public void registerOperationalDatasetCallback_noPermissions_throwsSecurityException()
-            throws Exception {
-        dropAllPermissions();
-
-        for (ThreadNetworkController controller : getAllControllers()) {
-            SettableFuture<ActiveOperationalDataset> activeFuture = SettableFuture.create();
-            SettableFuture<PendingOperationalDataset> pendingFuture = SettableFuture.create();
-            var callback = newDatasetCallback(activeFuture, pendingFuture);
-
-            assertThrows(
-                    SecurityException.class,
-                    () -> controller.registerOperationalDatasetCallback(mExecutor, callback));
-        }
-    }
-
-    @Test
-    public void unregisterOperationalDatasetCallback_callbackRegistered_success() throws Exception {
-        grantPermissions(permission.ACCESS_NETWORK_STATE, PERMISSION_THREAD_NETWORK_PRIVILEGED);
-        for (ThreadNetworkController controller : getAllControllers()) {
-            SettableFuture<ActiveOperationalDataset> activeFuture = SettableFuture.create();
-            SettableFuture<PendingOperationalDataset> pendingFuture = SettableFuture.create();
-            var callback = newDatasetCallback(activeFuture, pendingFuture);
-            controller.registerOperationalDatasetCallback(mExecutor, callback);
-
-            controller.unregisterOperationalDatasetCallback(callback);
-        }
-    }
-
-    @Test
-    public void unregisterOperationalDatasetCallback_noPermissions_throwsSecurityException()
-            throws Exception {
-        dropAllPermissions();
-
-        for (ThreadNetworkController controller : getAllControllers()) {
-            SettableFuture<ActiveOperationalDataset> activeFuture = SettableFuture.create();
-            SettableFuture<PendingOperationalDataset> pendingFuture = SettableFuture.create();
-            var callback = newDatasetCallback(activeFuture, pendingFuture);
-            grantPermissions(permission.ACCESS_NETWORK_STATE, PERMISSION_THREAD_NETWORK_PRIVILEGED);
-            controller.registerOperationalDatasetCallback(mExecutor, callback);
-
-            try {
-                dropAllPermissions();
-                assertThrows(
-                        SecurityException.class,
-                        () -> controller.unregisterOperationalDatasetCallback(callback));
-            } finally {
-                grantPermissions(
-                        permission.ACCESS_NETWORK_STATE, PERMISSION_THREAD_NETWORK_PRIVILEGED);
-                controller.unregisterOperationalDatasetCallback(callback);
-            }
-        }
-    }
-
-    private static <V> OutcomeReceiver<V, ThreadNetworkException> newOutcomeReceiver(
-            SettableFuture<V> future) {
-        return new OutcomeReceiver<V, ThreadNetworkException>() {
-            @Override
-            public void onResult(V result) {
-                future.set(result);
-            }
-
-            @Override
-            public void onError(ThreadNetworkException e) {
-                future.setException(e);
-            }
-        };
-    }
-
-    @Test
-    public void join_withPrivilegedPermission_success() throws Exception {
-        grantPermissions(PERMISSION_THREAD_NETWORK_PRIVILEGED);
-
-        for (ThreadNetworkController controller : getAllControllers()) {
-            ActiveOperationalDataset activeDataset = newRandomizedDataset("TestNet", controller);
-            SettableFuture<Void> joinFuture = SettableFuture.create();
-
-            controller.join(activeDataset, mExecutor, newOutcomeReceiver(joinFuture));
-            joinFuture.get();
-
-            grantPermissions(permission.ACCESS_NETWORK_STATE);
-            assertThat(isAttached(controller)).isTrue();
-            assertThat(getActiveOperationalDataset(controller)).isEqualTo(activeDataset);
-        }
-    }
-
-    @Test
-    public void join_withoutPrivilegedPermission_throwsSecurityException() throws Exception {
-        dropAllPermissions();
-
-        for (ThreadNetworkController controller : getAllControllers()) {
-            ActiveOperationalDataset activeDataset = newRandomizedDataset("TestNet", controller);
-
-            assertThrows(
-                    SecurityException.class,
-                    () -> controller.join(activeDataset, mExecutor, v -> {}));
-        }
-    }
-
-    @Test
-    public void join_concurrentRequests_firstOneIsAborted() throws Exception {
-        grantPermissions(PERMISSION_THREAD_NETWORK_PRIVILEGED);
-
-        final byte[] KEY_1 = new byte[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
-        final byte[] KEY_2 = new byte[] {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2};
-        for (ThreadNetworkController controller : getAllControllers()) {
-            ActiveOperationalDataset activeDataset1 =
-                    new ActiveOperationalDataset.Builder(
-                                    newRandomizedDataset("TestNet", controller))
-                            .setNetworkKey(KEY_1)
-                            .build();
-            ActiveOperationalDataset activeDataset2 =
-                    new ActiveOperationalDataset.Builder(activeDataset1)
-                            .setNetworkKey(KEY_2)
-                            .build();
-            SettableFuture<Void> joinFuture1 = SettableFuture.create();
-            SettableFuture<Void> joinFuture2 = SettableFuture.create();
-
-            controller.join(activeDataset1, mExecutor, newOutcomeReceiver(joinFuture1));
-            controller.join(activeDataset2, mExecutor, newOutcomeReceiver(joinFuture2));
-
-            ThreadNetworkException thrown =
-                    (ThreadNetworkException)
-                            assertThrows(ExecutionException.class, joinFuture1::get).getCause();
-            assertThat(thrown.getErrorCode()).isEqualTo(ERROR_ABORTED);
-            joinFuture2.get();
-            grantPermissions(permission.ACCESS_NETWORK_STATE);
-            assertThat(isAttached(controller)).isTrue();
-            assertThat(getActiveOperationalDataset(controller)).isEqualTo(activeDataset2);
-        }
-    }
-
-    @Test
-    public void leave_withPrivilegedPermission_success() throws Exception {
-        grantPermissions(PERMISSION_THREAD_NETWORK_PRIVILEGED);
-
-        for (ThreadNetworkController controller : getAllControllers()) {
-            ActiveOperationalDataset activeDataset = newRandomizedDataset("TestNet", controller);
-            SettableFuture<Void> joinFuture = SettableFuture.create();
-            SettableFuture<Void> leaveFuture = SettableFuture.create();
-            controller.join(activeDataset, mExecutor, newOutcomeReceiver(joinFuture));
-            joinFuture.get();
-
-            controller.leave(mExecutor, newOutcomeReceiver(leaveFuture));
-            leaveFuture.get();
-
-            grantPermissions(permission.ACCESS_NETWORK_STATE);
-            assertThat(getDeviceRole(controller)).isEqualTo(DEVICE_ROLE_STOPPED);
-        }
-    }
-
-    @Test
-    public void leave_withoutPrivilegedPermission_throwsSecurityException() {
-        dropAllPermissions();
-
-        for (ThreadNetworkController controller : getAllControllers()) {
-            assertThrows(SecurityException.class, () -> controller.leave(mExecutor, v -> {}));
-        }
-    }
-
-    @Test
-    public void leave_concurrentRequests_bothSuccess() throws Exception {
-        grantPermissions(PERMISSION_THREAD_NETWORK_PRIVILEGED);
-
-        for (ThreadNetworkController controller : getAllControllers()) {
-            ActiveOperationalDataset activeDataset = newRandomizedDataset("TestNet", controller);
-            SettableFuture<Void> joinFuture = SettableFuture.create();
-            SettableFuture<Void> leaveFuture1 = SettableFuture.create();
-            SettableFuture<Void> leaveFuture2 = SettableFuture.create();
-            controller.join(activeDataset, mExecutor, newOutcomeReceiver(joinFuture));
-            joinFuture.get();
-
-            controller.leave(mExecutor, newOutcomeReceiver(leaveFuture1));
-            controller.leave(mExecutor, newOutcomeReceiver(leaveFuture2));
-
-            leaveFuture1.get();
-            leaveFuture2.get();
-            grantPermissions(permission.ACCESS_NETWORK_STATE);
-            assertThat(getDeviceRole(controller)).isEqualTo(DEVICE_ROLE_STOPPED);
-        }
-    }
-
-    @Test
-    public void scheduleMigration_withPrivilegedPermission_newDatasetApplied() throws Exception {
-        grantPermissions(permission.ACCESS_NETWORK_STATE, PERMISSION_THREAD_NETWORK_PRIVILEGED);
-
-        for (ThreadNetworkController controller : getAllControllers()) {
-            ActiveOperationalDataset activeDataset1 =
-                    new ActiveOperationalDataset.Builder(
-                                    newRandomizedDataset("TestNet", controller))
-                            .setActiveTimestamp(new OperationalDatasetTimestamp(1L, 0, false))
-                            .setExtendedPanId(new byte[] {1, 1, 1, 1, 1, 1, 1, 1})
-                            .build();
-            ActiveOperationalDataset activeDataset2 =
-                    new ActiveOperationalDataset.Builder(activeDataset1)
-                            .setActiveTimestamp(new OperationalDatasetTimestamp(2L, 0, false))
-                            .setNetworkName("ThreadNet2")
-                            .build();
-            PendingOperationalDataset pendingDataset2 =
-                    new PendingOperationalDataset(
-                            activeDataset2,
-                            OperationalDatasetTimestamp.fromInstant(Instant.now()),
-                            Duration.ofSeconds(30));
-            SettableFuture<Void> joinFuture = SettableFuture.create();
-            SettableFuture<Void> migrateFuture = SettableFuture.create();
-            controller.join(activeDataset1, mExecutor, newOutcomeReceiver(joinFuture));
-            joinFuture.get();
-
-            controller.scheduleMigration(
-                    pendingDataset2, mExecutor, newOutcomeReceiver(migrateFuture));
-            migrateFuture.get();
-
-            SettableFuture<Boolean> dataset2IsApplied = SettableFuture.create();
-            SettableFuture<Boolean> pendingDatasetIsRemoved = SettableFuture.create();
-            OperationalDatasetCallback datasetCallback =
-                    new OperationalDatasetCallback() {
-                        @Override
-                        public void onActiveOperationalDatasetChanged(
-                                ActiveOperationalDataset activeDataset) {
-                            if (activeDataset.equals(activeDataset2)) {
-                                dataset2IsApplied.set(true);
-                            }
-                        }
-
-                        @Override
-                        public void onPendingOperationalDatasetChanged(
-                                PendingOperationalDataset pendingDataset) {
-                            if (pendingDataset == null) {
-                                pendingDatasetIsRemoved.set(true);
-                            }
-                        }
-                    };
-            controller.registerOperationalDatasetCallback(directExecutor(), datasetCallback);
-            assertThat(dataset2IsApplied.get()).isTrue();
-            assertThat(pendingDatasetIsRemoved.get()).isTrue();
-            controller.unregisterOperationalDatasetCallback(datasetCallback);
-        }
-    }
-
-    @Test
-    public void scheduleMigration_whenNotAttached_failWithPreconditionError() throws Exception {
-        grantPermissions(permission.ACCESS_NETWORK_STATE, PERMISSION_THREAD_NETWORK_PRIVILEGED);
-
-        for (ThreadNetworkController controller : getAllControllers()) {
-            PendingOperationalDataset pendingDataset =
-                    new PendingOperationalDataset(
-                            newRandomizedDataset("TestNet", controller),
-                            OperationalDatasetTimestamp.fromInstant(Instant.now()),
-                            Duration.ofSeconds(30));
-            SettableFuture<Void> migrateFuture = SettableFuture.create();
-
-            controller.scheduleMigration(
-                    pendingDataset, mExecutor, newOutcomeReceiver(migrateFuture));
-
-            ThreadNetworkException thrown =
-                    (ThreadNetworkException)
-                            assertThrows(ExecutionException.class, migrateFuture::get).getCause();
-            assertThat(thrown.getErrorCode()).isEqualTo(ERROR_FAILED_PRECONDITION);
-        }
-    }
-
-    @Test
-    public void scheduleMigration_secondRequestHasSmallerTimestamp_rejectedByLeader()
-            throws Exception {
-        grantPermissions(permission.ACCESS_NETWORK_STATE, PERMISSION_THREAD_NETWORK_PRIVILEGED);
-
-        for (ThreadNetworkController controller : getAllControllers()) {
-            final ActiveOperationalDataset activeDataset =
-                    new ActiveOperationalDataset.Builder(
-                                    newRandomizedDataset("testNet", controller))
-                            .setActiveTimestamp(new OperationalDatasetTimestamp(1L, 0, false))
-                            .build();
-            ActiveOperationalDataset activeDataset1 =
-                    new ActiveOperationalDataset.Builder(activeDataset)
-                            .setActiveTimestamp(new OperationalDatasetTimestamp(2L, 0, false))
-                            .setNetworkName("testNet1")
-                            .build();
-            PendingOperationalDataset pendingDataset1 =
-                    new PendingOperationalDataset(
-                            activeDataset1,
-                            new OperationalDatasetTimestamp(100, 0, false),
-                            Duration.ofSeconds(30));
-            ActiveOperationalDataset activeDataset2 =
-                    new ActiveOperationalDataset.Builder(activeDataset)
-                            .setActiveTimestamp(new OperationalDatasetTimestamp(3L, 0, false))
-                            .setNetworkName("testNet2")
-                            .build();
-            PendingOperationalDataset pendingDataset2 =
-                    new PendingOperationalDataset(
-                            activeDataset2,
-                            new OperationalDatasetTimestamp(20, 0, false),
-                            Duration.ofSeconds(30));
-            SettableFuture<Void> joinFuture = SettableFuture.create();
-            SettableFuture<Void> migrateFuture1 = SettableFuture.create();
-            SettableFuture<Void> migrateFuture2 = SettableFuture.create();
-            controller.join(activeDataset, mExecutor, newOutcomeReceiver(joinFuture));
-            joinFuture.get();
-
-            controller.scheduleMigration(
-                    pendingDataset1, mExecutor, newOutcomeReceiver(migrateFuture1));
-            migrateFuture1.get();
-            controller.scheduleMigration(
-                    pendingDataset2, mExecutor, newOutcomeReceiver(migrateFuture2));
-
-            ThreadNetworkException thrown =
-                    (ThreadNetworkException)
-                            assertThrows(ExecutionException.class, migrateFuture2::get).getCause();
-            assertThat(thrown.getErrorCode()).isEqualTo(ERROR_REJECTED_BY_PEER);
-        }
-    }
-
-    @Test
-    public void scheduleMigration_secondRequestHasLargerTimestamp_newDatasetApplied()
-            throws Exception {
-        grantPermissions(permission.ACCESS_NETWORK_STATE, PERMISSION_THREAD_NETWORK_PRIVILEGED);
-
-        for (ThreadNetworkController controller : getAllControllers()) {
-            final ActiveOperationalDataset activeDataset =
-                    new ActiveOperationalDataset.Builder(
-                                    newRandomizedDataset("validName", controller))
-                            .setActiveTimestamp(new OperationalDatasetTimestamp(1L, 0, false))
-                            .build();
-            ActiveOperationalDataset activeDataset1 =
-                    new ActiveOperationalDataset.Builder(activeDataset)
-                            .setActiveTimestamp(new OperationalDatasetTimestamp(2L, 0, false))
-                            .setNetworkName("testNet1")
-                            .build();
-            PendingOperationalDataset pendingDataset1 =
-                    new PendingOperationalDataset(
-                            activeDataset1,
-                            new OperationalDatasetTimestamp(100, 0, false),
-                            Duration.ofSeconds(30));
-            ActiveOperationalDataset activeDataset2 =
-                    new ActiveOperationalDataset.Builder(activeDataset)
-                            .setActiveTimestamp(new OperationalDatasetTimestamp(3L, 0, false))
-                            .setNetworkName("testNet2")
-                            .build();
-            PendingOperationalDataset pendingDataset2 =
-                    new PendingOperationalDataset(
-                            activeDataset2,
-                            new OperationalDatasetTimestamp(200, 0, false),
-                            Duration.ofSeconds(30));
-            SettableFuture<Void> joinFuture = SettableFuture.create();
-            SettableFuture<Void> migrateFuture1 = SettableFuture.create();
-            SettableFuture<Void> migrateFuture2 = SettableFuture.create();
-            controller.join(activeDataset, mExecutor, newOutcomeReceiver(joinFuture));
-            joinFuture.get();
-
-            controller.scheduleMigration(
-                    pendingDataset1, mExecutor, newOutcomeReceiver(migrateFuture1));
-            migrateFuture1.get();
-            controller.scheduleMigration(
-                    pendingDataset2, mExecutor, newOutcomeReceiver(migrateFuture2));
-            migrateFuture2.get();
-
-            SettableFuture<Boolean> dataset2IsApplied = SettableFuture.create();
-            SettableFuture<Boolean> pendingDatasetIsRemoved = SettableFuture.create();
-            OperationalDatasetCallback datasetCallback =
-                    new OperationalDatasetCallback() {
-                        @Override
-                        public void onActiveOperationalDatasetChanged(
-                                ActiveOperationalDataset activeDataset) {
-                            if (activeDataset.equals(activeDataset2)) {
-                                dataset2IsApplied.set(true);
-                            }
-                        }
-
-                        @Override
-                        public void onPendingOperationalDatasetChanged(
-                                PendingOperationalDataset pendingDataset) {
-                            if (pendingDataset == null) {
-                                pendingDatasetIsRemoved.set(true);
-                            }
-                        }
-                    };
-            controller.registerOperationalDatasetCallback(directExecutor(), datasetCallback);
-            assertThat(dataset2IsApplied.get()).isTrue();
-            assertThat(pendingDatasetIsRemoved.get()).isTrue();
-            controller.unregisterOperationalDatasetCallback(datasetCallback);
-        }
-    }
-
-    @Test
-    public void createRandomizedDataset_wrongNetworkNameLength_throwsIllegalArgumentException() {
-        for (ThreadNetworkController controller : getAllControllers()) {
-            assertThrows(
-                    IllegalArgumentException.class,
-                    () -> controller.createRandomizedDataset("", mExecutor, dataset -> {}));
-
-            assertThrows(
-                    IllegalArgumentException.class,
-                    () ->
-                            controller.createRandomizedDataset(
-                                    "ANetNameIs17Bytes", mExecutor, dataset -> {}));
-        }
-    }
-
-    @Test
-    public void createRandomizedDataset_validNetworkName_success() throws Exception {
-        for (ThreadNetworkController controller : getAllControllers()) {
-            ActiveOperationalDataset dataset = newRandomizedDataset("validName", controller);
-
-            assertThat(dataset.getNetworkName()).isEqualTo("validName");
-            assertThat(dataset.getPanId()).isLessThan(0xffff);
-            assertThat(dataset.getChannelMask().size()).isAtLeast(1);
-            assertThat(dataset.getExtendedPanId()).hasLength(8);
-            assertThat(dataset.getNetworkKey()).hasLength(16);
-            assertThat(dataset.getPskc()).hasLength(16);
-            assertThat(dataset.getMeshLocalPrefix().getPrefixLength()).isEqualTo(64);
-            assertThat(dataset.getMeshLocalPrefix().getRawAddress()[0]).isEqualTo((byte) 0xfd);
-        }
-    }
-
-    @Test
-    public void threadNetworkCallback_deviceAttached_threadNetworkIsAvailable() throws Exception {
-        ThreadNetworkController controller = mManager.getAllThreadNetworkControllers().get(0);
-        ActiveOperationalDataset activeDataset = newRandomizedDataset("TestNet", controller);
-        SettableFuture<Void> joinFuture = SettableFuture.create();
-        SettableFuture<Network> networkFuture = SettableFuture.create();
-        ConnectivityManager cm = mContext.getSystemService(ConnectivityManager.class);
-        NetworkRequest networkRequest =
-                new NetworkRequest.Builder()
-                        .addTransportType(NetworkCapabilities.TRANSPORT_THREAD)
-                        .build();
-        ConnectivityManager.NetworkCallback networkCallback =
-                new ConnectivityManager.NetworkCallback() {
-                    @Override
-                    public void onAvailable(Network network) {
-                        networkFuture.set(network);
-                    }
-                };
-
-        runAsShell(
-                PERMISSION_THREAD_NETWORK_PRIVILEGED,
-                () -> controller.join(activeDataset, mExecutor, newOutcomeReceiver(joinFuture)));
-        runAsShell(
-                permission.ACCESS_NETWORK_STATE,
-                () -> cm.registerNetworkCallback(networkRequest, networkCallback));
-
-        joinFuture.get(JOIN_TIMEOUT_MILLIS, MILLISECONDS);
-        runAsShell(
-                permission.ACCESS_NETWORK_STATE, () -> assertThat(isAttached(controller)).isTrue());
-        assertThat(networkFuture.get(NETWORK_CALLBACK_TIMEOUT_MILLIS, MILLISECONDS)).isNotNull();
-    }
 }
diff --git a/thread/tests/integration/Android.bp b/thread/tests/integration/Android.bp
index 405fb76..633389f 100644
--- a/thread/tests/integration/Android.bp
+++ b/thread/tests/integration/Android.bp
@@ -43,7 +43,7 @@
     manifest: "AndroidManifest.xml",
     defaults: [
         "framework-connectivity-test-defaults",
-        "ThreadNetworkIntegrationTestsDefaults"
+        "ThreadNetworkIntegrationTestsDefaults",
     ],
     test_suites: [
         "general-tests",
diff --git a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
index 44a8ab7..1d83abc 100644
--- a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
+++ b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
@@ -24,6 +24,7 @@
 import static com.google.common.io.BaseEncoding.base16;
 import static com.google.common.truth.Truth.assertThat;
 
+import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.mock;
@@ -85,6 +86,7 @@
     @Mock private TunInterfaceController mMockTunIfController;
     @Mock private ParcelFileDescriptor mMockTunFd;
     @Mock private InfraInterfaceController mMockInfraIfController;
+    @Mock private ThreadPersistentSettings mMockPersistentSettings;
     private Context mContext;
     private TestLooper mTestLooper;
     private FakeOtDaemon mFakeOtDaemon;
@@ -104,6 +106,8 @@
 
         when(mMockTunIfController.getTunFd()).thenReturn(mMockTunFd);
 
+        when(mMockPersistentSettings.get(any())).thenReturn(true);
+
         mService =
                 new ThreadNetworkControllerService(
                         ApplicationProvider.getApplicationContext(),
@@ -112,7 +116,8 @@
                         () -> mFakeOtDaemon,
                         mMockConnectivityManager,
                         mMockTunIfController,
-                        mMockInfraIfController);
+                        mMockInfraIfController,
+                        mMockPersistentSettings);
         mService.setTestNetworkAgent(mMockNetworkAgent);
     }