Merge "Support adding an IPv4 address via RTM_NEWADDR netlink message." into main
diff --git a/framework-t/api/system-current.txt b/framework-t/api/system-current.txt
index d346af3..8251f85 100644
--- a/framework-t/api/system-current.txt
+++ b/framework-t/api/system-current.txt
@@ -500,6 +500,7 @@
     method @RequiresPermission(allOf={android.Manifest.permission.ACCESS_NETWORK_STATE, "android.permission.THREAD_NETWORK_PRIVILEGED"}) public void registerOperationalDatasetCallback(@NonNull java.util.concurrent.Executor, @NonNull android.net.thread.ThreadNetworkController.OperationalDatasetCallback);
     method @RequiresPermission(android.Manifest.permission.ACCESS_NETWORK_STATE) public void registerStateCallback(@NonNull java.util.concurrent.Executor, @NonNull android.net.thread.ThreadNetworkController.StateCallback);
     method @RequiresPermission("android.permission.THREAD_NETWORK_PRIVILEGED") public void scheduleMigration(@NonNull android.net.thread.PendingOperationalDataset, @NonNull java.util.concurrent.Executor, @NonNull android.os.OutcomeReceiver<java.lang.Void,android.net.thread.ThreadNetworkException>);
+    method @RequiresPermission("android.permission.THREAD_NETWORK_PRIVILEGED") public void setEnabled(boolean, @NonNull java.util.concurrent.Executor, @NonNull android.os.OutcomeReceiver<java.lang.Void,android.net.thread.ThreadNetworkException>);
     method @RequiresPermission(allOf={android.Manifest.permission.ACCESS_NETWORK_STATE, "android.permission.THREAD_NETWORK_PRIVILEGED"}) public void unregisterOperationalDatasetCallback(@NonNull android.net.thread.ThreadNetworkController.OperationalDatasetCallback);
     method @RequiresPermission(android.Manifest.permission.ACCESS_NETWORK_STATE) public void unregisterStateCallback(@NonNull android.net.thread.ThreadNetworkController.StateCallback);
     field public static final int DEVICE_ROLE_CHILD = 2; // 0x2
@@ -507,6 +508,9 @@
     field public static final int DEVICE_ROLE_LEADER = 4; // 0x4
     field public static final int DEVICE_ROLE_ROUTER = 3; // 0x3
     field public static final int DEVICE_ROLE_STOPPED = 0; // 0x0
+    field public static final int STATE_DISABLED = 0; // 0x0
+    field public static final int STATE_DISABLING = 2; // 0x2
+    field public static final int STATE_ENABLED = 1; // 0x1
     field public static final int THREAD_VERSION_1_3 = 4; // 0x4
   }
 
@@ -518,6 +522,7 @@
   public static interface ThreadNetworkController.StateCallback {
     method public void onDeviceRoleChanged(int);
     method public default void onPartitionIdChanged(long);
+    method public default void onThreadEnableStateChanged(int);
   }
 
   @FlaggedApi("com.android.net.thread.flags.thread_enabled") public class ThreadNetworkException extends java.lang.Exception {
@@ -530,6 +535,7 @@
     field public static final int ERROR_REJECTED_BY_PEER = 8; // 0x8
     field public static final int ERROR_RESOURCE_EXHAUSTED = 10; // 0xa
     field public static final int ERROR_RESPONSE_BAD_FORMAT = 9; // 0x9
+    field public static final int ERROR_THREAD_DISABLED = 12; // 0xc
     field public static final int ERROR_TIMEOUT = 3; // 0x3
     field public static final int ERROR_UNAVAILABLE = 4; // 0x4
     field public static final int ERROR_UNKNOWN = 11; // 0xb
diff --git a/service-t/native/libs/libnetworkstats/BpfNetworkStats.cpp b/service-t/native/libs/libnetworkstats/BpfNetworkStats.cpp
index def5f88..d3e331e 100644
--- a/service-t/native/libs/libnetworkstats/BpfNetworkStats.cpp
+++ b/service-t/native/libs/libnetworkstats/BpfNetworkStats.cpp
@@ -50,6 +50,15 @@
     return ifaceStatsMap;
 }
 
+Result<IfaceValue> ifindex2name(const uint32_t ifindex) {
+    Result<IfaceValue> v = getIfaceIndexNameMap().readValue(ifindex);
+    if (v.ok()) return v;
+    IfaceValue iv = {};
+    if (!if_indextoname(ifindex, iv.name)) return v;
+    getIfaceIndexNameMap().writeValue(ifindex, iv, BPF_ANY);
+    return iv;
+}
+
 void bpfRegisterIface(const char* iface) {
     if (!iface) return;
     if (strlen(iface) >= sizeof(IfaceValue)) return;
@@ -78,14 +87,14 @@
 
 int bpfGetIfaceStatsInternal(const char* iface, StatsValue* stats,
                              const BpfMapRO<uint32_t, StatsValue>& ifaceStatsMap,
-                             const BpfMapRO<uint32_t, IfaceValue>& ifaceNameMap) {
+                             const IfIndexToNameFunc ifindex2name) {
     *stats = {};
     int64_t unknownIfaceBytesTotal = 0;
     const auto processIfaceStats =
-            [iface, stats, &ifaceNameMap, &unknownIfaceBytesTotal](
+            [iface, stats, ifindex2name, &unknownIfaceBytesTotal](
                     const uint32_t& key,
                     const BpfMapRO<uint32_t, StatsValue>& ifaceStatsMap) -> Result<void> {
-        Result<IfaceValue> ifname = ifaceNameMap.readValue(key);
+        Result<IfaceValue> ifname = ifindex2name(key);
         if (!ifname.ok()) {
             maybeLogUnknownIface(key, ifaceStatsMap, key, &unknownIfaceBytesTotal);
             return Result<void>();
@@ -104,7 +113,7 @@
 }
 
 int bpfGetIfaceStats(const char* iface, StatsValue* stats) {
-    return bpfGetIfaceStatsInternal(iface, stats, getIfaceStatsMap(), getIfaceIndexNameMap());
+    return bpfGetIfaceStatsInternal(iface, stats, getIfaceStatsMap(), ifindex2name);
 }
 
 int bpfGetIfIndexStatsInternal(uint32_t ifindex, StatsValue* stats,
@@ -138,13 +147,13 @@
 
 int parseBpfNetworkStatsDetailInternal(std::vector<stats_line>& lines,
                                        const BpfMapRO<StatsKey, StatsValue>& statsMap,
-                                       const BpfMapRO<uint32_t, IfaceValue>& ifaceMap) {
+                                       const IfIndexToNameFunc ifindex2name) {
     int64_t unknownIfaceBytesTotal = 0;
     const auto processDetailUidStats =
-            [&lines, &unknownIfaceBytesTotal, &ifaceMap](
+            [&lines, &unknownIfaceBytesTotal, &ifindex2name](
                     const StatsKey& key,
                     const BpfMapRO<StatsKey, StatsValue>& statsMap) -> Result<void> {
-        Result<IfaceValue> ifname = ifaceMap.readValue(key.ifaceIndex);
+        Result<IfaceValue> ifname = ifindex2name(key.ifaceIndex);
         if (!ifname.ok()) {
             maybeLogUnknownIface(key.ifaceIndex, statsMap, key, &unknownIfaceBytesTotal);
             return Result<void>();
@@ -212,7 +221,7 @@
     // TODO: the above comment feels like it may be obsolete / out of date,
     // since we no longer swap the map via netd binder rpc - though we do
     // still swap it.
-    int ret = parseBpfNetworkStatsDetailInternal(*lines, *inactiveStatsMap, getIfaceIndexNameMap());
+    int ret = parseBpfNetworkStatsDetailInternal(*lines, *inactiveStatsMap, ifindex2name);
     if (ret) {
         ALOGE("parse detail network stats failed: %s", strerror(errno));
         return ret;
@@ -229,12 +238,12 @@
 
 int parseBpfNetworkStatsDevInternal(std::vector<stats_line>& lines,
                                     const BpfMapRO<uint32_t, StatsValue>& statsMap,
-                                    const BpfMapRO<uint32_t, IfaceValue>& ifaceMap) {
+                                    const IfIndexToNameFunc ifindex2name) {
     int64_t unknownIfaceBytesTotal = 0;
-    const auto processDetailIfaceStats = [&lines, &unknownIfaceBytesTotal, &ifaceMap, &statsMap](
+    const auto processDetailIfaceStats = [&lines, &unknownIfaceBytesTotal, ifindex2name, &statsMap](
                                              const uint32_t& key, const StatsValue& value,
                                              const BpfMapRO<uint32_t, StatsValue>&) {
-        Result<IfaceValue> ifname = ifaceMap.readValue(key);
+        Result<IfaceValue> ifname = ifindex2name(key);
         if (!ifname.ok()) {
             maybeLogUnknownIface(key, statsMap, key, &unknownIfaceBytesTotal);
             return Result<void>();
@@ -259,7 +268,7 @@
 }
 
 int parseBpfNetworkStatsDev(std::vector<stats_line>* lines) {
-    return parseBpfNetworkStatsDevInternal(*lines, getIfaceStatsMap(), getIfaceIndexNameMap());
+    return parseBpfNetworkStatsDevInternal(*lines, getIfaceStatsMap(), ifindex2name);
 }
 
 void groupNetworkStats(std::vector<stats_line>& lines) {
diff --git a/service-t/native/libs/libnetworkstats/BpfNetworkStatsTest.cpp b/service-t/native/libs/libnetworkstats/BpfNetworkStatsTest.cpp
index 57822fc..484c166 100644
--- a/service-t/native/libs/libnetworkstats/BpfNetworkStatsTest.cpp
+++ b/service-t/native/libs/libnetworkstats/BpfNetworkStatsTest.cpp
@@ -77,6 +77,10 @@
     BpfMap<uint32_t, IfaceValue> mFakeIfaceIndexNameMap;
     BpfMap<uint32_t, StatsValue> mFakeIfaceStatsMap;
 
+    IfIndexToNameFunc mIfIndex2Name = [this](const uint32_t ifindex){
+        return mFakeIfaceIndexNameMap.readValue(ifindex);
+    };
+
     void SetUp() {
         ASSERT_EQ(0, setrlimitForTest());
 
@@ -228,7 +232,7 @@
     populateFakeStats(TEST_UID1, 0, IFACE_INDEX1, TEST_COUNTERSET0, value1, mFakeStatsMap);
     populateFakeStats(TEST_UID1, 0, IFACE_INDEX2, TEST_COUNTERSET1, value1, mFakeStatsMap);
     populateFakeStats(TEST_UID2, 0, IFACE_INDEX3, TEST_COUNTERSET1, value1, mFakeStatsMap);
-    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mFakeIfaceIndexNameMap));
+    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mIfIndex2Name));
     ASSERT_EQ((unsigned long)3, lines.size());
 }
 
@@ -256,16 +260,15 @@
     EXPECT_RESULT_OK(mFakeIfaceStatsMap.writeValue(ifaceStatsKey, value1, BPF_ANY));
 
     StatsValue result1 = {};
-    ASSERT_EQ(0, bpfGetIfaceStatsInternal(IFACE_NAME1, &result1, mFakeIfaceStatsMap,
-                                          mFakeIfaceIndexNameMap));
+    ASSERT_EQ(0,
+              bpfGetIfaceStatsInternal(IFACE_NAME1, &result1, mFakeIfaceStatsMap, mIfIndex2Name));
     expectStatsEqual(value1, result1);
     StatsValue result2 = {};
-    ASSERT_EQ(0, bpfGetIfaceStatsInternal(IFACE_NAME2, &result2, mFakeIfaceStatsMap,
-                                          mFakeIfaceIndexNameMap));
+    ASSERT_EQ(0,
+              bpfGetIfaceStatsInternal(IFACE_NAME2, &result2, mFakeIfaceStatsMap, mIfIndex2Name));
     expectStatsEqual(value2, result2);
     StatsValue totalResult = {};
-    ASSERT_EQ(0, bpfGetIfaceStatsInternal(NULL, &totalResult, mFakeIfaceStatsMap,
-                                          mFakeIfaceIndexNameMap));
+    ASSERT_EQ(0, bpfGetIfaceStatsInternal(NULL, &totalResult, mFakeIfaceStatsMap, mIfIndex2Name));
     StatsValue totalValue = {
             .rxPackets = TEST_PACKET0 * 2 + TEST_PACKET1,
             .rxBytes = TEST_BYTES0 * 2 + TEST_BYTES1,
@@ -304,7 +307,7 @@
                       mFakeStatsMap);
     populateFakeStats(TEST_UID2, TEST_TAG, IFACE_INDEX1, TEST_COUNTERSET0, value1, mFakeStatsMap);
     std::vector<stats_line> lines;
-    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mFakeIfaceIndexNameMap));
+    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mIfIndex2Name));
     ASSERT_EQ((unsigned long)7, lines.size());
 }
 
@@ -324,7 +327,7 @@
     populateFakeStats(TEST_UID1, 0, IFACE_INDEX1, TEST_COUNTERSET1, value1, mFakeStatsMap);
     populateFakeStats(TEST_UID2, 0, IFACE_INDEX1, TEST_COUNTERSET0, value1, mFakeStatsMap);
     std::vector<stats_line> lines;
-    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mFakeIfaceIndexNameMap));
+    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mIfIndex2Name));
     ASSERT_EQ((unsigned long)4, lines.size());
 }
 
@@ -365,7 +368,7 @@
     ASSERT_EQ(-1, unknownIfaceBytesTotal);
     std::vector<stats_line> lines;
     // TODO: find a way to test the total of unknown Iface Bytes go above limit.
-    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mFakeIfaceIndexNameMap));
+    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mIfIndex2Name));
     ASSERT_EQ((unsigned long)1, lines.size());
     expectStatsLineEqual(value1, IFACE_NAME1, TEST_UID1, TEST_COUNTERSET0, 0, lines.front());
 }
@@ -396,8 +399,7 @@
     ifaceStatsKey = IFACE_INDEX4;
     EXPECT_RESULT_OK(mFakeIfaceStatsMap.writeValue(ifaceStatsKey, value2, BPF_ANY));
     std::vector<stats_line> lines;
-    ASSERT_EQ(0,
-              parseBpfNetworkStatsDevInternal(lines, mFakeIfaceStatsMap, mFakeIfaceIndexNameMap));
+    ASSERT_EQ(0, parseBpfNetworkStatsDevInternal(lines, mFakeIfaceStatsMap, mIfIndex2Name));
     ASSERT_EQ((unsigned long)4, lines.size());
 
     expectStatsLineEqual(value1, IFACE_NAME1, UID_ALL, SET_ALL, TAG_NONE, lines[0]);
@@ -441,13 +443,13 @@
     std::vector<stats_line> lines;
 
     // Test empty stats.
