[Thread] do not restart ot-daemon when Thread is disabled

Current implementation will auto-restart ot-daemon even when Thread
has been disabled by ThreadNetworkController#setEnabled(false). This
commit fixes this issue and also adds the "enable" / "disable" shell
command to ease testing.

Bug: 328538612
Change-Id: Icc9360ccf5d7daa6345def75d1860f046600b85b
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
index 1235c30..815a36e9 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
@@ -301,7 +301,13 @@
                 .build();
     }
 
-    private void initializeOtDaemon() {
+    private void maybeInitializeOtDaemon() {
+        if (!isEnabled()) {
+            return;
+        }
+
+        Log.i(TAG, "Starting OT daemon...");
+
         try {
             getOtDaemon();
         } catch (RemoteException e) {
@@ -371,14 +377,14 @@
 
     private void onOtDaemonDied() {
         checkOnHandlerThread();
-        Log.w(TAG, "OT daemon is dead, clean up and restart it...");
+        Log.w(TAG, "OT daemon is dead, clean up...");
 
         OperationReceiverWrapper.onOtDaemonDied();
         mOtDaemonCallbackProxy.onOtDaemonDied();
         mTunIfController.onOtDaemonDied();
         mNsdPublisher.onOtDaemonDied();
         mOtDaemon = null;
-        initializeOtDaemon();
+        maybeInitializeOtDaemon();
     }
 
     public void initialize() {
@@ -396,7 +402,7 @@
                     requestThreadNetwork();
                     mUserRestricted = isThreadUserRestricted();
                     registerUserRestrictionsReceiver();
-                    initializeOtDaemon();
+                    maybeInitializeOtDaemon();
                 });
     }
 
@@ -955,6 +961,13 @@
             String countryCode, @NonNull OperationReceiverWrapper receiver) {
         checkOnHandlerThread();
 
+        // Fails early to avoid waking up ot-daemon by the ThreadNetworkCountryCode class
+        if (!isEnabled()) {
+            receiver.onError(
+                    ERROR_THREAD_DISABLED, "Can't set country code when Thread is disabled");
+            return;
+        }
+
         try {
             getOtDaemon().setCountryCode(countryCode, newOtStatusReceiver(receiver));
         } catch (RemoteException e) {
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkService.java b/thread/service/java/com/android/server/thread/ThreadNetworkService.java
index 5664922..37c1cf1 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkService.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkService.java
@@ -18,6 +18,8 @@
 
 import static android.content.pm.PackageManager.PERMISSION_GRANTED;
 
+import static java.util.Objects.requireNonNull;
+
 import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.content.Context;
@@ -66,7 +68,8 @@
             // PHASE_ACTIVITY_MANAGER_READY and PHASE_THIRD_PARTY_APPS_CAN_START
             mCountryCode = ThreadNetworkCountryCode.newInstance(mContext, mControllerService);
             mCountryCode.initialize();
-            mShellCommand = new ThreadNetworkShellCommand(mCountryCode);
+            mShellCommand =
+                    new ThreadNetworkShellCommand(requireNonNull(mControllerService), mCountryCode);
         }
     }
 
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkShellCommand.java b/thread/service/java/com/android/server/thread/ThreadNetworkShellCommand.java
index c17c5a7..431232b 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkShellCommand.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkShellCommand.java
@@ -16,7 +16,10 @@
 
 package com.android.server.thread;
 
+import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.net.thread.IOperationReceiver;
+import android.net.thread.ThreadNetworkException;
 import android.os.Binder;
 import android.os.Process;
 import android.text.TextUtils;
@@ -25,7 +28,12 @@
 import com.android.modules.utils.BasicShellCommandHandler;
 
 import java.io.PrintWriter;
