Merge "Reset the LISTEN_ACTIVITY_TIMEOUT_MS timer" into main
diff --git a/Tethering/src/com/android/networkstack/tethering/Tethering.java b/Tethering/src/com/android/networkstack/tethering/Tethering.java
index 3277363..da3b584 100644
--- a/Tethering/src/com/android/networkstack/tethering/Tethering.java
+++ b/Tethering/src/com/android/networkstack/tethering/Tethering.java
@@ -2734,84 +2734,73 @@
         }
     }
 
-    private IpServer.Callback makeControlCallback() {
-        return new IpServer.Callback() {
-            @Override
-            public void updateInterfaceState(IpServer who, int state, int lastError) {
-                notifyInterfaceStateChange(who, state, lastError);
+    private class ControlCallback extends IpServer.Callback {
+        @Override
+        public void updateInterfaceState(IpServer who, int state, int lastError) {
+            final String iface = who.interfaceName();
+            final TetherState tetherState = mTetherStates.get(iface);
+            if (tetherState != null && tetherState.ipServer.equals(who)) {
+                tetherState.lastState = state;
+                tetherState.lastError = lastError;
+            } else {
+                if (DBG) Log.d(TAG, "got notification from stale iface " + iface);
             }
 
-            @Override
-            public void updateLinkProperties(IpServer who, LinkProperties newLp) {
-                notifyLinkPropertiesChanged(who, newLp);
-            }
+            mLog.log(String.format("OBSERVED iface=%s state=%s error=%s", iface, state, lastError));
 
-            @Override
-            public void dhcpLeasesChanged() {
-                maybeDhcpLeasesChanged();
+            // If TetherMainSM is in ErrorState, TetherMainSM stays there.
+            // Thus we give a chance for TetherMainSM to recover to InitialState
+            // by sending CMD_CLEAR_ERROR
+            if (lastError == TETHER_ERROR_INTERNAL_ERROR) {
+                mTetherMainSM.sendMessage(TetherMainSM.CMD_CLEAR_ERROR, who);
             }
-
-            @Override
-            public void requestEnableTethering(int tetheringType, boolean enabled) {
-                mTetherMainSM.sendMessage(TetherMainSM.EVENT_REQUEST_CHANGE_DOWNSTREAM,
-                        tetheringType, 0, enabled ? Boolean.TRUE : Boolean.FALSE);
+            int which;
+            switch (state) {
+                case IpServer.STATE_UNAVAILABLE:
+                case IpServer.STATE_AVAILABLE:
+                    which = TetherMainSM.EVENT_IFACE_SERVING_STATE_INACTIVE;
+                    break;
+                case IpServer.STATE_TETHERED:
+                case IpServer.STATE_LOCAL_ONLY:
+                    which = TetherMainSM.EVENT_IFACE_SERVING_STATE_ACTIVE;
+                    break;
+                default:
+                    Log.wtf(TAG, "Unknown interface state: " + state);
+                    return;
             }
-        };
-    }
-
-    // TODO: Move into TetherMainSM.
-    private void notifyInterfaceStateChange(IpServer who, int state, int error) {
-        final String iface = who.interfaceName();
-        final TetherState tetherState = mTetherStates.get(iface);
-        if (tetherState != null && tetherState.ipServer.equals(who)) {
-            tetherState.lastState = state;
-            tetherState.lastError = error;
-        } else {
-            if (DBG) Log.d(TAG, "got notification from stale iface " + iface);
+            mTetherMainSM.sendMessage(which, state, 0, who);
+            sendTetherStateChangedBroadcast();
         }
 
-        mLog.log(String.format("OBSERVED iface=%s state=%s error=%s", iface, state, error));
-
-        // If TetherMainSM is in ErrorState, TetherMainSM stays there.
-        // Thus we give a chance for TetherMainSM to recover to InitialState
-        // by sending CMD_CLEAR_ERROR
-        if (error == TETHER_ERROR_INTERNAL_ERROR) {
-            mTetherMainSM.sendMessage(TetherMainSM.CMD_CLEAR_ERROR, who);
-        }
-        int which;
-        switch (state) {
-            case IpServer.STATE_UNAVAILABLE:
-            case IpServer.STATE_AVAILABLE:
-                which = TetherMainSM.EVENT_IFACE_SERVING_STATE_INACTIVE;
-                break;
-            case IpServer.STATE_TETHERED:
-            case IpServer.STATE_LOCAL_ONLY:
-                which = TetherMainSM.EVENT_IFACE_SERVING_STATE_ACTIVE;
-                break;
-            default:
-                Log.wtf(TAG, "Unknown interface state: " + state);
+        @Override
+        public void updateLinkProperties(IpServer who, LinkProperties newLp) {
+            final String iface = who.interfaceName();
+            final int state;
+            final TetherState tetherState = mTetherStates.get(iface);
+            if (tetherState != null && tetherState.ipServer.equals(who)) {
+                state = tetherState.lastState;
+            } else {
+                mLog.log("got notification from stale iface " + iface);
                 return;
-        }
-        mTetherMainSM.sendMessage(which, state, 0, who);
-        sendTetherStateChangedBroadcast();
-    }
+            }
 
-    private void notifyLinkPropertiesChanged(IpServer who, LinkProperties newLp) {
-        final String iface = who.interfaceName();
-        final int state;
-        final TetherState tetherState = mTetherStates.get(iface);
-        if (tetherState != null && tetherState.ipServer.equals(who)) {
-            state = tetherState.lastState;
-        } else {
-            mLog.log("got notification from stale iface " + iface);
-            return;
+            mLog.log(String.format(
+                    "OBSERVED LinkProperties update iface=%s state=%s lp=%s",
+                    iface, IpServer.getStateString(state), newLp));
+            final int which = TetherMainSM.EVENT_IFACE_UPDATE_LINKPROPERTIES;
+            mTetherMainSM.sendMessage(which, state, 0, newLp);
         }
 
-        mLog.log(String.format(
-                "OBSERVED LinkProperties update iface=%s state=%s lp=%s",
-                iface, IpServer.getStateString(state), newLp));
-        final int which = TetherMainSM.EVENT_IFACE_UPDATE_LINKPROPERTIES;
-        mTetherMainSM.sendMessage(which, state, 0, newLp);
+        @Override
+        public void dhcpLeasesChanged() {
+            maybeDhcpLeasesChanged();
+        }
+
+        @Override
+        public void requestEnableTethering(int tetheringType, boolean enabled) {
+            mTetherMainSM.sendMessage(TetherMainSM.EVENT_REQUEST_CHANGE_DOWNSTREAM,
+                    tetheringType, 0, enabled ? Boolean.TRUE : Boolean.FALSE);
+        }
     }
 
     private boolean hasSystemFeature(final String feature) {
@@ -2853,7 +2842,7 @@
         mLog.i("adding IpServer for: " + iface);
         final TetherState tetherState = new TetherState(
                 new IpServer(iface, mHandler, interfaceType, mLog, mNetd, mBpfCoordinator,
-                        mRoutingCoordinator, makeControlCallback(), mConfig,
+                        mRoutingCoordinator, new ControlCallback(), mConfig,
                         mPrivateAddressCoordinator, mTetheringMetrics,
                         mDeps.getIpServerDependencies()), isNcm);
         mTetherStates.put(iface, tetherState);
diff --git a/Tethering/tests/unit/src/com/android/networkstack/tethering/BpfCoordinatorTest.java b/Tethering/tests/unit/src/com/android/networkstack/tethering/BpfCoordinatorTest.java
index 01600b8..47ecf58 100644
--- a/Tethering/tests/unit/src/com/android/networkstack/tethering/BpfCoordinatorTest.java
+++ b/Tethering/tests/unit/src/com/android/networkstack/tethering/BpfCoordinatorTest.java
@@ -1280,9 +1280,9 @@
         final Ipv6DownstreamRule rule = buildTestDownstreamRule(mobileIfIndex, NEIGH_A, MAC_A);
 
         final TetherDownstream6Key key = rule.makeTetherDownstream6Key();
-        assertEquals(key.iif, mobileIfIndex);
-        assertEquals(key.dstMac, MacAddress.ALL_ZEROS_ADDRESS);  // rawip upstream
-        assertTrue(Arrays.equals(key.neigh6, NEIGH_A.getAddress()));
+        assertEquals(mobileIfIndex, key.iif);
+        assertEquals(MacAddress.ALL_ZEROS_ADDRESS, key.dstMac);  // rawip upstream
+        assertArrayEquals(NEIGH_A.getAddress(), key.neigh6);
         // iif (4) + dstMac(6) + padding(2) + neigh6 (16) = 28.
         assertEquals(28, key.writeToBytes().length);
     }
@@ -1293,12 +1293,12 @@
         final Ipv6DownstreamRule rule = buildTestDownstreamRule(mobileIfIndex, NEIGH_A, MAC_A);
 
         final Tether6Value value = rule.makeTether6Value();
-        assertEquals(value.oif, DOWNSTREAM_IFINDEX);
-        assertEquals(value.ethDstMac, MAC_A);
-        assertEquals(value.ethSrcMac, DOWNSTREAM_MAC);
-        assertEquals(value.ethProto, ETH_P_IPV6);
-        assertEquals(value.pmtu, NetworkStackConstants.ETHER_MTU);
-        // oif (4) + ethDstMac (6) + ethSrcMac (6) + ethProto (2) + pmtu (2) = 20.
+        assertEquals(DOWNSTREAM_IFINDEX, value.oif);
+        assertEquals(MAC_A, value.ethDstMac);
+        assertEquals(DOWNSTREAM_MAC, value.ethSrcMac);
+        assertEquals(ETH_P_IPV6, value.ethProto);
+        assertEquals(NetworkStackConstants.ETHER_MTU, value.pmtu);
+        // oif (4) + ethDstMac (6) + ethSrcMac (6) + ethProto (2) + pmtu (2) = 20
         assertEquals(20, value.writeToBytes().length);
     }
 
diff --git a/common/flags.aconfig b/common/flags.aconfig
index b85c2fe..ebfa13a 100644
--- a/common/flags.aconfig
+++ b/common/flags.aconfig
@@ -41,3 +41,10 @@
   description: "Block network access for apps in a low importance background state"
   bug: "304347838"
 }
+
+flag {
+  name: "register_nsd_offload_engine"
+  namespace: "android_core_networking"
+  description: "The flag controls the access for registerOffloadEngine API in NsdManager"
+  bug: "294777050"
+}
diff --git a/framework-t/src/android/net/nsd/NsdManager.java b/framework-t/src/android/net/nsd/NsdManager.java
index ef0e34b..dae8914 100644
--- a/framework-t/src/android/net/nsd/NsdManager.java
+++ b/framework-t/src/android/net/nsd/NsdManager.java
@@ -34,10 +34,10 @@
 import android.content.Context;
 import android.net.ConnectivityManager;
 import android.net.ConnectivityManager.NetworkCallback;
+import android.net.ConnectivityThread;
 import android.net.Network;
 import android.net.NetworkRequest;
 import android.os.Handler;
-import android.os.HandlerThread;
 import android.os.Looper;
 import android.os.Message;
 import android.os.RemoteException;
@@ -632,10 +632,9 @@
      */
     public NsdManager(Context context, INsdManager service) {
         mContext = context;
-
-        HandlerThread t = new HandlerThread("NsdManager");
-        t.start();
-        mHandler = new ServiceHandler(t.getLooper());
+        // Use a common singleton thread ConnectivityThread to be shared among all nsd tasks.
+        // Instead of launching separate threads to handle tasks from the various instances.
+        mHandler = new ServiceHandler(ConnectivityThread.getInstanceLooper());
 
         try {
             mService = service.connect(new NsdCallbackImpl(mHandler), CompatChanges.isChangeEnabled(
diff --git a/tests/unit/java/android/net/nsd/NsdManagerTest.java b/tests/unit/java/android/net/nsd/NsdManagerTest.java
index 0965193..550a9ee 100644
--- a/tests/unit/java/android/net/nsd/NsdManagerTest.java
+++ b/tests/unit/java/android/net/nsd/NsdManagerTest.java
@@ -51,6 +51,7 @@
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
+@DevSdkIgnoreRunner.MonitorThreadLeak
 @RunWith(DevSdkIgnoreRunner.class)
 @SmallTest
 @DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
diff --git a/tests/unit/java/com/android/server/NsdServiceTest.java b/tests/unit/java/com/android/server/NsdServiceTest.java
index 771edb2..ffc8aa1 100644
--- a/tests/unit/java/com/android/server/NsdServiceTest.java
+++ b/tests/unit/java/com/android/server/NsdServiceTest.java
@@ -145,6 +145,7 @@
 // TODOs:
 //  - test client can send requests and receive replies
 //  - test NSD_ON ENABLE/DISABLED listening
+@DevSdkIgnoreRunner.MonitorThreadLeak
 @RunWith(DevSdkIgnoreRunner.class)
 @SmallTest
 @DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
diff --git a/thread/framework/java/android/net/thread/IThreadNetworkController.aidl b/thread/framework/java/android/net/thread/IThreadNetworkController.aidl
index 51e4d88..89dcd39 100644
--- a/thread/framework/java/android/net/thread/IThreadNetworkController.aidl
+++ b/thread/framework/java/android/net/thread/IThreadNetworkController.aidl
@@ -16,7 +16,6 @@
 
 package android.net.thread;
 
-import android.net.Network;
 import android.net.thread.ActiveOperationalDataset;
 import android.net.thread.IActiveOperationalDatasetReceiver;
 import android.net.thread.IOperationalDatasetCallback;
diff --git a/thread/framework/java/android/net/thread/ThreadNetworkController.java b/thread/framework/java/android/net/thread/ThreadNetworkController.java
index 5c5fda9..34b0b06 100644
--- a/thread/framework/java/android/net/thread/ThreadNetworkController.java
+++ b/thread/framework/java/android/net/thread/ThreadNetworkController.java
@@ -26,6 +26,7 @@
 import android.annotation.Nullable;
 import android.annotation.RequiresPermission;
 import android.annotation.SystemApi;
+import android.os.Binder;
 import android.os.OutcomeReceiver;
 import android.os.RemoteException;
 
@@ -98,7 +99,8 @@
     private final Map<OperationalDatasetCallback, OperationalDatasetCallbackProxy>
             mOpDatasetCallbackMap = new HashMap<>();
 
-    ThreadNetworkController(@NonNull IThreadNetworkController controllerService) {
+    /** @hide */
+    public ThreadNetworkController(@NonNull IThreadNetworkController controllerService) {
         requireNonNull(controllerService, "controllerService cannot be null");
         mControllerService = controllerService;
     }
@@ -180,12 +182,22 @@
 
         @Override
         public void onDeviceRoleChanged(@DeviceRole int deviceRole) {
-            mExecutor.execute(() -> mCallback.onDeviceRoleChanged(deviceRole));
+            final long identity = Binder.clearCallingIdentity();
+            try {
+                mExecutor.execute(() -> mCallback.onDeviceRoleChanged(deviceRole));
+            } finally {
+                Binder.restoreCallingIdentity(identity);
+            }
         }
 
         @Override
         public void onPartitionIdChanged(long partitionId) {
-            mExecutor.execute(() -> mCallback.onPartitionIdChanged(partitionId));
+            final long identity = Binder.clearCallingIdentity();
+            try {
+                mExecutor.execute(() -> mCallback.onPartitionIdChanged(partitionId));
+            } finally {
+                Binder.restoreCallingIdentity(identity);
+            }
         }
     }
 
@@ -282,13 +294,24 @@
         @Override
         public void onActiveOperationalDatasetChanged(
                 @Nullable ActiveOperationalDataset activeDataset) {
-            mExecutor.execute(() -> mCallback.onActiveOperationalDatasetChanged(activeDataset));
+            final long identity = Binder.clearCallingIdentity();
+            try {
+                mExecutor.execute(() -> mCallback.onActiveOperationalDatasetChanged(activeDataset));
+            } finally {
+                Binder.restoreCallingIdentity(identity);
+            }
         }
 
         @Override
         public void onPendingOperationalDatasetChanged(
                 @Nullable PendingOperationalDataset pendingDataset) {
-            mExecutor.execute(() -> mCallback.onPendingOperationalDatasetChanged(pendingDataset));
+            final long identity = Binder.clearCallingIdentity();
+            try {
+                mExecutor.execute(
+                        () -> mCallback.onPendingOperationalDatasetChanged(pendingDataset));
+            } finally {
+                Binder.restoreCallingIdentity(identity);
+            }
         }
     }
 
@@ -481,7 +504,13 @@
             OutcomeReceiver<T, ThreadNetworkException> receiver,
             int errorCode,
             String errorMsg) {
-        executor.execute(() -> receiver.onError(new ThreadNetworkException(errorCode, errorMsg)));
+        final long identity = Binder.clearCallingIdentity();
+        try {
+            executor.execute(
+                    () -> receiver.onError(new ThreadNetworkException(errorCode, errorMsg)));
+        } finally {
+            Binder.restoreCallingIdentity(identity);
+        }
     }
 
     private static final class ActiveDatasetReceiverProxy
@@ -498,7 +527,12 @@
 
         @Override
         public void onSuccess(ActiveOperationalDataset dataset) {
-            mExecutor.execute(() -> mResultReceiver.onResult(dataset));
+            final long identity = Binder.clearCallingIdentity();
+            try {
+                mExecutor.execute(() -> mResultReceiver.onResult(dataset));
+            } finally {
+                Binder.restoreCallingIdentity(identity);
+            }
         }
 
         @Override
@@ -520,7 +554,12 @@
 
         @Override
         public void onSuccess() {
-            mExecutor.execute(() -> mResultReceiver.onResult(null));
+            final long identity = Binder.clearCallingIdentity();
+            try {
+                mExecutor.execute(() -> mResultReceiver.onResult(null));
+            } finally {
+                Binder.restoreCallingIdentity(identity);
+            }
         }
 
         @Override
diff --git a/thread/tests/unit/Android.bp b/thread/tests/unit/Android.bp
index 5863673..8092693 100644
--- a/thread/tests/unit/Android.bp
+++ b/thread/tests/unit/Android.bp
@@ -33,11 +33,11 @@
     static_libs: [
         "androidx.test.ext.junit",
         "compatibility-device-util-axt",
-        "ctstestrunner-axt",
         "framework-connectivity-pre-jarjar",
         "framework-connectivity-t-pre-jarjar",
         "guava",
         "guava-android-testlib",
+        "mockito-target-minus-junit4",
         "net-tests-utils",
         "truth",
     ],
@@ -45,6 +45,7 @@
         "android.test.base",
         "android.test.runner",
     ],
+    jarjar_rules: ":connectivity-jarjar-rules",
     // Test coverage system runs on different devices. Need to
     // compile for all architectures.
     compile_multilib: "both",
diff --git a/thread/tests/unit/AndroidTest.xml b/thread/tests/unit/AndroidTest.xml
index 663ff74..597c6a8 100644
--- a/thread/tests/unit/AndroidTest.xml
+++ b/thread/tests/unit/AndroidTest.xml
@@ -27,6 +27,7 @@
 
     <test class="com.android.tradefed.testtype.AndroidJUnitTest" >
         <option name="package" value="android.net.thread.unittests" />
+        <option name="hidden-api-checks" value="false"/>
         <!-- Ignores tests introduced by guava-android-testlib -->
         <option name="exclude-annotation" value="org.junit.Ignore"/>
     </test>
diff --git a/thread/tests/unit/src/android/net/thread/ThreadNetworkControllerTest.java b/thread/tests/unit/src/android/net/thread/ThreadNetworkControllerTest.java
new file mode 100644
index 0000000..2f120b2
--- /dev/null
+++ b/thread/tests/unit/src/android/net/thread/ThreadNetworkControllerTest.java
@@ -0,0 +1,362 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.net.thread;
+
+import static android.net.thread.ThreadNetworkController.DEVICE_ROLE_CHILD;
+import static android.net.thread.ThreadNetworkException.ERROR_UNAVAILABLE;
+import static android.net.thread.ThreadNetworkException.ERROR_UNSUPPORTED_CHANNEL;
+import static android.os.Process.SYSTEM_UID;
+
+import static com.google.common.io.BaseEncoding.base16;
+import static com.google.common.truth.Truth.assertThat;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.doAnswer;
+
+import android.net.thread.IActiveOperationalDatasetReceiver;
+import android.net.thread.IOperationReceiver;
+import android.net.thread.IOperationalDatasetCallback;
+import android.net.thread.IStateCallback;
+import android.net.thread.IThreadNetworkController;
+import android.net.thread.ThreadNetworkController.OperationalDatasetCallback;
+import android.net.thread.ThreadNetworkController.StateCallback;
+import android.os.Binder;
+import android.os.OutcomeReceiver;
+import android.os.Process;
+
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.mockito.invocation.InvocationOnMock;
+
+import java.time.Duration;
+import java.util.concurrent.atomic.AtomicInteger;
+
+/** Unit tests for {@link ThreadNetworkController}. */
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public final class ThreadNetworkControllerTest {
+
+    @Mock private IThreadNetworkController mMockService;
+    private ThreadNetworkController mController;
+
+    // A valid Thread Active Operational Dataset generated from OpenThread CLI "dataset new":
+    // Active Timestamp: 1
+    // Channel: 19
+    // Channel Mask: 0x07FFF800
+    // Ext PAN ID: ACC214689BC40BDF
+    // Mesh Local Prefix: fd64:db12:25f4:7e0b::/64
+    // Network Key: F26B3153760F519A63BAFDDFFC80D2AF
+    // Network Name: OpenThread-d9a0
+    // PAN ID: 0xD9A0
+    // PSKc: A245479C836D551B9CA557F7B9D351B4
+    // Security Policy: 672 onrcb
+    private static final byte[] DEFAULT_DATASET_TLVS =
+            base16().decode(
+                            "0E080000000000010000000300001335060004001FFFE002"
+                                    + "08ACC214689BC40BDF0708FD64DB1225F47E0B0510F26B31"
+                                    + "53760F519A63BAFDDFFC80D2AF030F4F70656E5468726561"
+                                    + "642D643961300102D9A00410A245479C836D551B9CA557F7"
+                                    + "B9D351B40C0402A0FFF8");
+
+    private static final ActiveOperationalDataset DEFAULT_DATASET =
+            ActiveOperationalDataset.fromThreadTlvs(DEFAULT_DATASET_TLVS);
+
+    @Before
+    public void setUp() {
+        MockitoAnnotations.initMocks(this);
+        mController = new ThreadNetworkController(mMockService);
+    }
+
+    private static void setBinderUid(int uid) {
+        // TODO: generally, it's not a good practice to depend on the implementation detail to set
+        // a custom UID, but Connectivity, Wifi, UWB and etc modules are using this trick. Maybe
+        // define a interface (e.b. CallerIdentityInjector) for easier mocking.
+        Binder.restoreCallingIdentity((((long) uid) << 32) | Binder.getCallingPid());
+    }
+
+    private static IStateCallback getStateCallback(InvocationOnMock invocation) {
+        return (IStateCallback) invocation.getArguments()[0];
+    }
+
+    private static IOperationReceiver getOperationReceiver(InvocationOnMock invocation) {
+        return (IOperationReceiver) invocation.getArguments()[0];
+    }
+
+    private static IOperationReceiver getJoinReceiver(InvocationOnMock invocation) {
+        return (IOperationReceiver) invocation.getArguments()[1];
+    }
+
+    private static IOperationReceiver getScheduleMigrationReceiver(InvocationOnMock invocation) {
+        return (IOperationReceiver) invocation.getArguments()[1];
+    }
+
+    private static IActiveOperationalDatasetReceiver getCreateDatasetReceiver(
+            InvocationOnMock invocation) {
+        return (IActiveOperationalDatasetReceiver) invocation.getArguments()[1];
+    }
+
+    private static IOperationalDatasetCallback getOperationalDatasetCallback(
+            InvocationOnMock invocation) {
+        return (IOperationalDatasetCallback) invocation.getArguments()[0];
+    }
+
+    @Test
+    public void registerStateCallback_callbackIsInvokedWithCallingAppIdentity() throws Exception {
+        setBinderUid(SYSTEM_UID);
+        doAnswer(
+                        invoke -> {
+                            getStateCallback(invoke).onDeviceRoleChanged(DEVICE_ROLE_CHILD);
+                            return null;
+                        })
+                .when(mMockService)
+                .registerStateCallback(any(IStateCallback.class));
+        AtomicInteger callbackUid = new AtomicInteger(0);
+        StateCallback callback = state -> callbackUid.set(Binder.getCallingUid());
+
+        try {
+            mController.registerStateCallback(Runnable::run, callback);
+
+            assertThat(callbackUid.get()).isNotEqualTo(SYSTEM_UID);
+            assertThat(callbackUid.get()).isEqualTo(Process.myUid());
+        } finally {
+            mController.unregisterStateCallback(callback);
+        }
+    }
+
+    @Test
+    public void registerOperationalDatasetCallback_callbackIsInvokedWithCallingAppIdentity()
+            throws Exception {
+        setBinderUid(SYSTEM_UID);
+        doAnswer(
+                        invoke -> {
+                            getOperationalDatasetCallback(invoke)
+                                    .onActiveOperationalDatasetChanged(null);
+                            getOperationalDatasetCallback(invoke)
+                                    .onPendingOperationalDatasetChanged(null);
+                            return null;
+                        })
+                .when(mMockService)
+                .registerOperationalDatasetCallback(any(IOperationalDatasetCallback.class));
+        AtomicInteger activeCallbackUid = new AtomicInteger(0);
+        AtomicInteger pendingCallbackUid = new AtomicInteger(0);
+        OperationalDatasetCallback callback =
+                new OperationalDatasetCallback() {
+                    @Override
+                    public void onActiveOperationalDatasetChanged(
+                            ActiveOperationalDataset dataset) {
+                        activeCallbackUid.set(Binder.getCallingUid());
+                    }
+
+                    @Override
+                    public void onPendingOperationalDatasetChanged(
+                            PendingOperationalDataset dataset) {
+                        pendingCallbackUid.set(Binder.getCallingUid());
+                    }
+                };
+
+        try {
+            mController.registerOperationalDatasetCallback(Runnable::run, callback);
+
+            assertThat(activeCallbackUid.get()).isNotEqualTo(SYSTEM_UID);
+            assertThat(activeCallbackUid.get()).isEqualTo(Process.myUid());
+            assertThat(pendingCallbackUid.get()).isNotEqualTo(SYSTEM_UID);
+            assertThat(pendingCallbackUid.get()).isEqualTo(Process.myUid());
+        } finally {
+            mController.unregisterOperationalDatasetCallback(callback);
+        }
+    }
+
+    @Test
+    public void createRandomizedDataset_callbackIsInvokedWithCallingAppIdentity() throws Exception {
+        setBinderUid(SYSTEM_UID);
+        AtomicInteger successCallbackUid = new AtomicInteger(0);
+        AtomicInteger errorCallbackUid = new AtomicInteger(0);
+
+        doAnswer(
+                        invoke -> {
+                            getCreateDatasetReceiver(invoke).onSuccess(DEFAULT_DATASET);
+                            return null;
+                        })
+                .when(mMockService)
+                .createRandomizedDataset(anyString(), any(IActiveOperationalDatasetReceiver.class));
+        mController.createRandomizedDataset(
+                "TestNet",
+                Runnable::run,
+                dataset -> successCallbackUid.set(Binder.getCallingUid()));
+        doAnswer(
+                        invoke -> {
+                            getCreateDatasetReceiver(invoke).onError(ERROR_UNSUPPORTED_CHANNEL, "");
+                            return null;
+                        })
+                .when(mMockService)
+                .createRandomizedDataset(anyString(), any(IActiveOperationalDatasetReceiver.class));
+        mController.createRandomizedDataset(
+                "TestNet",
+                Runnable::run,
+                new OutcomeReceiver<>() {
+                    @Override
+                    public void onResult(ActiveOperationalDataset dataset) {}
+
+                    @Override
+                    public void onError(ThreadNetworkException e) {
+                        errorCallbackUid.set(Binder.getCallingUid());
+                    }
+                });
+
+        assertThat(successCallbackUid.get()).isNotEqualTo(SYSTEM_UID);
+        assertThat(successCallbackUid.get()).isEqualTo(Process.myUid());
+        assertThat(errorCallbackUid.get()).isNotEqualTo(SYSTEM_UID);
+        assertThat(errorCallbackUid.get()).isEqualTo(Process.myUid());
+    }
+
+    @Test
+    public void join_callbackIsInvokedWithCallingAppIdentity() throws Exception {
+        setBinderUid(SYSTEM_UID);
+        AtomicInteger successCallbackUid = new AtomicInteger(0);
+        AtomicInteger errorCallbackUid = new AtomicInteger(0);
+
+        doAnswer(
+                        invoke -> {
+                            getJoinReceiver(invoke).onSuccess();
+                            return null;
+                        })
+                .when(mMockService)
+                .join(any(ActiveOperationalDataset.class), any(IOperationReceiver.class));
+        mController.join(
+                DEFAULT_DATASET,
+                Runnable::run,
+                v -> successCallbackUid.set(Binder.getCallingUid()));
+        doAnswer(
+                        invoke -> {
+                            getJoinReceiver(invoke).onError(ERROR_UNAVAILABLE, "");
+                            return null;
+                        })
+                .when(mMockService)
+                .join(any(ActiveOperationalDataset.class), any(IOperationReceiver.class));
+        mController.join(
+                DEFAULT_DATASET,
+                Runnable::run,
+                new OutcomeReceiver<>() {
+                    @Override
+                    public void onResult(Void unused) {}
+
+                    @Override
+                    public void onError(ThreadNetworkException e) {
+                        errorCallbackUid.set(Binder.getCallingUid());
+                    }
+                });
+
+        assertThat(successCallbackUid.get()).isNotEqualTo(SYSTEM_UID);
+        assertThat(successCallbackUid.get()).isEqualTo(Process.myUid());
+        assertThat(errorCallbackUid.get()).isNotEqualTo(SYSTEM_UID);
+        assertThat(errorCallbackUid.get()).isEqualTo(Process.myUid());
+    }
+
+    @Test
+    public void scheduleMigration_callbackIsInvokedWithCallingAppIdentity() throws Exception {
+        setBinderUid(SYSTEM_UID);
+        final PendingOperationalDataset pendingDataset =
+                new PendingOperationalDataset(
+                        DEFAULT_DATASET,
+                        new OperationalDatasetTimestamp(100, 0, false),
+                        Duration.ZERO);
+        AtomicInteger successCallbackUid = new AtomicInteger(0);
+        AtomicInteger errorCallbackUid = new AtomicInteger(0);
+
+        doAnswer(
+                        invoke -> {
+                            getScheduleMigrationReceiver(invoke).onSuccess();
+                            return null;
+                        })
+                .when(mMockService)
+                .scheduleMigration(
+                        any(PendingOperationalDataset.class), any(IOperationReceiver.class));
+        mController.scheduleMigration(
+                pendingDataset, Runnable::run, v -> successCallbackUid.set(Binder.getCallingUid()));
+        doAnswer(
+                        invoke -> {
+                            getScheduleMigrationReceiver(invoke).onError(ERROR_UNAVAILABLE, "");
+                            return null;
+                        })
+                .when(mMockService)
+                .scheduleMigration(
+                        any(PendingOperationalDataset.class), any(IOperationReceiver.class));
+        mController.scheduleMigration(
+                pendingDataset,
+                Runnable::run,
+                new OutcomeReceiver<>() {
+                    @Override
+                    public void onResult(Void unused) {}
+
+                    @Override
+                    public void onError(ThreadNetworkException e) {
+                        errorCallbackUid.set(Binder.getCallingUid());
+                    }
+                });
+
+        assertThat(successCallbackUid.get()).isNotEqualTo(SYSTEM_UID);
+        assertThat(successCallbackUid.get()).isEqualTo(Process.myUid());
+        assertThat(errorCallbackUid.get()).isNotEqualTo(SYSTEM_UID);
+        assertThat(errorCallbackUid.get()).isEqualTo(Process.myUid());
+    }
+
+    @Test
+    public void leave_callbackIsInvokedWithCallingAppIdentity() throws Exception {
+        setBinderUid(SYSTEM_UID);
+        AtomicInteger successCallbackUid = new AtomicInteger(0);
+        AtomicInteger errorCallbackUid = new AtomicInteger(0);
+
+        doAnswer(
+                        invoke -> {
+                            getOperationReceiver(invoke).onSuccess();
+                            return null;
+                        })
+                .when(mMockService)
+                .leave(any(IOperationReceiver.class));
+        mController.leave(Runnable::run, v -> successCallbackUid.set(Binder.getCallingUid()));
+        doAnswer(
+                        invoke -> {
+                            getOperationReceiver(invoke).onError(ERROR_UNAVAILABLE, "");
+                            return null;
+                        })
+                .when(mMockService)
+                .leave(any(IOperationReceiver.class));
+        mController.leave(
+                Runnable::run,
+                new OutcomeReceiver<>() {
+                    @Override
+                    public void onResult(Void unused) {}
+
+                    @Override
+                    public void onError(ThreadNetworkException e) {
+                        errorCallbackUid.set(Binder.getCallingUid());
+                    }
+                });
+
+        assertThat(successCallbackUid.get()).isNotEqualTo(SYSTEM_UID);
+        assertThat(successCallbackUid.get()).isEqualTo(Process.myUid());
+        assertThat(errorCallbackUid.get()).isNotEqualTo(SYSTEM_UID);
+        assertThat(errorCallbackUid.get()).isEqualTo(Process.myUid());
+    }
+}