-    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mFakeIfaceIndexNameMap));
+    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mIfIndex2Name));
     ASSERT_EQ((size_t) 0, lines.size());
     lines.clear();
 
     // Test 1 line stats.
     populateFakeStats(TEST_UID1, TEST_TAG, IFACE_INDEX1, TEST_COUNTERSET0, value1, mFakeStatsMap);
-    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mFakeIfaceIndexNameMap));
+    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mIfIndex2Name));
     ASSERT_EQ((size_t) 2, lines.size());  // TEST_TAG != 0 -> 1 entry becomes 2 lines
     expectStatsLineEqual(value1, IFACE_NAME1, TEST_UID1, TEST_COUNTERSET0, 0, lines[0]);
     expectStatsLineEqual(value1, IFACE_NAME1, TEST_UID1, TEST_COUNTERSET0, TEST_TAG, lines[1]);
@@ -459,7 +461,7 @@
     populateFakeStats(TEST_UID1, TEST_TAG + 1, IFACE_INDEX1, TEST_COUNTERSET0, value2,
                       mFakeStatsMap);
     populateFakeStats(TEST_UID2, TEST_TAG, IFACE_INDEX1, TEST_COUNTERSET0, value1, mFakeStatsMap);
-    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mFakeIfaceIndexNameMap));
+    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mIfIndex2Name));
     ASSERT_EQ((size_t) 9, lines.size());
     lines.clear();
 
@@ -467,7 +469,7 @@
     populateFakeStats(TEST_UID1, TEST_TAG, IFACE_INDEX3, TEST_COUNTERSET0, value1, mFakeStatsMap);
     populateFakeStats(TEST_UID2, TEST_TAG, IFACE_INDEX3, TEST_COUNTERSET0, value1, mFakeStatsMap);
 
-    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mFakeIfaceIndexNameMap));
+    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mIfIndex2Name));
     ASSERT_EQ((size_t) 9, lines.size());
 
     // Verify Sorted & Grouped.
@@ -492,8 +494,7 @@
     ifaceStatsKey = IFACE_INDEX3;
     EXPECT_RESULT_OK(mFakeIfaceStatsMap.writeValue(ifaceStatsKey, value1, BPF_ANY));
 
-    ASSERT_EQ(0,
-              parseBpfNetworkStatsDevInternal(lines, mFakeIfaceStatsMap, mFakeIfaceIndexNameMap));
+    ASSERT_EQ(0, parseBpfNetworkStatsDevInternal(lines, mFakeIfaceStatsMap, mIfIndex2Name));
     ASSERT_EQ((size_t) 2, lines.size());
 
     expectStatsLineEqual(value3, IFACE_NAME1, UID_ALL, SET_ALL, TAG_NONE, lines[0]);
@@ -534,7 +535,7 @@
     // TODO: Mutate counterSet and enlarge TEST_MAP_SIZE if overflow on counterSet is possible.
 
     std::vector<stats_line> lines;
-    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mFakeIfaceIndexNameMap));
+    ASSERT_EQ(0, parseBpfNetworkStatsDetailInternal(lines, mFakeStatsMap, mIfIndex2Name));
     ASSERT_EQ((size_t) 12, lines.size());
 
     // Uid 0 first
diff --git a/service-t/native/libs/libnetworkstats/include/netdbpf/BpfNetworkStats.h b/service-t/native/libs/libnetworkstats/include/netdbpf/BpfNetworkStats.h
index 173dee4..59eb195 100644
--- a/service-t/native/libs/libnetworkstats/include/netdbpf/BpfNetworkStats.h
+++ b/service-t/native/libs/libnetworkstats/include/netdbpf/BpfNetworkStats.h
@@ -55,20 +55,25 @@
 bool operator==(const stats_line& lhs, const stats_line& rhs);
 bool operator<(const stats_line& lhs, const stats_line& rhs);
 
+// This mirrors BpfMap.h's:
+//   Result<Value> readValue(const Key key) const
+// for a BpfMap<uint32_t, IfaceValue>
+using IfIndexToNameFunc = std::function<Result<IfaceValue>(const uint32_t)>;
+
 // For test only
 int bpfGetUidStatsInternal(uid_t uid, StatsValue* stats,
                            const BpfMapRO<uint32_t, StatsValue>& appUidStatsMap);
 // For test only
 int bpfGetIfaceStatsInternal(const char* iface, StatsValue* stats,
                              const BpfMapRO<uint32_t, StatsValue>& ifaceStatsMap,
-                             const BpfMapRO<uint32_t, IfaceValue>& ifaceNameMap);
+                             const IfIndexToNameFunc ifindex2name);
 // For test only
 int bpfGetIfIndexStatsInternal(uint32_t ifindex, StatsValue* stats,
                                const BpfMapRO<uint32_t, StatsValue>& ifaceStatsMap);
 // For test only
 int parseBpfNetworkStatsDetailInternal(std::vector<stats_line>& lines,
                                        const BpfMapRO<StatsKey, StatsValue>& statsMap,
-                                       const BpfMapRO<uint32_t, IfaceValue>& ifaceMap);
+                                       const IfIndexToNameFunc ifindex2name);
 // For test only
 int cleanStatsMapInternal(const base::unique_fd& cookieTagMap, const base::unique_fd& tagStatsMap);
 
@@ -98,7 +103,7 @@
 // For test only
 int parseBpfNetworkStatsDevInternal(std::vector<stats_line>& lines,
                                     const BpfMapRO<uint32_t, StatsValue>& statsMap,
-                                    const BpfMapRO<uint32_t, IfaceValue>& ifaceMap);
+                                    const IfIndexToNameFunc ifindex2name);
 
 void bpfRegisterIface(const char* iface);
 int bpfGetUidStats(uid_t uid, StatsValue* stats);
diff --git a/service-t/src/com/android/server/net/BpfInterfaceMapHelper.java b/service-t/src/com/android/server/net/BpfInterfaceMapHelper.java
new file mode 100644
index 0000000..3c95b8e
--- /dev/null
+++ b/service-t/src/com/android/server/net/BpfInterfaceMapHelper.java
@@ -0,0 +1,101 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.android.server.net;
+
+import android.os.Build;
+import android.system.ErrnoException;
+import android.util.IndentingPrintWriter;
+import android.util.Log;
+
+import androidx.annotation.RequiresApi;
+
+import com.android.internal.annotations.VisibleForTesting;
+import com.android.net.module.util.BpfDump;
+import com.android.net.module.util.BpfMap;
+import com.android.net.module.util.IBpfMap;
+import com.android.net.module.util.Struct.S32;
+
+/**
+ * Monitor interface added (without removed) and right interface name and its index to bpf map.
+ */
+@RequiresApi(Build.VERSION_CODES.TIRAMISU)
+public class BpfInterfaceMapHelper {
+    private static final String TAG = BpfInterfaceMapHelper.class.getSimpleName();
+    // This is current path but may be changed soon.
+    private static final String IFACE_INDEX_NAME_MAP_PATH =
+            "/sys/fs/bpf/netd_shared/map_netd_iface_index_name_map";
+    private final IBpfMap<S32, InterfaceMapValue> mIndexToIfaceBpfMap;
+
+    public BpfInterfaceMapHelper() {
+        this(new Dependencies());
+    }
+
+    @VisibleForTesting
+    public BpfInterfaceMapHelper(Dependencies deps) {
+        mIndexToIfaceBpfMap = deps.getInterfaceMap();
+    }
+
+    /**
+     * Dependencies of BpfInerfaceMapUpdater, for injection in tests.
+     */
+    @VisibleForTesting
+    public static class Dependencies {
+        /** Create BpfMap for updating interface and index mapping. */
+        public IBpfMap<S32, InterfaceMapValue> getInterfaceMap() {
+            try {
+                return new BpfMap<>(IFACE_INDEX_NAME_MAP_PATH,
+                    S32.class, InterfaceMapValue.class);
+            } catch (ErrnoException e) {
+                Log.e(TAG, "Cannot create interface map: " + e);
+                return null;
+            }
+        }
+    }
+
+    /** get interface name by interface index from bpf map */
+    public String getIfNameByIndex(final int index) {
+        try {
+            final InterfaceMapValue value = mIndexToIfaceBpfMap.getValue(new S32(index));
+            if (value == null) {
+                Log.e(TAG, "No if name entry for index " + index);
+                return null;
+            }
+            return value.getInterfaceNameString();
+        } catch (ErrnoException e) {
+            Log.e(TAG, "Failed to get entry for index " + index + ": " + e);
+            return null;
+        }
+    }
+
+    /**
+     * Dump BPF map
+     *
+     * @param pw print writer
+     */
+    public void dump(final IndentingPrintWriter pw) {
+        pw.println("BPF map status:");
+        pw.increaseIndent();
+        BpfDump.dumpMapStatus(mIndexToIfaceBpfMap, pw, "IfaceIndexNameMap",
+                IFACE_INDEX_NAME_MAP_PATH);
+        pw.decreaseIndent();
+        pw.println("BPF map content:");
+        pw.increaseIndent();
+        BpfDump.dumpMap(mIndexToIfaceBpfMap, pw, "IfaceIndexNameMap",
+                (key, value) -> "ifaceIndex=" + key.val
+                        + " ifaceName=" + value.getInterfaceNameString());
+        pw.decreaseIndent();
+    }
+}
diff --git a/service-t/src/com/android/server/net/BpfInterfaceMapUpdater.java b/service-t/src/com/android/server/net/BpfInterfaceMapUpdater.java
deleted file mode 100644
index 59de2c4..0000000
--- a/service-t/src/com/android/server/net/BpfInterfaceMapUpdater.java
+++ /dev/null
@@ -1,179 +0,0 @@
-/*
- * Copyright (C) 2022 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package com.android.server.net;
-
-import android.os.Build;
-import android.content.Context;
-import android.net.INetd;
-import android.os.Handler;
-import android.os.IBinder;
-import android.os.RemoteException;
-import android.os.ServiceSpecificException;
-import android.system.ErrnoException;
-import android.util.IndentingPrintWriter;
-import android.util.Log;
-
-import androidx.annotation.RequiresApi;
-
-import com.android.internal.annotations.VisibleForTesting;
-import com.android.net.module.util.BaseNetdUnsolicitedEventListener;
-import com.android.net.module.util.BpfDump;
-import com.android.net.module.util.BpfMap;
-import com.android.net.module.util.IBpfMap;
-import com.android.net.module.util.InterfaceParams;
-import com.android.net.module.util.Struct.S32;
-
-/**
- * Monitor interface added (without removed) and right interface name and its index to bpf map.
- */
-@RequiresApi(Build.VERSION_CODES.TIRAMISU)
-public class BpfInterfaceMapUpdater {
-    private static final String TAG = BpfInterfaceMapUpdater.class.getSimpleName();
-    // This is current path but may be changed soon.
-    private static final String IFACE_INDEX_NAME_MAP_PATH =
-            "/sys/fs/bpf/netd_shared/map_netd_iface_index_name_map";
-    private final IBpfMap<S32, InterfaceMapValue> mIndexToIfaceBpfMap;
-    private final INetd mNetd;
-    private final Handler mHandler;
-    private final Dependencies mDeps;
-
-    public BpfInterfaceMapUpdater(Context ctx, Handler handler) {
-        this(ctx, handler, new Dependencies());
-    }
-
-    @VisibleForTesting
-    public BpfInterfaceMapUpdater(Context ctx, Handler handler, Dependencies deps) {
-        mDeps = deps;
-        mIndexToIfaceBpfMap = deps.getInterfaceMap();
-        mNetd = deps.getINetd(ctx);
-        mHandler = handler;
-    }
-
-    /**
-     * Dependencies of BpfInerfaceMapUpdater, for injection in tests.
-     */
-    @VisibleForTesting
-    public static class Dependencies {
-        /** Create BpfMap for updating interface and index mapping. */
-        public IBpfMap<S32, InterfaceMapValue> getInterfaceMap() {
-            try {
-                return new BpfMap<>(IFACE_INDEX_NAME_MAP_PATH,
-                    S32.class, InterfaceMapValue.class);
-            } catch (ErrnoException e) {
-                Log.e(TAG, "Cannot create interface map: " + e);
-                return null;
-            }
-        }
-
-        /** Get InterfaceParams for giving interface name. */
-        public InterfaceParams getInterfaceParams(String ifaceName) {
-            return InterfaceParams.getByName(ifaceName);
-        }
-
-        /** Get INetd binder object. */
-        public INetd getINetd(Context ctx) {
-            return INetd.Stub.asInterface((IBinder) ctx.getSystemService(Context.NETD_SERVICE));
-        }
-    }
-
-    /**
-     * Start listening interface update event.
-     * Query current interface names before listening.
-     */
-    public void start() {
-        mHandler.post(() -> {
-            if (mIndexToIfaceBpfMap == null) {
-                Log.wtf(TAG, "Fail to start: Null bpf map");
-                return;
-            }
-
-            try {
-                // TODO: use a NetlinkMonitor and listen for RTM_NEWLINK messages instead.
-                mNetd.registerUnsolicitedEventListener(new InterfaceChangeObserver());
-            } catch (RemoteException e) {
-                Log.wtf(TAG, "Unable to register netd UnsolicitedEventListener, " + e);
-            }
-
-            final String[] ifaces;
-            try {
-                // TODO: use a netlink dump to get the current interface list.
-                ifaces = mNetd.interfaceGetList();
-            } catch (RemoteException | ServiceSpecificException e) {
-                Log.wtf(TAG, "Unable to query interface names by netd, " + e);
-                return;
-            }
-
-            for (String ifaceName : ifaces) {
-                addInterface(ifaceName);
-            }
-        });
-    }
-
-    private void addInterface(String ifaceName) {
-        final InterfaceParams iface = mDeps.getInterfaceParams(ifaceName);
-        if (iface == null) {
-            Log.e(TAG, "Unable to get InterfaceParams for " + ifaceName);
-            return;
-        }
-
-        try {
-            mIndexToIfaceBpfMap.updateEntry(new S32(iface.index), new InterfaceMapValue(ifaceName));
-        } catch (ErrnoException e) {
-            Log.e(TAG, "Unable to update entry for " + ifaceName + ", " + e);
-        }
-    }
-
-    private class InterfaceChangeObserver extends BaseNetdUnsolicitedEventListener {
-        @Override
-        public void onInterfaceAdded(String ifName) {
-            mHandler.post(() -> addInterface(ifName));
-        }
-    }
-
-    /** get interface name by interface index from bpf map */
-    public String getIfNameByIndex(final int index) {
-        try {
-            final InterfaceMapValue value = mIndexToIfaceBpfMap.getValue(new S32(index));
-            if (value == null) {
-                Log.e(TAG, "No if name entry for index " + index);
-                return null;
-            }
-            return value.getInterfaceNameString();
-        } catch (ErrnoException e) {
-            Log.e(TAG, "Failed to get entry for index " + index + ": " + e);
-            return null;
-        }
-    }
-
-    /**
-     * Dump BPF map
-     *
-     * @param pw print writer
-     */
-    public void dump(final IndentingPrintWriter pw) {
-        pw.println("BPF map status:");
-        pw.increaseIndent();
-        BpfDump.dumpMapStatus(mIndexToIfaceBpfMap, pw, "IfaceIndexNameMap",
-                IFACE_INDEX_NAME_MAP_PATH);
-        pw.decreaseIndent();
-        pw.println("BPF map content:");
-        pw.increaseIndent();
-        BpfDump.dumpMap(mIndexToIfaceBpfMap, pw, "IfaceIndexNameMap",
-                (key, value) -> "ifaceIndex=" + key.val
-                        + " ifaceName=" + value.getInterfaceNameString());
-        pw.decreaseIndent();
-    }
-}
diff --git a/service-t/src/com/android/server/net/NetworkStatsService.java b/service-t/src/com/android/server/net/NetworkStatsService.java
index 7b24315..ec10158 100644
--- a/service-t/src/com/android/server/net/NetworkStatsService.java
+++ b/service-t/src/com/android/server/net/NetworkStatsService.java
@@ -476,7 +476,7 @@
     private final LocationPermissionChecker mLocationPermissionChecker;
 
     @NonNull