+import java.time.Duration;
 import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
 
 /**
  * Interprets and executes 'adb shell cmd thread_network [args]'.
@@ -37,16 +45,21 @@
  * corresponding API permissions.
  */
 public class ThreadNetworkShellCommand extends BasicShellCommandHandler {
-    private static final String TAG = "ThreadNetworkShellCommand";
+    private static final Duration SET_ENABLED_TIMEOUT = Duration.ofSeconds(2);
 
     // These don't require root access.
-    private static final List<String> NON_PRIVILEGED_COMMANDS = List.of("help", "get-country-code");
+    private static final List<String> NON_PRIVILEGED_COMMANDS =
+            List.of("help", "get-country-code", "enable", "disable");
 
-    @Nullable private final ThreadNetworkCountryCode mCountryCode;
+    @NonNull private final ThreadNetworkControllerService mControllerService;
+    @NonNull private final ThreadNetworkCountryCode mCountryCode;
     @Nullable private PrintWriter mOutputWriter;
     @Nullable private PrintWriter mErrorWriter;
 
-    ThreadNetworkShellCommand(@Nullable ThreadNetworkCountryCode countryCode) {
+    ThreadNetworkShellCommand(
+            @NonNull ThreadNetworkControllerService controllerService,
+            @NonNull ThreadNetworkCountryCode countryCode) {
+        mControllerService = controllerService;
         mCountryCode = countryCode;
     }
 
@@ -91,14 +104,12 @@
         }
 
         switch (cmd) {
+            case "enable":
+                return setThreadEnabled(true);
+            case "disable":
+                return setThreadEnabled(false);
             case "force-country-code":
                 boolean enabled;
-
-                if (mCountryCode == null) {
-                    perr.println("Thread country code operations are not supported");
-                    return -1;
-                }
-
                 try {
                     enabled = getNextArgRequiredTrueOrFalse("enabled", "disabled");
                 } catch (IllegalArgumentException e) {
@@ -124,11 +135,6 @@
                 }
                 return 0;
             case "get-country-code":
-                if (mCountryCode == null) {
-                    perr.println("Thread country code operations are not supported");
-                    return -1;
-                }
-
                 pw.println("Thread country code = " + mCountryCode.getCountryCode());
                 return 0;
             default:
@@ -136,6 +142,40 @@
         }
     }
 
+    private int setThreadEnabled(boolean enabled) {
+        final PrintWriter perr = getErrorWriter();
+
+        CompletableFuture<Void> setEnabledFuture = new CompletableFuture<>();
+        mControllerService.setEnabled(
+                enabled,
+                new IOperationReceiver.Stub() {
+                    @Override
+                    public void onSuccess() {
+                        setEnabledFuture.complete(null);
+                    }
+
+                    @Override
+                    public void onError(int errorCode, String errorMessage) {
+                        setEnabledFuture.completeExceptionally(
+                                new ThreadNetworkException(errorCode, errorMessage));
+                    }
+                });
+
+        try {
+            setEnabledFuture.get(SET_ENABLED_TIMEOUT.toSeconds(), TimeUnit.SECONDS);
+            return 0;
+        } catch (InterruptedException e) {
+            Thread.currentThread().interrupt();
+            perr.println("Failed: " + e.getMessage());
+        } catch (ExecutionException e) {
+            perr.println("Failed: " + e.getCause().getMessage());
+        } catch (TimeoutException e) {
+            perr.println("Failed: command timeout for " + SET_ENABLED_TIMEOUT);
+        }
+
+        return -1;
+    }
+
     private static boolean argTrueOrFalse(String arg, String trueString, String falseString) {
         if (trueString.equals(arg)) {
             return true;
@@ -159,6 +199,10 @@
     }
 
     private void onHelpNonPrivileged(PrintWriter pw) {
+        pw.println("  enable");
+        pw.println("    Enables Thread radio");
+        pw.println("  disable");
+        pw.println("    Disables Thread radio");
         pw.println("  get-country-code");
         pw.println("    Gets country code as a two-letter string");
     }
diff --git a/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java b/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java
index 353db10..5fe4325 100644
--- a/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java
+++ b/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java
@@ -164,15 +164,13 @@
          * </pre>
          */
 
-        // Let ftd join the network.
         FullThreadDevice ftd = mFtds.get(0);
         startFtdChild(ftd);
 
-        // Infra device sends an echo request to FTD's OMR.
         mInfraDevice.sendEchoRequest(ftd.getOmrAddress());
 
         // Infra device receives an echo reply sent by FTD.
-        assertNotNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, null /* srcAddress */));
+        assertNotNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd.getOmrAddress()));
     }
 
     @Test
diff --git a/thread/tests/integration/src/android/net/thread/ThreadNetworkShellCommandTest.java b/thread/tests/integration/src/android/net/thread/ThreadNetworkShellCommandTest.java
new file mode 100644
index 0000000..d24fd47
--- /dev/null
+++ b/thread/tests/integration/src/android/net/thread/ThreadNetworkShellCommandTest.java
@@ -0,0 +1,75 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.net.thread;
+
+import static android.net.thread.ThreadNetworkController.STATE_DISABLED;
+import static android.net.thread.ThreadNetworkController.STATE_ENABLED;
+
+import static com.android.compatibility.common.util.SystemUtil.runShellCommand;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.content.Context;
+import android.net.thread.utils.ThreadFeatureCheckerRule;
+import android.net.thread.utils.ThreadFeatureCheckerRule.RequiresThreadFeature;
+import android.net.thread.utils.ThreadNetworkControllerWrapper;
+
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.filters.LargeTest;
+import androidx.test.runner.AndroidJUnit4;
+
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+/** Integration tests for {@link ThreadNetworkShellCommand}. */
+@LargeTest
+@RequiresThreadFeature
+@RunWith(AndroidJUnit4.class)
+public class ThreadNetworkShellCommandTest {
+    @Rule public final ThreadFeatureCheckerRule mThreadRule = new ThreadFeatureCheckerRule();
+
+    private final Context mContext = ApplicationProvider.getApplicationContext();
+    private final ThreadNetworkControllerWrapper mController =
+            ThreadNetworkControllerWrapper.newInstance(mContext);
+
+    @Test
+    public void enable_threadStateIsEnabled() throws Exception {
+        runThreadCommand("enable");
+
+        assertThat(mController.getEnabledState()).isEqualTo(STATE_ENABLED);
+    }
+
+    @Test
+    public void disable_threadStateIsDisabled() throws Exception {
+        runThreadCommand("disable");
+
+        assertThat(mController.getEnabledState()).isEqualTo(STATE_DISABLED);
+    }
+
+    @Test
+    public void forceCountryCode_setCN_getCountryCodeReturnsCN() {
+        runThreadCommand("force-country-code enabled CN");
+
+        final String result = runThreadCommand("get-country-code");
+        assertThat(result).contains("Thread country code = CN");
+    }
+
+    private static String runThreadCommand(String cmd) {
+        return runShellCommand("cmd thread_network " + cmd);
+    }
+}
diff --git a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
index f54edfe..830890d 100644
--- a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
+++ b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
@@ -312,13 +312,13 @@
     }
 
     @Test
