[Thread] add more Thread shell commands

A summary of changes in this commit:
1. add new commands mirroring the java APIs: "join", "migrate" and
   "leave"
2. Use THREAD_NETWORK_TESTING for guarding access to shell commands
   which are for testing
3. Refactor ThreadNetworkShellCommand for readability

Bug: 329368792
Change-Id: Ibdb256674533704dc1d000a15c162f41cec5047f
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
index 0c200fd..3d743ab 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
@@ -1058,7 +1058,7 @@
     }
 
     @Override
-    public void leave(@NonNull IOperationReceiver receiver) throws RemoteException {
+    public void leave(@NonNull IOperationReceiver receiver) {
         enforceAllPermissionsGranted(PERMISSION_THREAD_NETWORK_PRIVILEGED);
 
         mHandler.post(() -> leaveInternal(new OperationReceiverWrapper(receiver)));
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkService.java b/thread/service/java/com/android/server/thread/ThreadNetworkService.java
index 30c67ca..4c22278 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkService.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkService.java
@@ -72,7 +72,10 @@
             // PHASE_ACTIVITY_MANAGER_READY and PHASE_THIRD_PARTY_APPS_CAN_START
             mCountryCode.initialize();
             mShellCommand =
-                    new ThreadNetworkShellCommand(requireNonNull(mControllerService), mCountryCode);
+                    new ThreadNetworkShellCommand(
+                            mContext,
+                            requireNonNull(mControllerService),
+                            requireNonNull(mCountryCode));
         }
     }
 
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkShellCommand.java b/thread/service/java/com/android/server/thread/ThreadNetworkShellCommand.java
index c6a1618..54155ee 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkShellCommand.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkShellCommand.java
@@ -16,50 +16,57 @@
 
 package com.android.server.thread;
 
-import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.content.Context;
+import android.net.thread.ActiveOperationalDataset;
 import android.net.thread.IOperationReceiver;
+import android.net.thread.OperationalDatasetTimestamp;
+import android.net.thread.PendingOperationalDataset;
 import android.net.thread.ThreadNetworkException;
-import android.os.Binder;
-import android.os.Process;
 import android.text.TextUtils;
 
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.modules.utils.BasicShellCommandHandler;
+import com.android.net.module.util.HexDump;
 
 import java.io.PrintWriter;
 import java.time.Duration;
-import java.util.List;
+import java.time.Instant;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 
 /**
- * Interprets and executes 'adb shell cmd thread_network [args]'.
+ * Interprets and executes 'adb shell cmd thread_network <subcommand>'.
+ *
+ * <p>Subcommands which don't have an equivalent Java API now require the
+ * "android.permission.THREAD_NETWORK_TESTING" permission. For a specific subcommand, it also
+ * requires the same permissions of the equivalent Java / AIDL API.
  *
  * <p>To add new commands: - onCommand: Add a case "<command>" execute. Return a 0 if command
  * executed successfully. - onHelp: add a description string.
- *
- * <p>Permissions: currently root permission is required for some commands. Others will enforce the
- * corresponding API permissions.
  */