-    private final BpfInterfaceMapUpdater mInterfaceMapUpdater;
+    private final BpfInterfaceMapHelper mInterfaceMapHelper;
 
     @Nullable
     private final SkDestroyListener mSkDestroyListener;
@@ -628,8 +628,7 @@
         mContentObserver = mDeps.makeContentObserver(mHandler, mSettings,
                 mNetworkStatsSubscriptionsMonitor);
         mLocationPermissionChecker = mDeps.makeLocationPermissionChecker(mContext);
-        mInterfaceMapUpdater = mDeps.makeBpfInterfaceMapUpdater(mContext, mHandler);
-        mInterfaceMapUpdater.start();
+        mInterfaceMapHelper = mDeps.makeBpfInterfaceMapHelper();
         mUidCounterSetMap = mDeps.getUidCounterSetMap();
         mCookieTagMap = mDeps.getCookieTagMap();
         mStatsMapA = mDeps.getStatsMapA();
@@ -798,11 +797,10 @@
             return new LocationPermissionChecker(context);
         }
 
-        /** Create BpfInterfaceMapUpdater to update bpf interface map. */
+        /** Create BpfInterfaceMapHelper to update bpf interface map. */
         @NonNull
-        public BpfInterfaceMapUpdater makeBpfInterfaceMapUpdater(
-                @NonNull Context ctx, @NonNull Handler handler) {
-            return new BpfInterfaceMapUpdater(ctx, handler);
+        public BpfInterfaceMapHelper makeBpfInterfaceMapHelper() {
+            return new BpfInterfaceMapHelper();
         }
 
         /** Get counter sets map for each UID. */
@@ -2889,9 +2887,9 @@
             }
 
             pw.println();
-            pw.println("InterfaceMapUpdater:");
+            pw.println("InterfaceMapHelper:");
             pw.increaseIndent();
-            mInterfaceMapUpdater.dump(pw);
+            mInterfaceMapHelper.dump(pw);
             pw.decreaseIndent();
 
             pw.println();
@@ -3038,7 +3036,7 @@
         BpfDump.dumpMap(statsMap, pw, mapName,
                 "ifaceIndex ifaceName tag_hex uid_int cnt_set rxBytes rxPackets txBytes txPackets",
                 (key, value) -> {
-                    final String ifName = mInterfaceMapUpdater.getIfNameByIndex(key.ifaceIndex);
+                    final String ifName = mInterfaceMapHelper.getIfNameByIndex(key.ifaceIndex);
                     return key.ifaceIndex + " "
                             + (ifName != null ? ifName : "unknown") + " "
                             + "0x" + Long.toHexString(key.tag) + " "
@@ -3056,7 +3054,7 @@
         BpfDump.dumpMap(mIfaceStatsMap, pw, "mIfaceStatsMap",
                 "ifaceIndex ifaceName rxBytes rxPackets txBytes txPackets",
                 (key, value) -> {
-                    final String ifName = mInterfaceMapUpdater.getIfNameByIndex(key.val);
+                    final String ifName = mInterfaceMapHelper.getIfNameByIndex(key.val);
                     return key.val + " "
                             + (ifName != null ? ifName : "unknown") + " "
                             + value.rxBytes + " "
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/DevSdkIgnoreRunner.kt b/staticlibs/testutils/devicetests/com/android/testutils/DevSdkIgnoreRunner.kt
index 10accd4..69fdbf8 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/DevSdkIgnoreRunner.kt
+++ b/staticlibs/testutils/devicetests/com/android/testutils/DevSdkIgnoreRunner.kt
@@ -31,6 +31,7 @@
 import org.junit.runner.notification.Failure
 import org.junit.runner.notification.RunNotifier
 import org.junit.runners.Parameterized
+import org.mockito.Mockito
 
 /**
  * A runner that can skip tests based on the development SDK as defined in [DevSdkIgnoreRule].
@@ -124,6 +125,9 @@
             notifier.fireTestFailure(Failure(leakMonitorDesc,
                     IllegalStateException("Unexpected thread changes: $threadsDiff")))
         }
+        // Clears up internal state of all inline mocks.
+        // TODO: Call clearInlineMocks() at the end of each test.
+        Mockito.framework().clearInlineMocks()
         notifier.fireTestFinished(leakMonitorDesc)
     }
 
diff --git a/tests/unit/java/com/android/server/net/BpfInterfaceMapHelperTest.java b/tests/unit/java/com/android/server/net/BpfInterfaceMapHelperTest.java
new file mode 100644
index 0000000..7b3bea3
--- /dev/null
+++ b/tests/unit/java/com/android/server/net/BpfInterfaceMapHelperTest.java
@@ -0,0 +1,113 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.net;
+
+import static android.system.OsConstants.EPERM;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.spy;
+
+import android.os.Build;
+import android.system.ErrnoException;
+import android.util.IndentingPrintWriter;
+
+import androidx.test.filters.SmallTest;
+
+import com.android.net.module.util.BaseNetdUnsolicitedEventListener;
+import com.android.net.module.util.IBpfMap;
+import com.android.net.module.util.Struct.S32;
+import com.android.testutils.DevSdkIgnoreRule;
+import com.android.testutils.DevSdkIgnoreRunner;
+import com.android.testutils.TestBpfMap;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.MockitoAnnotations;
+
+import java.io.PrintWriter;
+import java.io.StringWriter;
+
+@SmallTest
+@RunWith(DevSdkIgnoreRunner.class)
+@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
+public final class BpfInterfaceMapHelperTest {
+    private static final int TEST_INDEX = 1;
+    private static final int TEST_INDEX2 = 2;
+    private static final String TEST_INTERFACE_NAME = "test1";
+    private static final String TEST_INTERFACE_NAME2 = "test2";
+
+    private BaseNetdUnsolicitedEventListener mListener;
+    private BpfInterfaceMapHelper mUpdater;
+    private IBpfMap<S32, InterfaceMapValue> mBpfMap =
+            spy(new TestBpfMap<>(S32.class, InterfaceMapValue.class));
+
+    private class TestDependencies extends BpfInterfaceMapHelper.Dependencies {
+        @Override
+        public IBpfMap<S32, InterfaceMapValue> getInterfaceMap() {
+            return mBpfMap;
+        }
+    }
+
+    @Before
+    public void setUp() throws Exception {
+        MockitoAnnotations.initMocks(this);
+        mUpdater = new BpfInterfaceMapHelper(new TestDependencies());
+    }
+
+    @Test
+    public void testGetIfNameByIndex() throws Exception {
+        mBpfMap.updateEntry(new S32(TEST_INDEX), new InterfaceMapValue(TEST_INTERFACE_NAME));
+        assertEquals(TEST_INTERFACE_NAME, mUpdater.getIfNameByIndex(TEST_INDEX));
+    }
+
+    @Test
+    public void testGetIfNameByIndexNoEntry() {
+        assertNull(mUpdater.getIfNameByIndex(TEST_INDEX));
+    }
+
+    @Test
+    public void testGetIfNameByIndexException() throws Exception {
+        doThrow(new ErrnoException("", EPERM)).when(mBpfMap).getValue(new S32(TEST_INDEX));
+        assertNull(mUpdater.getIfNameByIndex(TEST_INDEX));
+    }
+
+    private void assertDumpContains(final String dump, final String message) {
+        assertTrue(String.format("dump(%s) does not contain '%s'", dump, message),
+                dump.contains(message));
+    }
+
+    private String getDump() {
+        final StringWriter sw = new StringWriter();
+        mUpdater.dump(new IndentingPrintWriter(new PrintWriter(sw), " "));
+        return sw.toString();
+    }
+
+    @Test
+    public void testDump() throws ErrnoException {
+        mBpfMap.updateEntry(new S32(TEST_INDEX), new InterfaceMapValue(TEST_INTERFACE_NAME));
+        mBpfMap.updateEntry(new S32(TEST_INDEX2), new InterfaceMapValue(TEST_INTERFACE_NAME2));
+
+        final String dump = getDump();
+        assertDumpContains(dump, "IfaceIndexNameMap: OK");
+        assertDumpContains(dump, "ifaceIndex=1 ifaceName=test1");
+        assertDumpContains(dump, "ifaceIndex=2 ifaceName=test2");
+    }
+}
diff --git a/tests/unit/java/com/android/server/net/BpfInterfaceMapUpdaterTest.java b/tests/unit/java/com/android/server/net/BpfInterfaceMapUpdaterTest.java
deleted file mode 100644
index c730856..0000000
--- a/tests/unit/java/com/android/server/net/BpfInterfaceMapUpdaterTest.java
+++ /dev/null
@@ -1,174 +0,0 @@
-/*
- * Copyright (C) 2022 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.server.net;
-
-import static android.system.OsConstants.EPERM;
-
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNull;
-import static org.junit.Assert.assertTrue;
-import static org.mockito.Matchers.eq;
-import static org.mockito.Mockito.doThrow;
-import static org.mockito.Mockito.spy;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.verifyNoMoreInteractions;
-import static org.mockito.Mockito.when;
-
-import android.content.Context;
-import android.net.INetd;
-import android.net.MacAddress;
-import android.os.Build;
-import android.os.Handler;
-import android.os.test.TestLooper;
-import android.system.ErrnoException;
-import android.util.IndentingPrintWriter;
-
-import androidx.test.filters.SmallTest;
-
-import com.android.net.module.util.BaseNetdUnsolicitedEventListener;
-import com.android.net.module.util.IBpfMap;
-import com.android.net.module.util.InterfaceParams;
-import com.android.net.module.util.Struct.S32;
-import com.android.testutils.DevSdkIgnoreRule;
-import com.android.testutils.DevSdkIgnoreRunner;
-import com.android.testutils.TestBpfMap;
-
-import org.junit.Before;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.mockito.ArgumentCaptor;
-import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
-
-import java.io.PrintWriter;
-import java.io.StringWriter;
-
-@SmallTest
-@RunWith(DevSdkIgnoreRunner.class)
-@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
-public final class BpfInterfaceMapUpdaterTest {
-    private static final int TEST_INDEX = 1;
-    private static final int TEST_INDEX2 = 2;
-    private static final String TEST_INTERFACE_NAME = "test1";
-    private static final String TEST_INTERFACE_NAME2 = "test2";
-
-    private final TestLooper mLooper = new TestLooper();
-    private BaseNetdUnsolicitedEventListener mListener;
-    private BpfInterfaceMapUpdater mUpdater;
-    private IBpfMap<S32, InterfaceMapValue> mBpfMap =
-            spy(new TestBpfMap<>(S32.class, InterfaceMapValue.class));
-    @Mock private INetd mNetd;
-    @Mock private Context mContext;
-
-    private class TestDependencies extends BpfInterfaceMapUpdater.Dependencies {
-        @Override
-        public IBpfMap<S32, InterfaceMapValue> getInterfaceMap() {
-            return mBpfMap;
-        }
-
-        @Override
-        public InterfaceParams getInterfaceParams(String ifaceName) {
-            if (ifaceName.equals(TEST_INTERFACE_NAME)) {
-                return new InterfaceParams(TEST_INTERFACE_NAME, TEST_INDEX,
-                        MacAddress.ALL_ZEROS_ADDRESS);
-            } else if (ifaceName.equals(TEST_INTERFACE_NAME2)) {
-                return new InterfaceParams(TEST_INTERFACE_NAME2, TEST_INDEX2,
-                        MacAddress.ALL_ZEROS_ADDRESS);
-            }
-
-            return null;
-        }
-
-        @Override
-        public INetd getINetd(Context ctx) {
-            return mNetd;
-        }
-    }
-
-    @Before
-    public void setUp() throws Exception {
-        MockitoAnnotations.initMocks(this);
-        when(mNetd.interfaceGetList()).thenReturn(new String[] {TEST_INTERFACE_NAME});
-        mUpdater = new BpfInterfaceMapUpdater(mContext, new Handler(mLooper.getLooper()),
-                new TestDependencies());
-    }
-
-    private void verifyStartUpdater() throws Exception {
-        mUpdater.start();
-        mLooper.dispatchAll();
-        final ArgumentCaptor<BaseNetdUnsolicitedEventListener> listenerCaptor =
-                ArgumentCaptor.forClass(BaseNetdUnsolicitedEventListener.class);
-        verify(mNetd).registerUnsolicitedEventListener(listenerCaptor.capture());
-        mListener = listenerCaptor.getValue();
-        verify(mBpfMap).updateEntry(eq(new S32(TEST_INDEX)),
-                eq(new InterfaceMapValue(TEST_INTERFACE_NAME)));
-    }
-
-    @Test
-    public void testUpdateInterfaceMap() throws Exception {
-        verifyStartUpdater();
-
-        mListener.onInterfaceAdded(TEST_INTERFACE_NAME2);
-        mLooper.dispatchAll();
-        verify(mBpfMap).updateEntry(eq(new S32(TEST_INDEX2)),
-                eq(new InterfaceMapValue(TEST_INTERFACE_NAME2)));
-
-        // Check that when onInterfaceRemoved is called, nothing happens.
-        mListener.onInterfaceRemoved(TEST_INTERFACE_NAME);
-        mLooper.dispatchAll();
-        verifyNoMoreInteractions(mBpfMap);
-    }
-
-    @Test
-    public void testGetIfNameByIndex() throws Exception {
-        mBpfMap.updateEntry(new S32(TEST_INDEX), new InterfaceMapValue(TEST_INTERFACE_NAME));
-        assertEquals(TEST_INTERFACE_NAME, mUpdater.getIfNameByIndex(TEST_INDEX));
-    }
-
-    @Test
-    public void testGetIfNameByIndexNoEntry() {
-        assertNull(mUpdater.getIfNameByIndex(TEST_INDEX));
-    }
-
-    @Test
-    public void testGetIfNameByIndexException() throws Exception {
-        doThrow(new ErrnoException("", EPERM)).when(mBpfMap).getValue(new S32(TEST_INDEX));
-        assertNull(mUpdater.getIfNameByIndex(TEST_INDEX));
-    }
-
-    private void assertDumpContains(final String dump, final String message) {
-        assertTrue(String.format("dump(%s) does not contain '%s'", dump, message),
-                dump.contains(message));
-    }
-
-    private String getDump() {
-        final StringWriter sw = new StringWriter();
-        mUpdater.dump(new IndentingPrintWriter(new PrintWriter(sw), " "));
-        return sw.toString();
-    }
-
-    @Test
-    public void testDump() throws ErrnoException {
-        mBpfMap.updateEntry(new S32(TEST_INDEX), new InterfaceMapValue(TEST_INTERFACE_NAME));
-        mBpfMap.updateEntry(new S32(TEST_INDEX2), new InterfaceMapValue(TEST_INTERFACE_NAME2));
-
-        final String dump = getDump();
-        assertDumpContains(dump, "IfaceIndexNameMap: OK");
-        assertDumpContains(dump, "ifaceIndex=1 ifaceName=test1");
-        assertDumpContains(dump, "ifaceIndex=2 ifaceName=test2");
-    }
-}
diff --git a/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java b/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java
index a5fee5b..3ed51bc 100644
--- a/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java
+++ b/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java
@@ -254,7 +254,7 @@
     private @Mock AlarmManager mAlarmManager;
     @Mock
     private NetworkStatsSubscriptionsMonitor mNetworkStatsSubscriptionsMonitor;
-    private @Mock BpfInterfaceMapUpdater mBpfInterfaceMapUpdater;
+    private @Mock BpfInterfaceMapHelper mBpfInterfaceMapHelper;
     private HandlerThread mHandlerThread;
     @Mock
     private LocationPermissionChecker mLocationPermissionChecker;
@@ -519,9 +519,8 @@
         }
 
         @Override
-        public BpfInterfaceMapUpdater makeBpfInterfaceMapUpdater(
-                @NonNull Context ctx, @NonNull Handler handler) {
-            return mBpfInterfaceMapUpdater;
+        public BpfInterfaceMapHelper makeBpfInterfaceMapHelper() {
+            return mBpfInterfaceMapHelper;
         }
 
         @Override
@@ -2764,13 +2763,13 @@
 
     @Test
     public void testDumpStatsMap() throws ErrnoException {
-        doReturn("wlan0").when(mBpfInterfaceMapUpdater).getIfNameByIndex(10 /* index */);
+        doReturn("wlan0").when(mBpfInterfaceMapHelper).getIfNameByIndex(10 /* index */);
         doTestDumpStatsMap("wlan0");
     }
 
     @Test
     public void testDumpStatsMapUnknownInterface() throws ErrnoException {
-        doReturn(null).when(mBpfInterfaceMapUpdater).getIfNameByIndex(10 /* index */);
+        doReturn(null).when(mBpfInterfaceMapHelper).getIfNameByIndex(10 /* index */);
         doTestDumpStatsMap("unknown");
     }
 
@@ -2785,13 +2784,13 @@
 
     @Test
     public void testDumpIfaceStatsMap() throws Exception {
-        doReturn("wlan0").when(mBpfInterfaceMapUpdater).getIfNameByIndex(10 /* index */);
+        doReturn("wlan0").when(mBpfInterfaceMapHelper).getIfNameByIndex(10 /* index */);
         doTestDumpIfaceStatsMap("wlan0");
     }
 
     @Test
     public void testDumpIfaceStatsMapUnknownInterface() throws Exception {
-        doReturn(null).when(mBpfInterfaceMapUpdater).getIfNameByIndex(10 /* index */);
+        doReturn(null).when(mBpfInterfaceMapHelper).getIfNameByIndex(10 /* index */);
         doTestDumpIfaceStatsMap("unknown");
     }
 
diff --git a/thread/framework/java/android/net/thread/IStateCallback.aidl b/thread/framework/java/android/net/thread/IStateCallback.aidl
index d7cbda9..9d0a571 100644
--- a/thread/framework/java/android/net/thread/IStateCallback.aidl
+++ b/thread/framework/java/android/net/thread/IStateCallback.aidl
@@ -22,4 +22,5 @@
 oneway interface IStateCallback {
     void onDeviceRoleChanged(int deviceRole);
     void onPartitionIdChanged(long partitionId);
+    void onThreadEnableStateChanged(int enabledState);
 }
diff --git a/thread/framework/java/android/net/thread/IThreadNetworkController.aidl b/thread/framework/java/android/net/thread/IThreadNetworkController.aidl
index a9da8d6..485e25d 100644
--- a/thread/framework/java/android/net/thread/IThreadNetworkController.aidl
+++ b/thread/framework/java/android/net/thread/IThreadNetworkController.aidl
@@ -42,4 +42,6 @@
 
     int getThreadVersion();
     void createRandomizedDataset(String networkName, IActiveOperationalDatasetReceiver receiver);
+
+    void setEnabled(boolean enabled, in IOperationReceiver receiver);
 }
diff --git a/thread/framework/java/android/net/thread/ThreadNetworkController.java b/thread/framework/java/android/net/thread/ThreadNetworkController.java
index 7242ed7..db761a3 100644
--- a/thread/framework/java/android/net/thread/ThreadNetworkController.java
+++ b/thread/framework/java/android/net/thread/ThreadNetworkController.java
@@ -68,6 +68,15 @@
     /** The device is a Thread Leader. */
     public static final int DEVICE_ROLE_LEADER = 4;
 
+    /** The Thread radio is disabled. */
+    public static final int STATE_DISABLED = 0;
+
+    /** The Thread radio is enabled. */
+    public static final int STATE_ENABLED = 1;
+
+    /** The Thread radio is being disabled. */
+    public static final int STATE_DISABLING = 2;
+
     /** @hide */
     @Retention(RetentionPolicy.SOURCE)
     @IntDef({
@@ -79,6 +88,13 @@
     })
     public @interface DeviceRole {}
 
+    /** @hide */
+    @Retention(RetentionPolicy.SOURCE)
+    @IntDef(
+            prefix = {"STATE_"},
+            value = {STATE_DISABLED, STATE_ENABLED, STATE_DISABLING})
+    public @interface EnabledState {}
+
     /** Thread standard version 1.3. */
     public static final int THREAD_VERSION_1_3 = 4;
 
@@ -106,6 +122,40 @@
         mControllerService = controllerService;
     }
 