-    public void userRestriction_initWithUserRestricted_threadIsDisabled() {
+    public void userRestriction_initWithUserRestricted_otDaemonNotStarted() {
         when(mMockUserManager.hasUserRestriction(eq(DISALLOW_THREAD_NETWORK))).thenReturn(true);
 
         mService.initialize();
         mTestLooper.dispatchAll();
 
-        assertThat(mFakeOtDaemon.getEnabledState()).isEqualTo(STATE_DISABLED);
+        assertThat(mFakeOtDaemon.isInitialized()).isFalse();
     }
 
     @Test
diff --git a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkShellCommandTest.java b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkShellCommandTest.java
index c7e0eca..f469152 100644
--- a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkShellCommandTest.java
+++ b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkShellCommandTest.java
@@ -45,8 +45,8 @@
 @SmallTest
 public class ThreadNetworkShellCommandTest {
     private static final String TAG = "ThreadNetworkShellCommandTTest";
-    @Mock ThreadNetworkService mThreadNetworkService;
-    @Mock ThreadNetworkCountryCode mThreadNetworkCountryCode;
+    @Mock ThreadNetworkControllerService mControllerService;
+    @Mock ThreadNetworkCountryCode mCountryCode;
     @Mock PrintWriter mErrorWriter;
     @Mock PrintWriter mOutputWriter;
 
@@ -56,7 +56,8 @@
     public void setUp() throws Exception {
         MockitoAnnotations.initMocks(this);
 
-        mThreadNetworkShellCommand = new ThreadNetworkShellCommand(mThreadNetworkCountryCode);
+        mThreadNetworkShellCommand =
+                new ThreadNetworkShellCommand(mControllerService, mCountryCode);
         mThreadNetworkShellCommand.setPrintWriters(mOutputWriter, mErrorWriter);
     }
 
@@ -68,7 +69,7 @@
     @Test
     public void getCountryCode_executeInUnrootedShell_allowed() {
         BinderUtil.setUid(Process.SHELL_UID);
-        when(mThreadNetworkCountryCode.getCountryCode()).thenReturn("US");
+        when(mCountryCode.getCountryCode()).thenReturn("US");
 
         mThreadNetworkShellCommand.exec(
                 new Binder(),
@@ -91,7 +92,7 @@
                 new FileDescriptor(),
                 new String[] {"force-country-code", "enabled", "US"});
 
-        verify(mThreadNetworkCountryCode, never()).setOverrideCountryCode(eq("US"));
+        verify(mCountryCode, never()).setOverrideCountryCode(eq("US"));
         verify(mErrorWriter).println(contains("force-country-code"));
     }
 
@@ -106,7 +107,7 @@
                 new FileDescriptor(),
                 new String[] {"force-country-code", "enabled", "US"});
 
-        verify(mThreadNetworkCountryCode).setOverrideCountryCode(eq("US"));
+        verify(mCountryCode).setOverrideCountryCode(eq("US"));
     }
 
     @Test
@@ -120,7 +121,7 @@
                 new FileDescriptor(),
                 new String[] {"force-country-code", "disabled"});
 
-        verify(mThreadNetworkCountryCode, never()).setOverrideCountryCode(any());
+        verify(mCountryCode, never()).setOverrideCountryCode(any());
         verify(mErrorWriter).println(contains("force-country-code"));
     }
 
@@ -135,6 +136,6 @@
                 new FileDescriptor(),
                 new String[] {"force-country-code", "disabled"});
 
-        verify(mThreadNetworkCountryCode).clearOverrideCountryCode();
+        verify(mCountryCode).clearOverrideCountryCode();
     }
 }