-public class ThreadNetworkShellCommand extends BasicShellCommandHandler {
+public final class ThreadNetworkShellCommand extends BasicShellCommandHandler {
     private static final Duration SET_ENABLED_TIMEOUT = Duration.ofSeconds(2);
+    private static final Duration LEAVE_TIMEOUT = Duration.ofSeconds(2);
+    private static final Duration MIGRATE_TIMEOUT = Duration.ofSeconds(2);
     private static final Duration FORCE_STOP_TIMEOUT = Duration.ofSeconds(1);
+    private static final String PERMISSION_THREAD_NETWORK_TESTING =
+            "android.permission.THREAD_NETWORK_TESTING";
 
-    // These don't require root access.
-    private static final List<String> NON_PRIVILEGED_COMMANDS =
-            List.of("help", "get-country-code", "enable", "disable");
+    private final Context mContext;
+    private final ThreadNetworkControllerService mControllerService;
+    private final ThreadNetworkCountryCode mCountryCode;
 
-    @NonNull private final ThreadNetworkControllerService mControllerService;
-    @NonNull private final ThreadNetworkCountryCode mCountryCode;
     @Nullable private PrintWriter mOutputWriter;
     @Nullable private PrintWriter mErrorWriter;
 
-    ThreadNetworkShellCommand(
-            @NonNull ThreadNetworkControllerService controllerService,
-            @NonNull ThreadNetworkCountryCode countryCode) {
+    public ThreadNetworkShellCommand(
+            Context context,
+            ThreadNetworkControllerService controllerService,
+            ThreadNetworkCountryCode countryCode) {
+        mContext = context;
         mControllerService = controllerService;
         mCountryCode = countryCode;
     }
@@ -79,79 +86,120 @@
     }
 
     @Override
+    public void onHelp() {
+        final PrintWriter pw = getOutputWriter();
+        pw.println("Thread network commands:");
+        pw.println("  help or -h");
+        pw.println("    Print this help text.");
+        pw.println("  enable");
+        pw.println("    Enables Thread radio");
+        pw.println("  disable");
+        pw.println("    Disables Thread radio");
+        pw.println("  join <active-dataset-tlvs>");
+        pw.println("    Joins a network of the given dataset");
+        pw.println("  migrate <active-dataset-tlvs> <delay-seconds>");
+        pw.println("    Migrate to the given network by a specific delay");
+        pw.println("  leave");
+        pw.println("    Leave the current network and erase datasets");
+        pw.println("  force-stop-ot-daemon enabled | disabled ");
+        pw.println("    force stop ot-daemon service");
+        pw.println("  get-country-code");
+        pw.println("    Gets country code as a two-letter string");
+        pw.println("  force-country-code enabled <two-letter code> | disabled ");
+        pw.println("    Sets country code to <two-letter code> or left for normal value");
+    }
+
+    @Override
     public int onCommand(String cmd) {
-        // Treat no command as help command.
+        // Treat no command as the "help" command
         if (TextUtils.isEmpty(cmd)) {
             cmd = "help";
         }
 
-        final PrintWriter pw = getOutputWriter();
-        final PrintWriter perr = getErrorWriter();
-
-        // Explicit exclusion from root permission
-        if (!NON_PRIVILEGED_COMMANDS.contains(cmd)) {
-            final int uid = Binder.getCallingUid();
-
-            if (uid != Process.ROOT_UID) {
-                perr.println(
-                        "Uid "
-                                + uid
-                                + " does not have access to "
-                                + cmd
-                                + " thread command "
-                                + "(or such command doesn't exist)");
-                return -1;
-            }
-        }
-
         switch (cmd) {
             case "enable":
                 return setThreadEnabled(true);
             case "disable":
                 return setThreadEnabled(false);
+            case "join":
+                return join();
+            case "leave":
+                return leave();
+            case "migrate":
+                return migrate();
             case "force-stop-ot-daemon":
                 return forceStopOtDaemon();
             case "force-country-code":
-                boolean enabled;
-                try {
-                    enabled = getNextArgRequiredTrueOrFalse("enabled", "disabled");
-                } catch (IllegalArgumentException e) {
-                    perr.println("Invalid argument: " + e.getMessage());
-                    return -1;
-                }
-
-                if (enabled) {
-                    String countryCode = getNextArgRequired();
-                    if (!ThreadNetworkCountryCode.isValidCountryCode(countryCode)) {
-                        perr.println(
-                                "Invalid argument: Country code must be a 2-Character"
-                                        + " string. But got country code "
-                                        + countryCode
-                                        + " instead");
-                        return -1;
-                    }
-                    mCountryCode.setOverrideCountryCode(countryCode);
-                    pw.println("Set Thread country code: " + countryCode);
-
-                } else {
-                    mCountryCode.clearOverrideCountryCode();
-                }
-                return 0;
+                return forceCountryCode();
             case "get-country-code":
-                pw.println("Thread country code = " + mCountryCode.getCountryCode());
-                return 0;
+                return getCountryCode();
             default:
                 return handleDefaultCommands(cmd);
         }
     }
 
+    private void ensureTestingPermission() {
+        mContext.enforceCallingOrSelfPermission(
+                PERMISSION_THREAD_NETWORK_TESTING,
+                "Permission " + PERMISSION_THREAD_NETWORK_TESTING + " is missing!");
+    }
+
     private int setThreadEnabled(boolean enabled) {
         CompletableFuture<Void> setEnabledFuture = new CompletableFuture<>();
         mControllerService.setEnabled(enabled, newOperationReceiver(setEnabledFuture));
-        return waitForFuture(setEnabledFuture, FORCE_STOP_TIMEOUT, getErrorWriter());
+        return waitForFuture(setEnabledFuture, SET_ENABLED_TIMEOUT, getErrorWriter());
+    }
+
+    private int join() {
+        byte[] datasetTlvs = HexDump.hexStringToByteArray(getNextArgRequired());
+        ActiveOperationalDataset dataset;
+        try {
+            dataset = ActiveOperationalDataset.fromThreadTlvs(datasetTlvs);
+        } catch (IllegalArgumentException e) {
+            getErrorWriter().println("Invalid dataset argument: " + e.getMessage());
+            return -1;
+        }
+        // Do not wait for join to complete because this can take 8 to 30 seconds
+        mControllerService.join(dataset, new IOperationReceiver.Default());
+        return 0;
+    }
+
+    private int leave() {
+        CompletableFuture<Void> leaveFuture = new CompletableFuture<>();
+        mControllerService.leave(newOperationReceiver(leaveFuture));
+        return waitForFuture(leaveFuture, LEAVE_TIMEOUT, getErrorWriter());
+    }
+
+    private int migrate() {
+        byte[] datasetTlvs = HexDump.hexStringToByteArray(getNextArgRequired());
+        ActiveOperationalDataset dataset;
+        try {
+            dataset = ActiveOperationalDataset.fromThreadTlvs(datasetTlvs);
+        } catch (IllegalArgumentException e) {
+            getErrorWriter().println("Invalid dataset argument: " + e.getMessage());
+            return -1;
+        }
+
+        int delaySeconds;
+        try {
+            delaySeconds = Integer.parseInt(getNextArgRequired());
+        } catch (NumberFormatException e) {
+            getErrorWriter().println("Invalid delay argument: " + e.getMessage());
+            return -1;
+        }
+
+        PendingOperationalDataset pendingDataset =
+                new PendingOperationalDataset(
+                        dataset,
+                        OperationalDatasetTimestamp.fromInstant(Instant.now()),
+                        Duration.ofSeconds(delaySeconds));
+        CompletableFuture<Void> migrateFuture = new CompletableFuture<>();
+        mControllerService.scheduleMigration(pendingDataset, newOperationReceiver(migrateFuture));
+        return waitForFuture(migrateFuture, MIGRATE_TIMEOUT, getErrorWriter());
     }
 
     private int forceStopOtDaemon() {
+        ensureTestingPermission();
         final PrintWriter errorWriter = getErrorWriter();
         boolean enabled;
         try {
@@ -166,6 +214,40 @@
         return waitForFuture(forceStopFuture, FORCE_STOP_TIMEOUT, getErrorWriter());
     }
 
+    private int forceCountryCode() {
+        ensureTestingPermission();
+        final PrintWriter perr = getErrorWriter();
+        boolean enabled;
+        try {
+            enabled = getNextArgRequiredTrueOrFalse("enabled", "disabled");
+        } catch (IllegalArgumentException e) {
+            perr.println("Invalid argument: " + e.getMessage());
+            return -1;
+        }
+
+        if (enabled) {
+            String countryCode = getNextArgRequired();
+            if (!ThreadNetworkCountryCode.isValidCountryCode(countryCode)) {
+                perr.println(
+                        "Invalid argument: Country code must be a 2-letter"
+                                + " string. But got country code "
+                                + countryCode
+                                + " instead");
+                return -1;
+            }
+            mCountryCode.setOverrideCountryCode(countryCode);
+        } else {
+            mCountryCode.clearOverrideCountryCode();
+        }
+        return 0;
+    }
+
+    private int getCountryCode() {
+        ensureTestingPermission();
+        getOutputWriter().println("Thread country code = " + mCountryCode.getCountryCode());
+        return 0;
+    }
+
     private static IOperationReceiver newOperationReceiver(CompletableFuture<Void> future) {
         return new IOperationReceiver.Stub() {
             @Override
@@ -224,33 +306,4 @@
         String nextArg = getNextArgRequired();
         return argTrueOrFalse(nextArg, trueString, falseString);
     }
-
-    private void onHelpNonPrivileged(PrintWriter pw) {
-        pw.println("  enable");
-        pw.println("    Enables Thread radio");
-        pw.println("  disable");
-        pw.println("    Disables Thread radio");
-        pw.println("  get-country-code");
-        pw.println("    Gets country code as a two-letter string");
-    }
-
-    private void onHelpPrivileged(PrintWriter pw) {
-        pw.println("  force-country-code enabled <two-letter code> | disabled ");
-        pw.println("    Sets country code to <two-letter code> or left for normal value");
-        pw.println("  force-stop-ot-daemon enabled | disabled ");
-        pw.println("    force stop ot-daemon service");
-    }
-
-    @Override
-    public void onHelp() {
-        final PrintWriter pw = getOutputWriter();
-        pw.println("Thread network commands:");
-        pw.println("  help or -h");
-        pw.println("    Print this help text.");
-        onHelpNonPrivileged(pw);
-        if (Binder.getCallingUid() == Process.ROOT_UID) {
-            onHelpPrivileged(pw);
-        }
-        pw.println();
-    }
 }
diff --git a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkShellCommandTest.java b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkShellCommandTest.java
index 9f2d0cb..dfb3129 100644
--- a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkShellCommandTest.java
+++ b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkShellCommandTest.java
@@ -16,22 +16,29 @@
 
 package com.android.server.thread;
 
-import static org.mockito.ArgumentMatchers.anyBoolean;
+import static com.google.common.io.BaseEncoding.base16;
+import static com.google.common.truth.Truth.assertThat;
+
 import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.anyString;
 import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.contains;
 import static org.mockito.Mockito.doNothing;
 import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.validateMockitoUsage;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
+import android.content.Context;
+import android.net.thread.ActiveOperationalDataset;
+import android.net.thread.PendingOperationalDataset;
 import android.os.Binder;
-import android.os.Process;
 
+import androidx.test.core.app.ApplicationProvider;
 import androidx.test.filters.SmallTest;
 import androidx.test.runner.AndroidJUnit4;
 
@@ -39,6 +46,7 @@
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
@@ -49,19 +57,43 @@
 @RunWith(AndroidJUnit4.class)
 @SmallTest
 public class ThreadNetworkShellCommandTest {
-    private static final String TAG = "ThreadNetworkShellCommandTTest";
-    @Mock ThreadNetworkControllerService mControllerService;
-    @Mock ThreadNetworkCountryCode mCountryCode;
-    @Mock PrintWriter mErrorWriter;
-    @Mock PrintWriter mOutputWriter;
+    // A valid Thread Active Operational Dataset generated from OpenThread CLI "dataset new":
+    // Active Timestamp: 1
+    // Channel: 19
+    // Channel Mask: 0x07FFF800
+    // Ext PAN ID: ACC214689BC40BDF
+    // Mesh Local Prefix: fd64:db12:25f4:7e0b::/64
+    // Network Key: F26B3153760F519A63BAFDDFFC80D2AF
+    // Network Name: OpenThread-d9a0
+    // PAN ID: 0xD9A0
+    // PSKc: A245479C836D551B9CA557F7B9D351B4
+    // Security Policy: 672 onrcb
+    private static final String DEFAULT_ACTIVE_DATASET_TLVS =
+            "0E080000000000010000000300001335060004001FFFE002"
+                    + "08ACC214689BC40BDF0708FD64DB1225F47E0B0510F26B31"
+                    + "53760F519A63BAFDDFFC80D2AF030F4F70656E5468726561"
+                    + "642D643961300102D9A00410A245479C836D551B9CA557F7"
+                    + "B9D351B40C0402A0FFF8";
 
-    ThreadNetworkShellCommand mShellCommand;
+    @Mock private ThreadNetworkControllerService mControllerService;
+    @Mock private ThreadNetworkCountryCode mCountryCode;
+    @Mock private PrintWriter mErrorWriter;
+    @Mock private PrintWriter mOutputWriter;
+
+    private Context mContext;
+    private ThreadNetworkShellCommand mShellCommand;
 
     @Before
     public void setUp() throws Exception {
         MockitoAnnotations.initMocks(this);
 
-        mShellCommand = new ThreadNetworkShellCommand(mControllerService, mCountryCode);
+        mContext = spy(ApplicationProvider.getApplicationContext());
+        doNothing()
+                .when(mContext)
+                .enforceCallingOrSelfPermission(
+                        eq("android.permission.THREAD_NETWORK_TESTING"), anyString());
+
+        mShellCommand = new ThreadNetworkShellCommand(mContext, mControllerService, mCountryCode);
         mShellCommand.setPrintWriters(mOutputWriter, mErrorWriter);
     }
 
@@ -71,8 +103,23 @@
     }
 
     @Test
-    public void getCountryCode_executeInUnrootedShell_allowed() {
-        BinderUtil.setUid(Process.SHELL_UID);
+    public void getCountryCode_testingPermissionIsChecked() {
+        when(mCountryCode.getCountryCode()).thenReturn("US");
+
+        mShellCommand.exec(
+                new Binder(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new String[] {"get-country-code"});
+
+        verify(mContext, times(1))
+                .enforceCallingOrSelfPermission(
+                        eq("android.permission.THREAD_NETWORK_TESTING"), anyString());
+    }
+
+    @Test
+    public void getCountryCode_currentCountryCodePrinted() {
         when(mCountryCode.getCountryCode()).thenReturn("US");
 
         mShellCommand.exec(
@@ -86,9 +133,7 @@
     }
 
     @Test
-    public void forceSetCountryCodeEnabled_executeInUnrootedShell_notAllowed() {
-        BinderUtil.setUid(Process.SHELL_UID);
-
+    public void forceSetCountryCodeEnabled_testingPermissionIsChecked() {
         mShellCommand.exec(
                 new Binder(),
                 new FileDescriptor(),
@@ -96,14 +141,13 @@
                 new FileDescriptor(),
                 new String[] {"force-country-code", "enabled", "US"});
 
-        verify(mCountryCode, never()).setOverrideCountryCode(eq("US"));
-        verify(mErrorWriter).println(contains("force-country-code"));
+        verify(mContext, times(1))
+                .enforceCallingOrSelfPermission(
+                        eq("android.permission.THREAD_NETWORK_TESTING"), anyString());
     }
 
     @Test
-    public void forceSetCountryCodeEnabled_executeInRootedShell_allowed() {
-        BinderUtil.setUid(Process.ROOT_UID);
-
+    public void forceSetCountryCodeEnabled_countryCodeIsOverridden() {
         mShellCommand.exec(
                 new Binder(),
                 new FileDescriptor(),
@@ -115,24 +159,7 @@
     }
 
     @Test
-    public void forceSetCountryCodeDisabled_executeInUnrootedShell_notAllowed() {
-        BinderUtil.setUid(Process.SHELL_UID);
-
-        mShellCommand.exec(
-                new Binder(),
-                new FileDescriptor(),
-                new FileDescriptor(),
-                new FileDescriptor(),
-                new String[] {"force-country-code", "disabled"});
-
-        verify(mCountryCode, never()).setOverrideCountryCode(any());
-        verify(mErrorWriter).println(contains("force-country-code"));
-    }
-
-    @Test
-    public void forceSetCountryCodeDisabled_executeInRootedShell_allowed() {
-        BinderUtil.setUid(Process.ROOT_UID);
-
+    public void forceSetCountryCodeDisabled_overriddenCountryCodeIsCleared() {
         mShellCommand.exec(
                 new Binder(),
                 new FileDescriptor(),
@@ -144,9 +171,7 @@
     }
 
     @Test
-    public void forceStopOtDaemon_executeInUnrootedShell_failedAndServiceApiNotCalled() {
-        BinderUtil.setUid(Process.SHELL_UID);
-
+    public void forceStopOtDaemon_testingPermissionIsChecked() {
         mShellCommand.exec(
                 new Binder(),
                 new FileDescriptor(),
@@ -154,14 +179,13 @@
                 new FileDescriptor(),
                 new String[] {"force-stop-ot-daemon", "enabled"});
 
-        verify(mControllerService, never()).forceStopOtDaemonForTest(anyBoolean(), any());
-        verify(mErrorWriter, atLeastOnce()).println(contains("force-stop-ot-daemon"));
-        verify(mOutputWriter, never()).println();
+        verify(mContext, times(1))
+                .enforceCallingOrSelfPermission(
+                        eq("android.permission.THREAD_NETWORK_TESTING"), anyString());
     }
 
     @Test
     public void forceStopOtDaemon_serviceThrows_failed() {
-        BinderUtil.setUid(Process.ROOT_UID);
         doThrow(new SecurityException(""))
                 .when(mControllerService)
                 .forceStopOtDaemonForTest(eq(true), any());
@@ -179,7 +203,6 @@
 
     @Test
     public void forceStopOtDaemon_serviceApiTimeout_failedWithTimeoutError() {
-        BinderUtil.setUid(Process.ROOT_UID);
         doNothing().when(mControllerService).forceStopOtDaemonForTest(eq(true), any());
 
         mShellCommand.exec(
@@ -193,4 +216,89 @@
         verify(mErrorWriter, atLeastOnce()).println(contains("timeout"));
         verify(mOutputWriter, never()).println();
     }
+
+    @Test
+    public void join_controllerServiceJoinIsCalled() {
+        doNothing().when(mControllerService).join(any(), any());
+
+        mShellCommand.exec(
+                new Binder(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new String[] {"join", DEFAULT_ACTIVE_DATASET_TLVS});
+
+        var activeDataset =
+                ActiveOperationalDataset.fromThreadTlvs(
+                        base16().decode(DEFAULT_ACTIVE_DATASET_TLVS));
+        verify(mControllerService, times(1)).join(eq(activeDataset), any());
+        verify(mErrorWriter, never()).println();
+    }
+
+    @Test
+    public void join_invalidDataset_controllerServiceJoinIsNotCalled() {
+        doNothing().when(mControllerService).join(any(), any());
+
+        mShellCommand.exec(
+                new Binder(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new String[] {"join", "000102"});
+
+        verify(mControllerService, never()).join(any(), any());
+        verify(mErrorWriter, times(1)).println(contains("Invalid dataset argument"));
+    }
+
+    @Test
+    public void migrate_controllerServiceMigrateIsCalled() {
+        doNothing().when(mControllerService).scheduleMigration(any(), any());
+
+        mShellCommand.exec(
+                new Binder(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new String[] {"migrate", DEFAULT_ACTIVE_DATASET_TLVS, "300"});
+
+        ArgumentCaptor<PendingOperationalDataset> captor =
+                ArgumentCaptor.forClass(PendingOperationalDataset.class);
+        verify(mControllerService, times(1)).scheduleMigration(captor.capture(), any());
+        assertThat(captor.getValue().getActiveOperationalDataset())
+                .isEqualTo(
+                        ActiveOperationalDataset.fromThreadTlvs(
+                                base16().decode(DEFAULT_ACTIVE_DATASET_TLVS)));
+        assertThat(captor.getValue().getDelayTimer().toSeconds()).isEqualTo(300);
+        verify(mErrorWriter, never()).println();
+    }
+
+    @Test
+    public void migrate_invalidDataset_controllerServiceMigrateIsNotCalled() {
+        doNothing().when(mControllerService).scheduleMigration(any(), any());
+
+        mShellCommand.exec(
+                new Binder(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new String[] {"migrate", "000102", "300"});
+
+        verify(mControllerService, never()).scheduleMigration(any(), any());
+        verify(mErrorWriter, times(1)).println(contains("Invalid dataset argument"));
+    }
+
+    @Test
+    public void leave_controllerServiceLeaveIsCalled() {
+        doNothing().when(mControllerService).leave(any());
+
+        mShellCommand.exec(
+                new Binder(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new String[] {"leave"});
+
+        verify(mControllerService, times(1)).leave(any());
+        verify(mErrorWriter, never()).println();
+    }
 }