+    /**
+     * Enables/Disables the radio of this ThreadNetworkController. The requested enabled state will
+     * be persistent and survives device reboots.
+     *
+     * <p>When Thread is in {@code STATE_DISABLED}, {@link ThreadNetworkController} APIs which
+     * require the Thread radio will fail with error code {@link
+     * ThreadNetworkException#ERROR_THREAD_DISABLED}. When Thread is in {@code STATE_DISABLING},
+     * {@link ThreadNetworkController} APIs that return a {@link ThreadNetworkException} will fail
+     * with error code {@link ThreadNetworkException#ERROR_BUSY}.
+     *
+     * <p>On success, {@link OutcomeReceiver#onResult} of {@code receiver} is called. It indicates
+     * the operation has completed. But there maybe subsequent calls to update the enabled state,
+     * callers of this method should use {@link #registerStateCallback} to subscribe to the Thread
+     * enabled state changes.
+     *
+     * <p>On failure, {@link OutcomeReceiver#onError} of {@code receiver} will be invoked with a
+     * specific error in {@link ThreadNetworkException#ERROR_}.
+     *
+     * @param enabled {@code true} for enabling Thread
+     * @param executor the executor to execute {@code receiver}
+     * @param receiver the receiver to receive result of this operation
+     */
+    @RequiresPermission("android.permission.THREAD_NETWORK_PRIVILEGED")
+    public void setEnabled(
+            boolean enabled,
+            @NonNull @CallbackExecutor Executor executor,
+            @NonNull OutcomeReceiver<Void, ThreadNetworkException> receiver) {
+        try {
+            mControllerService.setEnabled(enabled, new OperationReceiverProxy(executor, receiver));
+        } catch (RemoteException e) {
+            throw e.rethrowFromSystemServer();
+        }
+    }
+
     /** Returns the Thread version this device is operating on. */
     @ThreadVersion
     public int getThreadVersion() {
@@ -170,6 +220,16 @@
          * @param partitionId the new Thread partition ID
          */
         default void onPartitionIdChanged(long partitionId) {}
+
+        /**
+         * The Thread enabled state has changed.
+         *
+         * <p>The Thread enabled state can be set with {@link setEnabled}, it may also be updated by
+         * airplane mode or admin control.
+         *
+         * @param enabledState the new Thread enabled state
+         */
+        default void onThreadEnableStateChanged(@EnabledState int enabledState) {}
     }
 
     private static final class StateCallbackProxy extends IStateCallback.Stub {
@@ -200,6 +260,16 @@
                 Binder.restoreCallingIdentity(identity);
             }
         }
+
+        @Override
+        public void onThreadEnableStateChanged(@EnabledState int enabled) {
+            final long identity = Binder.clearCallingIdentity();
+            try {
+                mExecutor.execute(() -> mCallback.onThreadEnableStateChanged(enabled));
+            } finally {
+                Binder.restoreCallingIdentity(identity);
+            }
+        }
     }
 
     /**
diff --git a/thread/framework/java/android/net/thread/ThreadNetworkException.java b/thread/framework/java/android/net/thread/ThreadNetworkException.java
index af0a84b..23ed53e 100644
--- a/thread/framework/java/android/net/thread/ThreadNetworkException.java
+++ b/thread/framework/java/android/net/thread/ThreadNetworkException.java
@@ -48,6 +48,7 @@
         ERROR_RESPONSE_BAD_FORMAT,
         ERROR_RESOURCE_EXHAUSTED,
         ERROR_UNKNOWN,
+        ERROR_THREAD_DISABLED,
     })
     public @interface ErrorCode {}
 
@@ -129,6 +130,13 @@
      */
     public static final int ERROR_UNKNOWN = 11;
 
+    /**
+     * The operation failed because the Thread radio is disabled by {@link
+     * ThreadNetworkController#setEnabled}, airplane mode or device admin. The caller should retry
+     * only after Thread is enabled.
+     */
+    public static final int ERROR_THREAD_DISABLED = 12;
+
     private final int mErrorCode;
 
     /** Creates a new {@link ThreadNetworkException} object with given error code and message. */
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
index 1c51c42..7b9f290 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
@@ -27,6 +27,9 @@
 import static android.net.thread.ActiveOperationalDataset.MESH_LOCAL_PREFIX_FIRST_BYTE;
 import static android.net.thread.ActiveOperationalDataset.SecurityPolicy.DEFAULT_ROTATION_TIME_HOURS;
 import static android.net.thread.ThreadNetworkController.DEVICE_ROLE_DETACHED;
+import static android.net.thread.ThreadNetworkController.STATE_DISABLED;
+import static android.net.thread.ThreadNetworkController.STATE_DISABLING;
+import static android.net.thread.ThreadNetworkController.STATE_ENABLED;
 import static android.net.thread.ThreadNetworkController.THREAD_VERSION_1_3;
 import static android.net.thread.ThreadNetworkException.ERROR_ABORTED;
 import static android.net.thread.ThreadNetworkException.ERROR_BUSY;
@@ -35,6 +38,7 @@
 import static android.net.thread.ThreadNetworkException.ERROR_REJECTED_BY_PEER;
 import static android.net.thread.ThreadNetworkException.ERROR_RESOURCE_EXHAUSTED;
 import static android.net.thread.ThreadNetworkException.ERROR_RESPONSE_BAD_FORMAT;
+import static android.net.thread.ThreadNetworkException.ERROR_THREAD_DISABLED;
 import static android.net.thread.ThreadNetworkException.ERROR_TIMEOUT;
 import static android.net.thread.ThreadNetworkException.ERROR_UNSUPPORTED_CHANNEL;
 import static android.net.thread.ThreadNetworkManager.PERMISSION_THREAD_NETWORK_PRIVILEGED;
@@ -48,7 +52,11 @@
 import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_REASSEMBLY_TIMEOUT;
 import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_REJECTED;
 import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_RESPONSE_TIMEOUT;
+import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_THREAD_DISABLED;
 import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_UNSUPPORTED_CHANNEL;
+import static com.android.server.thread.openthread.IOtDaemon.OT_STATE_DISABLED;
+import static com.android.server.thread.openthread.IOtDaemon.OT_STATE_DISABLING;
+import static com.android.server.thread.openthread.IOtDaemon.OT_STATE_ENABLED;
 import static com.android.server.thread.openthread.IOtDaemon.TUN_IF_NAME;
 
 import android.Manifest.permission;
@@ -160,6 +168,7 @@
     private UpstreamNetworkCallback mUpstreamNetworkCallback;
     private TestNetworkSpecifier mUpstreamTestNetworkSpecifier;
     private final HashMap<Network, String> mNetworkToInterface;
+    private final ThreadPersistentSettings mPersistentSettings;
 
     private BorderRouterConfigurationParcel mBorderRouterConfig;
 
@@ -171,7 +180,8 @@
             Supplier<IOtDaemon> otDaemonSupplier,
             ConnectivityManager connectivityManager,
             TunInterfaceController tunIfController,
-            InfraInterfaceController infraIfController) {
+            InfraInterfaceController infraIfController,
+            ThreadPersistentSettings persistentSettings) {
         mContext = context;
         mHandler = handler;
         mNetworkProvider = networkProvider;
@@ -182,9 +192,11 @@
         mUpstreamNetworkRequest = newUpstreamNetworkRequest();
         mNetworkToInterface = new HashMap<Network, String>();
         mBorderRouterConfig = new BorderRouterConfigurationParcel();
+        mPersistentSettings = persistentSettings;
     }
 
-    public static ThreadNetworkControllerService newInstance(Context context) {
+    public static ThreadNetworkControllerService newInstance(
+            Context context, ThreadPersistentSettings persistentSettings) {
         HandlerThread handlerThread = new HandlerThread("ThreadHandlerThread");
         handlerThread.start();
         NetworkProvider networkProvider =
@@ -197,7 +209,8 @@
                 () -> IOtDaemon.Stub.asInterface(ServiceManagerWrapper.waitForService("ot_daemon")),
                 context.getSystemService(ConnectivityManager.class),
                 new TunInterfaceController(TUN_IF_NAME),
-                new InfraInterfaceController());
+                new InfraInterfaceController(),
+                persistentSettings);
     }
 
     private static Inet6Address bytesToInet6Address(byte[] addressBytes) {
@@ -273,7 +286,9 @@
         if (otDaemon == null) {
             throw new RemoteException("Internal error: failed to start OT daemon");
         }
-        otDaemon.initialize(mTunIfController.getTunFd());
+        otDaemon.initialize(
+                mTunIfController.getTunFd(),
+                mPersistentSettings.get(ThreadPersistentSettings.THREAD_ENABLED));
         otDaemon.registerStateCallback(mOtDaemonCallbackProxy, -1);
         otDaemon.asBinder().linkToDeath(() -> mHandler.post(this::onOtDaemonDied), 0);
         mOtDaemon = otDaemon;
@@ -308,6 +323,26 @@
                 });
     }
 
+    public void setEnabled(@NonNull boolean isEnabled, @NonNull IOperationReceiver receiver) {
+        enforceAllPermissionsGranted(PERMISSION_THREAD_NETWORK_PRIVILEGED);
+
+        mHandler.post(() -> setEnabledInternal(isEnabled, new OperationReceiverWrapper(receiver)));
+    }
+
+    private void setEnabledInternal(
+            @NonNull boolean isEnabled, @Nullable OperationReceiverWrapper receiver) {
+        // The persistent setting keeps the desired enabled state, thus it's set regardless
+        // the otDaemon set enabled state operation succeeded or not, so that it can recover
+        // to the desired value after reboot.
+        mPersistentSettings.put(ThreadPersistentSettings.THREAD_ENABLED.key, isEnabled);
+        try {
+            getOtDaemon().setThreadEnabled(isEnabled, newOtStatusReceiver(receiver));
+        } catch (RemoteException e) {
+            Log.e(TAG, "otDaemon.setThreadEnabled failed", e);
+            receiver.onError(ERROR_INTERNAL_ERROR, "Thread stack error");
+        }
+    }
+
     private void requestUpstreamNetwork() {
         if (mUpstreamNetworkCallback != null) {
             throw new AssertionError("The upstream network request is already there.");
@@ -658,6 +693,8 @@
                 return ERROR_REJECTED_BY_PEER;
             case OT_ERROR_UNSUPPORTED_CHANNEL:
                 return ERROR_UNSUPPORTED_CHANNEL;
+            case OT_ERROR_THREAD_DISABLED:
+                return ERROR_THREAD_DISABLED;
             default:
                 return ERROR_INTERNAL_ERROR;
         }
@@ -1001,6 +1038,15 @@
             }
         }
 
+        private void notifyThreadEnabledUpdated(IStateCallback callback, int enabledState) {
+            try {
+                callback.onThreadEnableStateChanged(enabledState);
+                Log.i(TAG, "onThreadEnableStateChanged " + enabledState);
+            } catch (RemoteException ignored) {
+                // do nothing if the client is dead
+            }
+        }
+
         public void unregisterStateCallback(IStateCallback callback) {
             checkOnHandlerThread();
             if (!mStateCallbacks.containsKey(callback)) {
@@ -1065,6 +1111,31 @@
         }
 
         @Override
+        public void onThreadEnabledChanged(int state) {
+            mHandler.post(() -> onThreadEnabledChangedInternal(state));
+        }
+
+        private void onThreadEnabledChangedInternal(int state) {
+            checkOnHandlerThread();
+            for (IStateCallback callback : mStateCallbacks.keySet()) {
+                notifyThreadEnabledUpdated(callback, otStateToAndroidState(state));
+            }
+        }
+
+        private static int otStateToAndroidState(int state) {
+            switch (state) {
+                case OT_STATE_ENABLED:
+                    return STATE_ENABLED;
+                case OT_STATE_DISABLED:
+                    return STATE_DISABLED;
+                case OT_STATE_DISABLING:
+                    return STATE_DISABLING;
+                default:
+                    throw new IllegalArgumentException("Unknown ot state " + state);
+            }
+        }
+
+        @Override
         public void onStateChanged(OtDaemonState newState, long listenerId) {
             mHandler.post(() -> onStateChangedInternal(newState, listenerId));
         }
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkCountryCode.java b/thread/service/java/com/android/server/thread/ThreadNetworkCountryCode.java
index 23aeb93..a1484a4 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkCountryCode.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkCountryCode.java
@@ -310,7 +310,14 @@
         public void onActiveCountryCodeChanged(String countryCode) {
             Log.d(TAG, "Wifi country code is changed to " + countryCode);
             synchronized ("ThreadNetworkCountryCode.this") {
-                mWifiCountryCodeInfo = new CountryCodeInfo(countryCode, COUNTRY_CODE_SOURCE_WIFI);
+                if (isValidCountryCode(countryCode)) {
+                    mWifiCountryCodeInfo =
+                            new CountryCodeInfo(countryCode, COUNTRY_CODE_SOURCE_WIFI);
+                } else {
+                    Log.w(TAG, "WiFi country code " + countryCode + " is invalid");
+                    mWifiCountryCodeInfo = null;
+                }
+
                 updateCountryCode(false /* forceUpdate */);
             }
         }
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkService.java b/thread/service/java/com/android/server/thread/ThreadNetworkService.java
index 53f2d4f..5cf27f7 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkService.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkService.java
@@ -18,16 +18,21 @@
 
 import static android.content.pm.PackageManager.PERMISSION_GRANTED;
 
+import static com.android.net.module.util.DeviceConfigUtils.TETHERING_MODULE_NAME;
+
 import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.content.ApexEnvironment;
 import android.content.Context;
 import android.net.thread.IThreadNetworkController;
 import android.net.thread.IThreadNetworkManager;
 import android.os.Binder;
 import android.os.ParcelFileDescriptor;
+import android.util.AtomicFile;
 
 import com.android.server.SystemService;
 
+import java.io.File;
 import java.io.FileDescriptor;
 import java.io.PrintWriter;
 import java.util.Collections;
@@ -40,11 +45,18 @@
     private final Context mContext;
     @Nullable private ThreadNetworkCountryCode mCountryCode;
     @Nullable private ThreadNetworkControllerService mControllerService;
+    private final ThreadPersistentSettings mPersistentSettings;
     @Nullable private ThreadNetworkShellCommand mShellCommand;
 
     /** Creates a new {@link ThreadNetworkService} object. */
     public ThreadNetworkService(Context context) {
         mContext = context;
+        mPersistentSettings =
+                new ThreadPersistentSettings(
+                        new AtomicFile(
+                                new File(
+                                        getOrCreateThreadnetworkDir(),
+                                        ThreadPersistentSettings.FILE_NAME)));
     }
 
     /**
@@ -54,7 +66,9 @@
      */
     public void onBootPhase(int phase) {
         if (phase == SystemService.PHASE_SYSTEM_SERVICES_READY) {
-            mControllerService = ThreadNetworkControllerService.newInstance(mContext);
+            mPersistentSettings.initialize();
+            mControllerService =
+                    ThreadNetworkControllerService.newInstance(mContext, mPersistentSettings);
             mControllerService.initialize();
         } else if (phase == SystemService.PHASE_BOOT_COMPLETED) {
             // Country code initialization is delayed to the BOOT_COMPLETED phase because it will
@@ -109,4 +123,19 @@
 
         pw.println();
     }
+
+    /** Get device protected storage dir for the tethering apex. */
+    private static File getOrCreateThreadnetworkDir() {
+        final File threadnetworkDir;
+        final File apexDataDir =
+                ApexEnvironment.getApexEnvironment(TETHERING_MODULE_NAME)
+                        .getDeviceProtectedDataDir();
+        threadnetworkDir = new File(apexDataDir, "thread");
+
+        if (threadnetworkDir.exists() || threadnetworkDir.mkdirs()) {
+            return threadnetworkDir;
+        }
+        throw new IllegalStateException(
+                "Cannot write into thread network data directory: " + threadnetworkDir);
+    }
 }
diff --git a/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java b/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java
index 7a6c9aa..7a129dc 100644
--- a/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java
+++ b/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java
@@ -16,11 +16,19 @@
 
 package android.net.thread.cts;
 
+import static android.Manifest.permission.ACCESS_NETWORK_STATE;
+import static android.net.thread.ThreadNetworkController.DEVICE_ROLE_CHILD;
+import static android.net.thread.ThreadNetworkController.DEVICE_ROLE_LEADER;
+import static android.net.thread.ThreadNetworkController.DEVICE_ROLE_ROUTER;
 import static android.net.thread.ThreadNetworkController.DEVICE_ROLE_STOPPED;
+import static android.net.thread.ThreadNetworkController.STATE_DISABLED;
+import static android.net.thread.ThreadNetworkController.STATE_DISABLING;
+import static android.net.thread.ThreadNetworkController.STATE_ENABLED;
 import static android.net.thread.ThreadNetworkController.THREAD_VERSION_1_3;
 import static android.net.thread.ThreadNetworkException.ERROR_ABORTED;
 import static android.net.thread.ThreadNetworkException.ERROR_FAILED_PRECONDITION;
 import static android.net.thread.ThreadNetworkException.ERROR_REJECTED_BY_PEER;
+import static android.net.thread.ThreadNetworkException.ERROR_THREAD_DISABLED;
 
 import static androidx.test.platform.app.InstrumentationRegistry.getInstrumentation;
 
@@ -29,12 +37,12 @@
 import static com.google.common.truth.Truth.assertThat;
 import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
 
+import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertThrows;
 import static org.junit.Assume.assumeNotNull;
 
 import static java.util.concurrent.TimeUnit.MILLISECONDS;
 
-import android.Manifest.permission;
 import android.content.Context;
 import android.net.ConnectivityManager;
 import android.net.Network;
@@ -54,12 +62,11 @@
 import androidx.test.core.app.ApplicationProvider;
 import androidx.test.filters.LargeTest;
 
+import com.android.net.module.util.ArrayTrackRecord;
 import com.android.testutils.DevSdkIgnoreRule;
 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo;
 import com.android.testutils.DevSdkIgnoreRunner;
 
-import com.google.common.util.concurrent.SettableFuture;
-
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Rule;
@@ -68,9 +75,11 @@
 
 import java.time.Duration;
 import java.time.Instant;
+import java.util.ArrayList;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
+import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
@@ -81,8 +90,11 @@
 @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU) // Thread is available on only U+
 public class ThreadNetworkControllerTest {
     private static final int JOIN_TIMEOUT_MILLIS = 30 * 1000;
+    private static final int LEAVE_TIMEOUT_MILLIS = 2_000;
+    private static final int MIGRATION_TIMEOUT_MILLIS = 40 * 1_000;
     private static final int NETWORK_CALLBACK_TIMEOUT_MILLIS = 10 * 1000;
-    private static final int CALLBACK_TIMEOUT_MILLIS = 1000;
+    private static final int CALLBACK_TIMEOUT_MILLIS = 1_000;
+    private static final int ENABLED_TIMEOUT_MILLIS = 2_000;
     private static final String PERMISSION_THREAD_NETWORK_PRIVILEGED =
             "android.permission.THREAD_NETWORK_PRIVILEGED";
 
@@ -95,7 +107,7 @@
     private Set<String> mGrantedPermissions;
 
     @Before
-    public void setUp() {
+    public void setUp() throws Exception {
         mExecutor = Executors.newSingleThreadExecutor();
         mManager = mContext.getSystemService(ThreadNetworkManager.class);
         mGrantedPermissions = new HashSet<String>();
@@ -103,13 +115,17 @@
         // TODO: we will also need it in tearDown(), it's better to have a Rule to skip
         // tests if a feature is not available.
         assumeNotNull(mManager);
+
+        for (ThreadNetworkController controller : getAllControllers()) {
+            setEnabledAndWait(controller, true);
+        }
     }
 
     @After
     public void tearDown() throws Exception {
         if (mManager != null) {
-            leaveAndWait();
             dropAllPermissions();
+            leaveAndWait();
         }
     }
 
@@ -118,12 +134,10 @@
     }
 
     private void leaveAndWait() throws Exception {
-        grantPermissions(PERMISSION_THREAD_NETWORK_PRIVILEGED);
-
         for (ThreadNetworkController controller : getAllControllers()) {
-            SettableFuture<Void> future = SettableFuture.create();
-            controller.leave(mExecutor, future::set);
-            future.get();
+            CompletableFuture<Void> future = new CompletableFuture<>();
+            leave(controller, future::complete);
+            future.get(LEAVE_TIMEOUT_MILLIS, MILLISECONDS);
         }
     }
 
@@ -142,8 +156,8 @@
 
     private static ActiveOperationalDataset newRandomizedDataset(
             String networkName, ThreadNetworkController controller) throws Exception {
-        SettableFuture<ActiveOperationalDataset> future = SettableFuture.create();
-        controller.createRandomizedDataset(networkName, directExecutor(), future::set);
+        CompletableFuture<ActiveOperationalDataset> future = new CompletableFuture<>();
+        controller.createRandomizedDataset(networkName, directExecutor(), future::complete);
         return future.get(CALLBACK_TIMEOUT_MILLIS, MILLISECONDS);
     }
 
@@ -152,33 +166,140 @@
     }
 
     private static int getDeviceRole(ThreadNetworkController controller) throws Exception {
-        SettableFuture<Integer> future = SettableFuture.create();
-        StateCallback callback = future::set;
+        CompletableFuture<Integer> future = new CompletableFuture<>();
+        StateCallback callback = future::complete;
         controller.registerStateCallback(directExecutor(), callback);
         int role = future.get(CALLBACK_TIMEOUT_MILLIS, MILLISECONDS);
         controller.unregisterStateCallback(callback);
         return role;
     }
 
+    private static int waitForAttachedState(ThreadNetworkController controller) throws Exception {
+        List<Integer> attachedRoles = new ArrayList<>();
+        attachedRoles.add(DEVICE_ROLE_CHILD);
+        attachedRoles.add(DEVICE_ROLE_ROUTER);
+        attachedRoles.add(DEVICE_ROLE_LEADER);
+        return waitForStateAnyOf(controller, attachedRoles);
+    }
+
     private static int waitForStateAnyOf(
             ThreadNetworkController controller, List<Integer> deviceRoles) throws Exception {
-        SettableFuture<Integer> future = SettableFuture.create();
+        CompletableFuture<Integer> future = new CompletableFuture<>();
         StateCallback callback =
                 newRole -> {
                     if (deviceRoles.contains(newRole)) {
-                        future.set(newRole);
+                        future.complete(newRole);
                     }
                 };
         controller.registerStateCallback(directExecutor(), callback);
-        int role = future.get();
+        int role = future.get(JOIN_TIMEOUT_MILLIS, MILLISECONDS);
         controller.unregisterStateCallback(callback);
         return role;
     }
 
+    private static void waitForEnabledState(ThreadNetworkController controller, int state)
+            throws Exception {
+        CompletableFuture<Integer> future = new CompletableFuture<>();
+        StateCallback callback =
+                new ThreadNetworkController.StateCallback() {
+                    @Override
+                    public void onDeviceRoleChanged(int r) {}
+
+                    @Override
+                    public void onThreadEnableStateChanged(int enabled) {
+                        if (enabled == state) {
+                            future.complete(enabled);
+                        }
+                    }
+                };
+        runAsShell(
+                ACCESS_NETWORK_STATE,
+                () -> controller.registerStateCallback(directExecutor(), callback));
+        future.get(ENABLED_TIMEOUT_MILLIS, MILLISECONDS);
+        runAsShell(ACCESS_NETWORK_STATE, () -> controller.unregisterStateCallback(callback));
+    }
+
+    private void leave(
+            ThreadNetworkController controller,
+            OutcomeReceiver<Void, ThreadNetworkException> receiver) {
+        runAsShell(
+                PERMISSION_THREAD_NETWORK_PRIVILEGED, () -> controller.leave(mExecutor, receiver));
+    }
+
+    private void scheduleMigration(
+            ThreadNetworkController controller,
+            PendingOperationalDataset pendingDataset,
+            OutcomeReceiver<Void, ThreadNetworkException> receiver) {
+        runAsShell(
+                PERMISSION_THREAD_NETWORK_PRIVILEGED,
+                () -> controller.scheduleMigration(pendingDataset, mExecutor, receiver));
+    }
+
+    private class EnabledStateListener {
+        private ArrayTrackRecord<Integer> mEnabledStates = new ArrayTrackRecord<>();
+        private final ArrayTrackRecord<Integer>.ReadHead mReadHead = mEnabledStates.newReadHead();
+        ThreadNetworkController mController;
+        StateCallback mCallback =
+                new ThreadNetworkController.StateCallback() {
+                    @Override
+                    public void onDeviceRoleChanged(int r) {}
+
+                    @Override
+                    public void onThreadEnableStateChanged(int enabled) {
+                        mEnabledStates.add(enabled);
+                    }
+                };
+
+        EnabledStateListener(ThreadNetworkController controller) {
+            this.mController = controller;
+            runAsShell(
+                    ACCESS_NETWORK_STATE,
+                    () -> controller.registerStateCallback(mExecutor, mCallback));
+        }
+
+        public void expectThreadEnabledState(int enabled) {
+            assertNotNull(mReadHead.poll(ENABLED_TIMEOUT_MILLIS, e -> (e == enabled)));
+        }
+
+        public void unregisterStateCallback() {
+            runAsShell(ACCESS_NETWORK_STATE, () -> mController.unregisterStateCallback(mCallback));
+        }
+    }
+
+    private int booleanToEnabledState(boolean enabled) {
+        return enabled ? STATE_ENABLED : STATE_DISABLED;
+    }
+
+    private void setEnabledAndWait(ThreadNetworkController controller, boolean enabled)
+            throws Exception {
+        CompletableFuture<Void> setFuture = new CompletableFuture<>();
+        runAsShell(
+                PERMISSION_THREAD_NETWORK_PRIVILEGED,
+                () -> controller.setEnabled(enabled, mExecutor, newOutcomeReceiver(setFuture)));
+        setFuture.get(ENABLED_TIMEOUT_MILLIS, MILLISECONDS);
+        waitForEnabledState(controller, booleanToEnabledState(enabled));
+    }
+
+    private CompletableFuture joinRandomizedDataset(ThreadNetworkController controller)
+            throws Exception {
+        ActiveOperationalDataset activeDataset = newRandomizedDataset("TestNet", controller);
+        CompletableFuture<Void> joinFuture = new CompletableFuture<>();
+        runAsShell(
+                PERMISSION_THREAD_NETWORK_PRIVILEGED,
+                () -> controller.join(activeDataset, mExecutor, newOutcomeReceiver(joinFuture)));
+        return joinFuture;
+    }
+
+    private void joinRandomizedDatasetAndWait(ThreadNetworkController controller) throws Exception {
+        CompletableFuture<Void> joinFuture = joinRandomizedDataset(controller);
+        joinFuture.get(JOIN_TIMEOUT_MILLIS, MILLISECONDS);
+        runAsShell(ACCESS_NETWORK_STATE, () -> assertThat(isAttached(controller)).isTrue());
+    }
+
     private static ActiveOperationalDataset getActiveOperationalDataset(
             ThreadNetworkController controller) throws Exception {
-        SettableFuture<ActiveOperationalDataset> future = SettableFuture.create();
-        OperationalDatasetCallback callback = future::set;
+        CompletableFuture<ActiveOperationalDataset> future = new CompletableFuture<>();
+        OperationalDatasetCallback callback = future::complete;
         controller.registerOperationalDatasetCallback(directExecutor(), callback);
         ActiveOperationalDataset dataset = future.get(CALLBACK_TIMEOUT_MILLIS, MILLISECONDS);
         controller.unregisterOperationalDatasetCallback(callback);
@@ -187,27 +308,27 @@
 
     private static PendingOperationalDataset getPendingOperationalDataset(
             ThreadNetworkController controller) throws Exception {
-        SettableFuture<ActiveOperationalDataset> activeFuture = SettableFuture.create();
-        SettableFuture<PendingOperationalDataset> pendingFuture = SettableFuture.create();
+        CompletableFuture<ActiveOperationalDataset> activeFuture = new CompletableFuture<>();
+        CompletableFuture<PendingOperationalDataset> pendingFuture = new CompletableFuture<>();
         controller.registerOperationalDatasetCallback(
                 directExecutor(), newDatasetCallback(activeFuture, pendingFuture));
-        return pendingFuture.get();
+        return pendingFuture.get(CALLBACK_TIMEOUT_MILLIS, MILLISECONDS);
     }
 
     private static OperationalDatasetCallback newDatasetCallback(
-            SettableFuture<ActiveOperationalDataset> activeFuture,
-            SettableFuture<PendingOperationalDataset> pendingFuture) {
+            CompletableFuture<ActiveOperationalDataset> activeFuture,
+            CompletableFuture<PendingOperationalDataset> pendingFuture) {
         return new OperationalDatasetCallback() {
             @Override
             public void onActiveOperationalDatasetChanged(
                     ActiveOperationalDataset activeOpDataset) {
-                activeFuture.set(activeOpDataset);
+                activeFuture.complete(activeOpDataset);
             }
 
             @Override
             public void onPendingOperationalDatasetChanged(
                     PendingOperationalDataset pendingOpDataset) {
-                pendingFuture.set(pendingOpDataset);
+                pendingFuture.complete(pendingOpDataset);
             }
         };
     }
@@ -221,16 +342,17 @@
 
     @Test
     public void registerStateCallback_permissionsGranted_returnsCurrentStates() throws Exception {
-        grantPermissions(permission.ACCESS_NETWORK_STATE);
+        grantPermissions(ACCESS_NETWORK_STATE);
 
         for (ThreadNetworkController controller : getAllControllers()) {
-            SettableFuture<Integer> deviceRole = SettableFuture.create();
-            StateCallback callback = deviceRole::set;
+            CompletableFuture<Integer> deviceRole = new CompletableFuture<>();
+            StateCallback callback = deviceRole::complete;
 
             try {
                 controller.registerStateCallback(mExecutor, callback);
 
-                assertThat(deviceRole.get()).isEqualTo(DEVICE_ROLE_STOPPED);
+                assertThat(deviceRole.get(CALLBACK_TIMEOUT_MILLIS, MILLISECONDS))
+                        .isEqualTo(DEVICE_ROLE_STOPPED);
             } finally {
                 controller.unregisterStateCallback(callback);
             }
@@ -238,6 +360,36 @@
     }
 
     @Test
+    public void registerStateCallback_returnsUpdatedEnabledStates() throws Exception {
+        for (ThreadNetworkController controller : getAllControllers()) {
+            CompletableFuture<Void> setFuture1 = new CompletableFuture<>();
+            CompletableFuture<Void> setFuture2 = new CompletableFuture<>();
+            EnabledStateListener listener = new EnabledStateListener(controller);
+
+            runAsShell(
+                    PERMISSION_THREAD_NETWORK_PRIVILEGED,
+                    () -> {
+                        controller.setEnabled(false, mExecutor, newOutcomeReceiver(setFuture1));
+                    });
+            setFuture1.get(ENABLED_TIMEOUT_MILLIS, MILLISECONDS);
+
+            runAsShell(
+                    PERMISSION_THREAD_NETWORK_PRIVILEGED,
+                    () -> {
+                        controller.setEnabled(true, mExecutor, newOutcomeReceiver(setFuture2));
+                    });
+            setFuture2.get(ENABLED_TIMEOUT_MILLIS, MILLISECONDS);
+
+            listener.expectThreadEnabledState(STATE_ENABLED);
+            listener.expectThreadEnabledState(STATE_DISABLING);
+            listener.expectThreadEnabledState(STATE_DISABLED);
+            listener.expectThreadEnabledState(STATE_ENABLED);
+
+            listener.unregisterStateCallback();
+        }
+    }
+
+    @Test
     public void registerStateCallback_noPermissions_throwsSecurityException() throws Exception {
         dropAllPermissions();
 
@@ -251,11 +403,11 @@
     @Test
     public void registerStateCallback_alreadyRegistered_throwsIllegalArgumentException()
             throws Exception {
-        grantPermissions(permission.ACCESS_NETWORK_STATE);
+        grantPermissions(ACCESS_NETWORK_STATE);
 
         for (ThreadNetworkController controller : getAllControllers()) {
-            SettableFuture<Integer> deviceRole = SettableFuture.create();
-            StateCallback callback = role -> deviceRole.set(role);
+            CompletableFuture<Integer> deviceRole = new CompletableFuture<>();
+            StateCallback callback = role -> deviceRole.complete(role);
             controller.registerStateCallback(mExecutor, callback);
 
             assertThrows(
@@ -267,9 +419,9 @@
     @Test
     public void unregisterStateCallback_noPermissions_throwsSecurityException() throws Exception {
         for (ThreadNetworkController controller : getAllControllers()) {
-            SettableFuture<Integer> deviceRole = SettableFuture.create();
-            StateCallback callback = role -> deviceRole.set(role);
-            grantPermissions(permission.ACCESS_NETWORK_STATE);
+            CompletableFuture<Integer> deviceRole = new CompletableFuture<>();
+            StateCallback callback = role -> deviceRole.complete(role);
+            grantPermissions(ACCESS_NETWORK_STATE);
             controller.registerStateCallback(mExecutor, callback);
 
             try {
@@ -278,7 +430,7 @@
                         SecurityException.class,
                         () -> controller.unregisterStateCallback(callback));
             } finally {
-                grantPermissions(permission.ACCESS_NETWORK_STATE);
+                grantPermissions(ACCESS_NETWORK_STATE);
                 controller.unregisterStateCallback(callback);
             }
         }
@@ -286,10 +438,10 @@
 
     @Test
     public void unregisterStateCallback_callbackRegistered_success() throws Exception {
-        grantPermissions(permission.ACCESS_NETWORK_STATE);
+        grantPermissions(ACCESS_NETWORK_STATE);
         for (ThreadNetworkController controller : getAllControllers()) {
-            SettableFuture<Integer> deviceRole = SettableFuture.create();
-            StateCallback callback = role -> deviceRole.set(role);
+            CompletableFuture<Integer> deviceRole = new CompletableFuture<>();
+            StateCallback callback = role -> deviceRole.complete(role);
             controller.registerStateCallback(mExecutor, callback);
 
             controller.unregisterStateCallback(callback);
@@ -300,8 +452,8 @@
     public void unregisterStateCallback_callbackNotRegistered_throwsIllegalArgumentException()
             throws Exception {
         for (ThreadNetworkController controller : getAllControllers()) {
-            SettableFuture<Integer> deviceRole = SettableFuture.create();
-            StateCallback callback = role -> deviceRole.set(role);
+            CompletableFuture<Integer> deviceRole = new CompletableFuture<>();
+            StateCallback callback = role -> deviceRole.complete(role);
 
             assertThrows(
                     IllegalArgumentException.class,
@@ -312,10 +464,10 @@
     @Test
     public void unregisterStateCallback_alreadyUnregistered_throwsIllegalArgumentException()
             throws Exception {
-        grantPermissions(permission.ACCESS_NETWORK_STATE);
+        grantPermissions(ACCESS_NETWORK_STATE);
         for (ThreadNetworkController controller : getAllControllers()) {
-            SettableFuture<Integer> deviceRole = SettableFuture.create();
-            StateCallback callback = deviceRole::set;
+            CompletableFuture<Integer> deviceRole = new CompletableFuture<>();
+            StateCallback callback = deviceRole::complete;
             controller.registerStateCallback(mExecutor, callback);
             controller.unregisterStateCallback(callback);
 
@@ -328,18 +480,18 @@
     @Test
     public void registerOperationalDatasetCallback_permissionsGranted_returnsCurrentStates()
             throws Exception {
-        grantPermissions(permission.ACCESS_NETWORK_STATE, PERMISSION_THREAD_NETWORK_PRIVILEGED);
+        grantPermissions(ACCESS_NETWORK_STATE, PERMISSION_THREAD_NETWORK_PRIVILEGED);
 
         for (ThreadNetworkController controller : getAllControllers()) {
-            SettableFuture<ActiveOperationalDataset> activeFuture = SettableFuture.create();
-            SettableFuture<PendingOperationalDataset> pendingFuture = SettableFuture.create();
+            CompletableFuture<ActiveOperationalDataset> activeFuture = new CompletableFuture<>();
+            CompletableFuture<PendingOperationalDataset> pendingFuture = new CompletableFuture<>();
             var callback = newDatasetCallback(activeFuture, pendingFuture);
 
             try {
                 controller.registerOperationalDatasetCallback(mExecutor, callback);
 
-                assertThat(activeFuture.get()).isNull();
-                assertThat(pendingFuture.get()).isNull();
+                assertThat(activeFuture.get(CALLBACK_TIMEOUT_MILLIS, MILLISECONDS)).isNull();
+                assertThat(pendingFuture.get(CALLBACK_TIMEOUT_MILLIS, MILLISECONDS)).isNull();
             } finally {
                 controller.unregisterOperationalDatasetCallback(callback);
             }
@@ -352,8 +504,8 @@
         dropAllPermissions();
 
         for (ThreadNetworkController controller : getAllControllers()) {
-            SettableFuture<ActiveOperationalDataset> activeFuture = SettableFuture.create();
-            SettableFuture<PendingOperationalDataset> pendingFuture = SettableFuture.create();
+            CompletableFuture<ActiveOperationalDataset> activeFuture = new CompletableFuture<>();
+            CompletableFuture<PendingOperationalDataset> pendingFuture = new CompletableFuture<>();
             var callback = newDatasetCallback(activeFuture, pendingFuture);
 
             assertThrows(
@@ -364,10 +516,10 @@
 
     @Test
     public void unregisterOperationalDatasetCallback_callbackRegistered_success() throws Exception {
-        grantPermissions(permission.ACCESS_NETWORK_STATE, PERMISSION_THREAD_NETWORK_PRIVILEGED);
+        grantPermissions(ACCESS_NETWORK_STATE, PERMISSION_THREAD_NETWORK_PRIVILEGED);
         for (ThreadNetworkController controller : getAllControllers()) {
-            SettableFuture<ActiveOperationalDataset> activeFuture = SettableFuture.create();
-            SettableFuture<PendingOperationalDataset> pendingFuture = SettableFuture.create();
+            CompletableFuture<ActiveOperationalDataset> activeFuture = new CompletableFuture<>();
+            CompletableFuture<PendingOperationalDataset> pendingFuture = new CompletableFuture<>();
             var callback = newDatasetCallback(activeFuture, pendingFuture);
             controller.registerOperationalDatasetCallback(mExecutor, callback);
 
@@ -381,10 +533,10 @@
         dropAllPermissions();
 
         for (ThreadNetworkController controller : getAllControllers()) {
-            SettableFuture<ActiveOperationalDataset> activeFuture = SettableFuture.create();
-            SettableFuture<PendingOperationalDataset> pendingFuture = SettableFuture.create();
+            CompletableFuture<ActiveOperationalDataset> activeFuture = new CompletableFuture<>();
+            CompletableFuture<PendingOperationalDataset> pendingFuture = new CompletableFuture<>();
             var callback = newDatasetCallback(activeFuture, pendingFuture);
-            grantPermissions(permission.ACCESS_NETWORK_STATE, PERMISSION_THREAD_NETWORK_PRIVILEGED);
+            grantPermissions(ACCESS_NETWORK_STATE, PERMISSION_THREAD_NETWORK_PRIVILEGED);
             controller.registerOperationalDatasetCallback(mExecutor, callback);
 
             try {
@@ -393,24 +545,23 @@
                         SecurityException.class,
                         () -> controller.unregisterOperationalDatasetCallback(callback));
             } finally {
-                grantPermissions(
-                        permission.ACCESS_NETWORK_STATE, PERMISSION_THREAD_NETWORK_PRIVILEGED);
+                grantPermissions(ACCESS_NETWORK_STATE, PERMISSION_THREAD_NETWORK_PRIVILEGED);
                 controller.unregisterOperationalDatasetCallback(callback);
             }
         }
     }
 
     private static <V> OutcomeReceiver<V, ThreadNetworkException> newOutcomeReceiver(
-            SettableFuture<V> future) {
+            CompletableFuture<V> future) {
         return new OutcomeReceiver<V, ThreadNetworkException>() {
             @Override
             public void onResult(V result) {
-                future.set(result);
+                future.complete(result);
             }
 
             @Override
             public void onError(ThreadNetworkException e) {
-                future.setException(e);
+                future.completeExceptionally(e);
             }
         };
     }
@@ -421,12 +572,12 @@
 
         for (ThreadNetworkController controller : getAllControllers()) {
             ActiveOperationalDataset activeDataset = newRandomizedDataset("TestNet", controller);
-            SettableFuture<Void> joinFuture = SettableFuture.create();
+            CompletableFuture<Void> joinFuture = new CompletableFuture<>();
 
             controller.join(activeDataset, mExecutor, newOutcomeReceiver(joinFuture));
-            joinFuture.get();
+            joinFuture.get(JOIN_TIMEOUT_MILLIS, MILLISECONDS);
 
-            grantPermissions(permission.ACCESS_NETWORK_STATE);
+            grantPermissions(ACCESS_NETWORK_STATE);
             assertThat(isAttached(controller)).isTrue();
             assertThat(getActiveOperationalDataset(controller)).isEqualTo(activeDataset);
         }
@@ -446,6 +597,20 @@
     }
 
     @Test
+    public void join_threadDisabled_failsWithErrorThreadDisabled() throws Exception {
+        for (ThreadNetworkController controller : getAllControllers()) {
+            setEnabledAndWait(controller, false);
+
+            CompletableFuture<Void> joinFuture = joinRandomizedDataset(controller);
+
+            ThreadNetworkException thrown =
+                    (ThreadNetworkException)
+                            assertThrows(ExecutionException.class, joinFuture::get).getCause();
+            assertThat(thrown.getErrorCode()).isEqualTo(ERROR_THREAD_DISABLED);
+        }
+    }
+
+    @Test
     public void join_concurrentRequests_firstOneIsAborted() throws Exception {
         grantPermissions(PERMISSION_THREAD_NETWORK_PRIVILEGED);
 
@@ -461,8 +626,8 @@
                     new ActiveOperationalDataset.Builder(activeDataset1)
                             .setNetworkKey(KEY_2)
                             .build();
-            SettableFuture<Void> joinFuture1 = SettableFuture.create();
-            SettableFuture<Void> joinFuture2 = SettableFuture.create();
+            CompletableFuture<Void> joinFuture1 = new CompletableFuture<>();
+            CompletableFuture<Void> joinFuture2 = new CompletableFuture<>();
 
             controller.join(activeDataset1, mExecutor, newOutcomeReceiver(joinFuture1));
             controller.join(activeDataset2, mExecutor, newOutcomeReceiver(joinFuture2));
@@ -471,8 +636,8 @@
                     (ThreadNetworkException)
                             assertThrows(ExecutionException.class, joinFuture1::get).getCause();
             assertThat(thrown.getErrorCode()).isEqualTo(ERROR_ABORTED);
-            joinFuture2.get();
-            grantPermissions(permission.ACCESS_NETWORK_STATE);
+            joinFuture2.get(JOIN_TIMEOUT_MILLIS, MILLISECONDS);
+            grantPermissions(ACCESS_NETWORK_STATE);
             assertThat(isAttached(controller)).isTrue();
             assertThat(getActiveOperationalDataset(controller)).isEqualTo(activeDataset2);
         }
@@ -480,19 +645,14 @@
 
     @Test
     public void leave_withPrivilegedPermission_success() throws Exception {
-        grantPermissions(PERMISSION_THREAD_NETWORK_PRIVILEGED);
-
         for (ThreadNetworkController controller : getAllControllers()) {
-            ActiveOperationalDataset activeDataset = newRandomizedDataset("TestNet", controller);
-            SettableFuture<Void> joinFuture = SettableFuture.create();
-            SettableFuture<Void> leaveFuture = SettableFuture.create();
-            controller.join(activeDataset, mExecutor, newOutcomeReceiver(joinFuture));
-            joinFuture.get();
+            joinRandomizedDatasetAndWait(controller);
 
-            controller.leave(mExecutor, newOutcomeReceiver(leaveFuture));
-            leaveFuture.get();
+            CompletableFuture<Void> leaveFuture = new CompletableFuture<>();
+            leave(controller, newOutcomeReceiver(leaveFuture));
+            leaveFuture.get(LEAVE_TIMEOUT_MILLIS, MILLISECONDS);
 
-            grantPermissions(permission.ACCESS_NETWORK_STATE);
+            grantPermissions(ACCESS_NETWORK_STATE);
             assertThat(getDeviceRole(controller)).isEqualTo(DEVICE_ROLE_STOPPED);
         }
     }
@@ -507,30 +667,46 @@
     }
 
     @Test
+    public void leave_threadDisabled_success() throws Exception {
+        for (ThreadNetworkController controller : getAllControllers()) {
+            joinRandomizedDatasetAndWait(controller);
+
+            CompletableFuture<Void> leaveFuture = new CompletableFuture<>();
+            setEnabledAndWait(controller, false);
+            leave(controller, newOutcomeReceiver(leaveFuture));
+
+            leaveFuture.get(LEAVE_TIMEOUT_MILLIS, MILLISECONDS);
+            runAsShell(
+                    ACCESS_NETWORK_STATE,
+                    () -> assertThat(getDeviceRole(controller)).isEqualTo(DEVICE_ROLE_STOPPED));
+        }
+    }
+
+    @Test
     public void leave_concurrentRequests_bothSuccess() throws Exception {
         grantPermissions(PERMISSION_THREAD_NETWORK_PRIVILEGED);
 
         for (ThreadNetworkController controller : getAllControllers()) {
             ActiveOperationalDataset activeDataset = newRandomizedDataset("TestNet", controller);
-            SettableFuture<Void> joinFuture = SettableFuture.create();
-            SettableFuture<Void> leaveFuture1 = SettableFuture.create();
-            SettableFuture<Void> leaveFuture2 = SettableFuture.create();
+            CompletableFuture<Void> joinFuture = new CompletableFuture<>();
+            CompletableFuture<Void> leaveFuture1 = new CompletableFuture<>();
+            CompletableFuture<Void> leaveFuture2 = new CompletableFuture<>();
             controller.join(activeDataset, mExecutor, newOutcomeReceiver(joinFuture));
-            joinFuture.get();
+            joinFuture.get(JOIN_TIMEOUT_MILLIS, MILLISECONDS);
 
             controller.leave(mExecutor, newOutcomeReceiver(leaveFuture1));
             controller.leave(mExecutor, newOutcomeReceiver(leaveFuture2));
 
-            leaveFuture1.get();
-            leaveFuture2.get();
-            grantPermissions(permission.ACCESS_NETWORK_STATE);
+            leaveFuture1.get(LEAVE_TIMEOUT_MILLIS, MILLISECONDS);
+            leaveFuture2.get(LEAVE_TIMEOUT_MILLIS, MILLISECONDS);
+            grantPermissions(ACCESS_NETWORK_STATE);
             assertThat(getDeviceRole(controller)).isEqualTo(DEVICE_ROLE_STOPPED);
         }
     }
 
     @Test
     public void scheduleMigration_withPrivilegedPermission_newDatasetApplied() throws Exception {
-        grantPermissions(permission.ACCESS_NETWORK_STATE, PERMISSION_THREAD_NETWORK_PRIVILEGED);
+        grantPermissions(ACCESS_NETWORK_STATE, PERMISSION_THREAD_NETWORK_PRIVILEGED);
 
         for (ThreadNetworkController controller : getAllControllers()) {
             ActiveOperationalDataset activeDataset1 =
@@ -549,24 +725,24 @@
                             activeDataset2,
                             OperationalDatasetTimestamp.fromInstant(Instant.now()),
                             Duration.ofSeconds(30));
-            SettableFuture<Void> joinFuture = SettableFuture.create();
-            SettableFuture<Void> migrateFuture = SettableFuture.create();
+            CompletableFuture<Void> joinFuture = new CompletableFuture<>();
+            CompletableFuture<Void> migrateFuture = new CompletableFuture<>();
             controller.join(activeDataset1, mExecutor, newOutcomeReceiver(joinFuture));
-            joinFuture.get();
+            joinFuture.get(JOIN_TIMEOUT_MILLIS, MILLISECONDS);
 
             controller.scheduleMigration(
                     pendingDataset2, mExecutor, newOutcomeReceiver(migrateFuture));
-            migrateFuture.get();
+            migrateFuture.get(CALLBACK_TIMEOUT_MILLIS, MILLISECONDS);
 
-            SettableFuture<Boolean> dataset2IsApplied = SettableFuture.create();
-            SettableFuture<Boolean> pendingDatasetIsRemoved = SettableFuture.create();
+            CompletableFuture<Boolean> dataset2IsApplied = new CompletableFuture<>();
+            CompletableFuture<Boolean> pendingDatasetIsRemoved = new CompletableFuture<>();
             OperationalDatasetCallback datasetCallback =
                     new OperationalDatasetCallback() {
                         @Override
                         public void onActiveOperationalDatasetChanged(
                                 ActiveOperationalDataset activeDataset) {
                             if (activeDataset.equals(activeDataset2)) {
-                                dataset2IsApplied.set(true);
+                                dataset2IsApplied.complete(true);
                             }
                         }
 
@@ -574,20 +750,20 @@
                         public void onPendingOperationalDatasetChanged(
                                 PendingOperationalDataset pendingDataset) {
                             if (pendingDataset == null) {
-                                pendingDatasetIsRemoved.set(true);
+                                pendingDatasetIsRemoved.complete(true);
                             }
                         }
                     };
             controller.registerOperationalDatasetCallback(directExecutor(), datasetCallback);
-            assertThat(dataset2IsApplied.get()).isTrue();
-            assertThat(pendingDatasetIsRemoved.get()).isTrue();
+            assertThat(dataset2IsApplied.get(MIGRATION_TIMEOUT_MILLIS, MILLISECONDS)).isTrue();
+            assertThat(pendingDatasetIsRemoved.get(CALLBACK_TIMEOUT_MILLIS, MILLISECONDS)).isTrue();
             controller.unregisterOperationalDatasetCallback(datasetCallback);
         }
     }
 
     @Test
     public void scheduleMigration_whenNotAttached_failWithPreconditionError() throws Exception {
-        grantPermissions(permission.ACCESS_NETWORK_STATE, PERMISSION_THREAD_NETWORK_PRIVILEGED);
+        grantPermissions(ACCESS_NETWORK_STATE, PERMISSION_THREAD_NETWORK_PRIVILEGED);
 
         for (ThreadNetworkController controller : getAllControllers()) {
             PendingOperationalDataset pendingDataset =
@@ -595,7 +771,7 @@
                             newRandomizedDataset("TestNet", controller),
                             OperationalDatasetTimestamp.fromInstant(Instant.now()),
                             Duration.ofSeconds(30));
-            SettableFuture<Void> migrateFuture = SettableFuture.create();
+            CompletableFuture<Void> migrateFuture = new CompletableFuture<>();
 
             controller.scheduleMigration(
                     pendingDataset, mExecutor, newOutcomeReceiver(migrateFuture));
@@ -610,7 +786,7 @@
     @Test
     public void scheduleMigration_secondRequestHasSmallerTimestamp_rejectedByLeader()
             throws Exception {
-        grantPermissions(permission.ACCESS_NETWORK_STATE, PERMISSION_THREAD_NETWORK_PRIVILEGED);
+        grantPermissions(ACCESS_NETWORK_STATE, PERMISSION_THREAD_NETWORK_PRIVILEGED);
 
         for (ThreadNetworkController controller : getAllControllers()) {
             final ActiveOperationalDataset activeDataset =
@@ -638,15 +814,15 @@
                             activeDataset2,
                             new OperationalDatasetTimestamp(20, 0, false),
                             Duration.ofSeconds(30));
-            SettableFuture<Void> joinFuture = SettableFuture.create();
-            SettableFuture<Void> migrateFuture1 = SettableFuture.create();
-            SettableFuture<Void> migrateFuture2 = SettableFuture.create();
+            CompletableFuture<Void> joinFuture = new CompletableFuture<>();
+            CompletableFuture<Void> migrateFuture1 = new CompletableFuture<>();
+            CompletableFuture<Void> migrateFuture2 = new CompletableFuture<>();
             controller.join(activeDataset, mExecutor, newOutcomeReceiver(joinFuture));
-            joinFuture.get();
+            joinFuture.get(JOIN_TIMEOUT_MILLIS, MILLISECONDS);
 
             controller.scheduleMigration(
                     pendingDataset1, mExecutor, newOutcomeReceiver(migrateFuture1));
-            migrateFuture1.get();
+            migrateFuture1.get(CALLBACK_TIMEOUT_MILLIS, MILLISECONDS);
             controller.scheduleMigration(
                     pendingDataset2, mExecutor, newOutcomeReceiver(migrateFuture2));
 
@@ -660,7 +836,7 @@
     @Test
     public void scheduleMigration_secondRequestHasLargerTimestamp_newDatasetApplied()
             throws Exception {
-        grantPermissions(permission.ACCESS_NETWORK_STATE, PERMISSION_THREAD_NETWORK_PRIVILEGED);
+        grantPermissions(ACCESS_NETWORK_STATE, PERMISSION_THREAD_NETWORK_PRIVILEGED);
 
         for (ThreadNetworkController controller : getAllControllers()) {
             final ActiveOperationalDataset activeDataset =
@@ -688,28 +864,28 @@
                             activeDataset2,
                             new OperationalDatasetTimestamp(200, 0, false),
                             Duration.ofSeconds(30));
-            SettableFuture<Void> joinFuture = SettableFuture.create();
-            SettableFuture<Void> migrateFuture1 = SettableFuture.create();
-            SettableFuture<Void> migrateFuture2 = SettableFuture.create();
+            CompletableFuture<Void> joinFuture = new CompletableFuture<>();
+            CompletableFuture<Void> migrateFuture1 = new CompletableFuture<>();
+            CompletableFuture<Void> migrateFuture2 = new CompletableFuture<>();
             controller.join(activeDataset, mExecutor, newOutcomeReceiver(joinFuture));
-            joinFuture.get();
+            joinFuture.get(JOIN_TIMEOUT_MILLIS, MILLISECONDS);
 
             controller.scheduleMigration(
                     pendingDataset1, mExecutor, newOutcomeReceiver(migrateFuture1));
-            migrateFuture1.get();
+            migrateFuture1.get(CALLBACK_TIMEOUT_MILLIS, MILLISECONDS);
             controller.scheduleMigration(
                     pendingDataset2, mExecutor, newOutcomeReceiver(migrateFuture2));
-            migrateFuture2.get();
+            migrateFuture2.get(CALLBACK_TIMEOUT_MILLIS, MILLISECONDS);
 
-            SettableFuture<Boolean> dataset2IsApplied = SettableFuture.create();
-            SettableFuture<Boolean> pendingDatasetIsRemoved = SettableFuture.create();
+            CompletableFuture<Boolean> dataset2IsApplied = new CompletableFuture<>();
+            CompletableFuture<Boolean> pendingDatasetIsRemoved = new CompletableFuture<>();
             OperationalDatasetCallback datasetCallback =
                     new OperationalDatasetCallback() {
                         @Override
                         public void onActiveOperationalDatasetChanged(
                                 ActiveOperationalDataset activeDataset) {
                             if (activeDataset.equals(activeDataset2)) {
-                                dataset2IsApplied.set(true);
+                                dataset2IsApplied.complete(true);
                             }
                         }
 
@@ -717,18 +893,41 @@
                         public void onPendingOperationalDatasetChanged(
                                 PendingOperationalDataset pendingDataset) {
                             if (pendingDataset == null) {
-                                pendingDatasetIsRemoved.set(true);
+                                pendingDatasetIsRemoved.complete(true);
                             }
                         }
                     };
             controller.registerOperationalDatasetCallback(directExecutor(), datasetCallback);
-            assertThat(dataset2IsApplied.get()).isTrue();
-            assertThat(pendingDatasetIsRemoved.get()).isTrue();
+            assertThat(dataset2IsApplied.get(MIGRATION_TIMEOUT_MILLIS, MILLISECONDS)).isTrue();
+            assertThat(pendingDatasetIsRemoved.get(CALLBACK_TIMEOUT_MILLIS, MILLISECONDS)).isTrue();
             controller.unregisterOperationalDatasetCallback(datasetCallback);
         }
     }
 
     @Test
+    public void scheduleMigration_threadDisabled_failsWithErrorThreadDisabled() throws Exception {
+        for (ThreadNetworkController controller : getAllControllers()) {
+            ActiveOperationalDataset activeDataset = newRandomizedDataset("TestNet", controller);
+            PendingOperationalDataset pendingDataset =
+                    new PendingOperationalDataset(
+                            activeDataset,
+                            OperationalDatasetTimestamp.fromInstant(Instant.now()),
+                            Duration.ofSeconds(30));
+            joinRandomizedDatasetAndWait(controller);
+            CompletableFuture<Void> migrationFuture = new CompletableFuture<>();
+
+            setEnabledAndWait(controller, false);
+
+            scheduleMigration(controller, pendingDataset, newOutcomeReceiver(migrationFuture));
+
+            ThreadNetworkException thrown =
+                    (ThreadNetworkException)
+                            assertThrows(ExecutionException.class, migrationFuture::get).getCause();
+            assertThat(thrown.getErrorCode()).isEqualTo(ERROR_THREAD_DISABLED);
+        }
+    }
+
+    @Test
     public void createRandomizedDataset_wrongNetworkNameLength_throwsIllegalArgumentException() {
         for (ThreadNetworkController controller : getAllControllers()) {
             assertThrows(
@@ -760,11 +959,105 @@
     }
 
     @Test
+    public void setEnabled_permissionsGranted_succeeds() throws Exception {
+        for (ThreadNetworkController controller : getAllControllers()) {
+            CompletableFuture<Void> setFuture1 = new CompletableFuture<>();
+            CompletableFuture<Void> setFuture2 = new CompletableFuture<>();
+
+            runAsShell(
+                    PERMISSION_THREAD_NETWORK_PRIVILEGED,
+                    () -> controller.setEnabled(false, mExecutor, newOutcomeReceiver(setFuture1)));
+            setFuture1.get(ENABLED_TIMEOUT_MILLIS, MILLISECONDS);
+            waitForEnabledState(controller, booleanToEnabledState(false));
+
+            runAsShell(
+                    PERMISSION_THREAD_NETWORK_PRIVILEGED,
+                    () -> controller.setEnabled(true, mExecutor, newOutcomeReceiver(setFuture2)));
+            setFuture2.get(ENABLED_TIMEOUT_MILLIS, MILLISECONDS);
+            waitForEnabledState(controller, booleanToEnabledState(true));
+        }
+    }
+
+    @Test
+    public void setEnabled_noPermissions_throwsSecurityException() throws Exception {
+        for (ThreadNetworkController controller : getAllControllers()) {
+            CompletableFuture<Void> setFuture = new CompletableFuture<>();
+            assertThrows(
+                    SecurityException.class,
+                    () -> controller.setEnabled(false, mExecutor, newOutcomeReceiver(setFuture)));
+        }
+    }
+
+    @Test
+    public void setEnabled_disable_leavesThreadNetwork() throws Exception {
+        for (ThreadNetworkController controller : getAllControllers()) {
+            joinRandomizedDatasetAndWait(controller);
+
+            setEnabledAndWait(controller, false);
+
+            runAsShell(
+                    ACCESS_NETWORK_STATE,
+                    () -> assertThat(getDeviceRole(controller)).isEqualTo(DEVICE_ROLE_STOPPED));
+        }
+    }
+
+    @Test
+    public void setEnabled_toggleAfterJoin_joinsThreadNetworkAgain() throws Exception {
+        for (ThreadNetworkController controller : getAllControllers()) {
+            joinRandomizedDatasetAndWait(controller);
+
+            setEnabledAndWait(controller, false);
+
+            runAsShell(
+                    ACCESS_NETWORK_STATE,
+                    () -> assertThat(getDeviceRole(controller)).isEqualTo(DEVICE_ROLE_STOPPED));
+
+            setEnabledAndWait(controller, true);
+
+            runAsShell(ACCESS_NETWORK_STATE, () -> waitForAttachedState(controller));
+        }
+    }
+
+    @Test
+    public void setEnabled_enableFollowedByDisable_allSucceed() throws Exception {
+        for (ThreadNetworkController controller : getAllControllers()) {
+            joinRandomizedDatasetAndWait(controller);
+            CompletableFuture<Void> setFuture1 = new CompletableFuture<>();
+            CompletableFuture<Void> setFuture2 = new CompletableFuture<>();
+            EnabledStateListener listener = new EnabledStateListener(controller);
+            listener.expectThreadEnabledState(STATE_ENABLED);
+
+            runAsShell(
+                    PERMISSION_THREAD_NETWORK_PRIVILEGED,
+                    () -> {
+                        controller.setEnabled(true, mExecutor, newOutcomeReceiver(setFuture1));
+                        controller.setEnabled(false, mExecutor, newOutcomeReceiver(setFuture2));
+                    });
+
+            setFuture1.get(ENABLED_TIMEOUT_MILLIS, MILLISECONDS);
+            setFuture2.get(ENABLED_TIMEOUT_MILLIS, MILLISECONDS);
+
+            listener.expectThreadEnabledState(STATE_DISABLING);
+            listener.expectThreadEnabledState(STATE_DISABLED);
+
+            runAsShell(
+                    ACCESS_NETWORK_STATE,
+                    () -> assertThat(getDeviceRole(controller)).isEqualTo(DEVICE_ROLE_STOPPED));
+
+            listener.unregisterStateCallback();
+        }
+    }
+    // TODO (b/322437869): add test case to verify when Thread is in DISABLING state, any commands
+    // (join/leave/scheduleMigration/setEnabled) fail with ERROR_BUSY. This is not currently tested
+    // because DISABLING has very short lifecycle, it's not possible to guarantee the command can be
+    // sent before state changes to DISABLED.
+
+    @Test
     public void threadNetworkCallback_deviceAttached_threadNetworkIsAvailable() throws Exception {
         ThreadNetworkController controller = mManager.getAllThreadNetworkControllers().get(0);
         ActiveOperationalDataset activeDataset = newRandomizedDataset("TestNet", controller);
-        SettableFuture<Void> joinFuture = SettableFuture.create();
-        SettableFuture<Network> networkFuture = SettableFuture.create();
+        CompletableFuture<Void> joinFuture = new CompletableFuture<>();
+        CompletableFuture<Network> networkFuture = new CompletableFuture<>();
         ConnectivityManager cm = mContext.getSystemService(ConnectivityManager.class);
         NetworkRequest networkRequest =
                 new NetworkRequest.Builder()
@@ -774,7 +1067,7 @@
                 new ConnectivityManager.NetworkCallback() {
                     @Override
                     public void onAvailable(Network network) {
-                        networkFuture.set(network);
+                        networkFuture.complete(network);
                     }
                 };
 
@@ -782,12 +1075,11 @@
                 PERMISSION_THREAD_NETWORK_PRIVILEGED,
                 () -> controller.join(activeDataset, mExecutor, newOutcomeReceiver(joinFuture)));
         runAsShell(
-                permission.ACCESS_NETWORK_STATE,
+                ACCESS_NETWORK_STATE,
                 () -> cm.registerNetworkCallback(networkRequest, networkCallback));
 
         joinFuture.get(JOIN_TIMEOUT_MILLIS, MILLISECONDS);
-        runAsShell(
-                permission.ACCESS_NETWORK_STATE, () -> assertThat(isAttached(controller)).isTrue());
+        runAsShell(ACCESS_NETWORK_STATE, () -> assertThat(isAttached(controller)).isTrue());
         assertThat(networkFuture.get(NETWORK_CALLBACK_TIMEOUT_MILLIS, MILLISECONDS)).isNotNull();
     }
 }
diff --git a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
index 44a8ab7..1d83abc 100644
--- a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
+++ b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
@@ -24,6 +24,7 @@
 import static com.google.common.io.BaseEncoding.base16;
 import static com.google.common.truth.Truth.assertThat;
 
+import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.mock;
@@ -85,6 +86,7 @@
     @Mock private TunInterfaceController mMockTunIfController;
     @Mock private ParcelFileDescriptor mMockTunFd;
     @Mock private InfraInterfaceController mMockInfraIfController;
+    @Mock private ThreadPersistentSettings mMockPersistentSettings;
     private Context mContext;
     private TestLooper mTestLooper;
     private FakeOtDaemon mFakeOtDaemon;
@@ -104,6 +106,8 @@
 
         when(mMockTunIfController.getTunFd()).thenReturn(mMockTunFd);
 
+        when(mMockPersistentSettings.get(any())).thenReturn(true);
+
         mService =
                 new ThreadNetworkControllerService(
                         ApplicationProvider.getApplicationContext(),
@@ -112,7 +116,8 @@
                         () -> mFakeOtDaemon,
                         mMockConnectivityManager,
                         mMockTunIfController,
-                        mMockInfraIfController);
+                        mMockInfraIfController,
+                        mMockPersistentSettings);
         mService.setTestNetworkAgent(mMockNetworkAgent);
     }
 
diff --git a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkCountryCodeTest.java b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkCountryCodeTest.java
index 670449d..5ca6511 100644
--- a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkCountryCodeTest.java
+++ b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkCountryCodeTest.java
@@ -86,6 +86,7 @@
     private static final String TEST_COUNTRY_CODE_US = "US";
     private static final String TEST_COUNTRY_CODE_CN = "CN";
     private static final String TEST_COUNTRY_CODE_INVALID = "INVALID";
+    private static final String TEST_WIFI_DEFAULT_COUNTRY_CODE = "00";
     private static final int TEST_SIM_SLOT_INDEX_0 = 0;
     private static final int TEST_SIM_SLOT_INDEX_1 = 1;
 
@@ -259,6 +260,21 @@
     }
 
     @Test
+    public void wifiCountryCode_wifiDefaultCountryCodeIsActive_wifiCountryCodeIsNotUsed() {
+        mThreadNetworkCountryCode.initialize();
+
+        verify(mWifiManager)
+                .registerActiveCountryCodeChangedCallback(
+                        any(), mWifiCountryCodeReceiverCaptor.capture());
+        mWifiCountryCodeReceiverCaptor
+                .getValue()
+                .onActiveCountryCodeChanged(TEST_WIFI_DEFAULT_COUNTRY_CODE);
+
+        assertThat(mThreadNetworkCountryCode.getCountryCode())
+                .isNotEqualTo(TEST_WIFI_DEFAULT_COUNTRY_CODE);
+    }
+
+    @Test
     public void wifiCountryCode_wifiCountryCodeIsInactive_defaultCountryCodeIsUsed() {
         mThreadNetworkCountryCode.initialize();
         verify(mWifiManager)