Merge "Set Thread Network country code from location country code" into main
diff --git a/thread/service/Android.bp b/thread/service/Android.bp
index 35ae3c2..92cdedc 100644
--- a/thread/service/Android.bp
+++ b/thread/service/Android.bp
@@ -35,9 +35,11 @@
     libs: [
         "framework-connectivity-pre-jarjar",
         "framework-connectivity-t-pre-jarjar",
+        "framework-location.stubs.module_lib",
         "service-connectivity-pre-jarjar",
     ],
     static_libs: [
+        "modules-utils-shell-command-handler",
         "net-utils-device-common",
         "net-utils-device-common-netlink",
         "ot-daemon-aidl-java",
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
index 2cd1be3..5ae310c 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
@@ -54,6 +54,8 @@
 import android.Manifest.permission;
 import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.annotation.RequiresPermission;
+import android.annotation.TargetApi;
 import android.content.Context;
 import android.net.ConnectivityManager;
 import android.net.IpPrefix;
@@ -82,6 +84,7 @@
 import android.net.thread.PendingOperationalDataset;
 import android.net.thread.ThreadNetworkController;
 import android.net.thread.ThreadNetworkController.DeviceRole;
+import android.os.Build;
 import android.os.Handler;
 import android.os.HandlerThread;
 import android.os.IBinder;
@@ -121,6 +124,7 @@
  * `mHandlerThread` 2. In the @Override methods, the actual work MUST be dispatched to the
  * HandlerThread except for arguments or permissions checking
  */
+@TargetApi(Build.VERSION_CODES.UPSIDE_DOWN_CAKE)
 final class ThreadNetworkControllerService extends IThreadNetworkController.Stub {
     private static final String TAG = "ThreadNetworkService";
 
@@ -742,6 +746,32 @@
         }
     }
 
+    /**
+     * Sets the country code.
+     *
+     * @param countryCode 2 characters string country code (as defined in ISO 3166) to set.
+     * @param receiver the receiver to receive result of this operation
+     */
+    @RequiresPermission(PERMISSION_THREAD_NETWORK_PRIVILEGED)
+    public void setCountryCode(@NonNull String countryCode, @NonNull IOperationReceiver receiver) {
+        enforceAllPermissionsGranted(PERMISSION_THREAD_NETWORK_PRIVILEGED);
+
+        OperationReceiverWrapper receiverWrapper = new OperationReceiverWrapper(receiver);
+        mHandler.post(() -> setCountryCodeInternal(countryCode, receiverWrapper));
+    }
+
+    private void setCountryCodeInternal(
+            String countryCode, @NonNull OperationReceiverWrapper receiver) {
+        checkOnHandlerThread();
+
+        try {
+            getOtDaemon().setCountryCode(countryCode, newOtStatusReceiver(receiver));
+        } catch (RemoteException e) {
+            Log.e(TAG, "otDaemon.setCountryCode failed", e);
+            receiver.onError(ERROR_INTERNAL_ERROR, "Thread stack error");
+        }
+    }
+
     private void enableBorderRouting(String infraIfName) {
         if (mBorderRouterConfig.isBorderRoutingEnabled
                 && infraIfName.equals(mBorderRouterConfig.infraInterfaceName)) {
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkCountryCode.java b/thread/service/java/com/android/server/thread/ThreadNetworkCountryCode.java
new file mode 100644
index 0000000..b05183c
--- /dev/null
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkCountryCode.java
@@ -0,0 +1,304 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.thread;
+
+import android.annotation.Nullable;
+import android.annotation.StringDef;
+import android.annotation.TargetApi;
+import android.location.Address;
+import android.location.Geocoder;
+import android.location.Location;
+import android.location.LocationManager;
+import android.net.thread.IOperationReceiver;
+import android.os.Build;
+import android.util.Log;
+
+import com.android.internal.annotations.VisibleForTesting;
+
+import java.io.FileDescriptor;
+import java.io.PrintWriter;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.time.Instant;
+import java.util.List;
+import java.util.Locale;
+import java.util.Objects;
+
+/**
+ * Provide functions for making changes to Thread Network country code. This Country Code is from
+ * location. This class sends Country Code to Thread Network native layer.
+ *
+ * <p>This class is thread-safe.
+ */
+@TargetApi(Build.VERSION_CODES.UPSIDE_DOWN_CAKE)
+public class ThreadNetworkCountryCode {
+    private static final String TAG = "ThreadNetworkCountryCode";
+    // To be used when there is no country code available.
+    @VisibleForTesting public static final String DEFAULT_COUNTRY_CODE = "WW";
+
+    // Wait 1 hour between updates.
+    private static final long TIME_BETWEEN_LOCATION_UPDATES_MS = 1000L * 60 * 60 * 1;
+    // Minimum distance before an update is triggered, in meters. We don't need this to be too
+    // exact because all we care about is what country the user is in.
+    private static final float DISTANCE_BETWEEN_LOCALTION_UPDATES_METERS = 5_000.0f;
+
+    /** List of country code sources. */
+    @Retention(RetentionPolicy.SOURCE)
+    @StringDef(
+            prefix = "COUNTRY_CODE_SOURCE_",
+            value = {
+                COUNTRY_CODE_SOURCE_DEFAULT,
+                COUNTRY_CODE_SOURCE_LOCATION,
+                COUNTRY_CODE_SOURCE_OVERRIDE,
+            })
+    private @interface CountryCodeSource {}
+
+    private static final String COUNTRY_CODE_SOURCE_DEFAULT = "Default";
+    private static final String COUNTRY_CODE_SOURCE_LOCATION = "Location";
+    private static final String COUNTRY_CODE_SOURCE_OVERRIDE = "Override";
+    private static final CountryCodeInfo DEFAULT_COUNTRY_CODE_INFO =
+            new CountryCodeInfo(DEFAULT_COUNTRY_CODE, COUNTRY_CODE_SOURCE_DEFAULT);
+
+    private final LocationManager mLocationManager;
+    @Nullable private final Geocoder mGeocoder;
+    private final ThreadNetworkControllerService mThreadNetworkControllerService;
+
+    @Nullable private CountryCodeInfo mCurrentCountryCodeInfo;
+    @Nullable private CountryCodeInfo mLocationCountryCodeInfo;
+    @Nullable private CountryCodeInfo mOverrideCountryCodeInfo;
+
+    /** Container class to store Thread country code information. */
+    private static final class CountryCodeInfo {
+        private String mCountryCode;
+        @CountryCodeSource private String mSource;
+        private final Instant mUpdatedTimestamp;
+
+        public CountryCodeInfo(
+                String countryCode, @CountryCodeSource String countryCodeSource, Instant instant) {
+            mCountryCode = countryCode;
+            mSource = countryCodeSource;
+            mUpdatedTimestamp = instant;
+        }
+
+        public CountryCodeInfo(String countryCode, @CountryCodeSource String countryCodeSource) {
+            this(countryCode, countryCodeSource, Instant.now());
+        }
+
+        public String getCountryCode() {
+            return mCountryCode;
+        }
+
+        public boolean isCountryCodeMatch(CountryCodeInfo countryCodeInfo) {
+            if (countryCodeInfo == null) {
+                return false;
+            }
+
+            return Objects.equals(countryCodeInfo.mCountryCode, mCountryCode);
+        }
+
+        @Override
+        public String toString() {
+            return "CountryCodeInfo{ mCountryCode: "
+                    + mCountryCode
+                    + ", mSource: "
+                    + mSource
+                    + ", mUpdatedTimestamp: "
+                    + mUpdatedTimestamp
+                    + "}";
+        }
+    }
+
+    private boolean isLocationUseForCountryCodeEnabled() {
+        // TODO: b/311324956 read the configuration from the overlay configuration.
+        return true;
+    }
+
+    public ThreadNetworkCountryCode(
+            LocationManager locationManager,
+            ThreadNetworkControllerService threadNetworkControllerService,
+            @Nullable Geocoder geocoder) {
+        mLocationManager = locationManager;
+        mThreadNetworkControllerService = threadNetworkControllerService;
+        mGeocoder = geocoder;
+    }
+
+    /** Sets up this country code module to listen to location country code changes. */
+    public synchronized void initialize() {
+        registerGeocoderCountryCodeCallback();
+        updateCountryCode(false /* forceUpdate */);
+    }
+
+    private synchronized void registerGeocoderCountryCodeCallback() {
+        if ((mGeocoder != null) && isLocationUseForCountryCodeEnabled()) {
+            mLocationManager.requestLocationUpdates(
+                    LocationManager.PASSIVE_PROVIDER,
+                    TIME_BETWEEN_LOCATION_UPDATES_MS,
+                    DISTANCE_BETWEEN_LOCALTION_UPDATES_METERS,
+                    location -> setCountryCodeFromGeocodingLocation(location));
+        }
+    }
+
+    private synchronized void geocodeListener(List<Address> addresses) {
+        if (addresses != null && !addresses.isEmpty()) {
+            String countryCode = addresses.get(0).getCountryCode();
+
+            if (isValidCountryCode(countryCode)) {
+                Log.d(TAG, "Set location country code to: " + countryCode);
+                mLocationCountryCodeInfo =
+                        new CountryCodeInfo(countryCode, COUNTRY_CODE_SOURCE_LOCATION);
+            } else {
+                Log.d(TAG, "Received invalid location country code");
+                mLocationCountryCodeInfo = null;
+            }
+
+            updateCountryCode(false /* forceUpdate */);
+        }
+    }
+
+    private synchronized void setCountryCodeFromGeocodingLocation(@Nullable Location location) {
+        if ((location == null) || (mGeocoder == null)) return;
+
+        if (Build.VERSION.SDK_INT <= Build.VERSION_CODES.TIRAMISU) {
+            Log.wtf(
+                    TAG,
+                    "Unexpected call to set country code from the Geocoding location, "
+                            + "Thread code never runs under T or lower.");
+            return;
+        }
+
+        mGeocoder.getFromLocation(
+                location.getLatitude(),
+                location.getLongitude(),
+                1 /* maxResults */,
+                this::geocodeListener);
+    }
+
+    /**
+     * Priority order of country code sources (we stop at the first known country code source):
+     *
+     * <ul>
+     *   <li>1. Override country code - Country code forced via shell command (local/automated
+     *       testing)
+     *   <li>2. Location Country code - Country code retrieved from LocationManager passive location
+     *       provider.
+     * </ul>
+     *
+     * @return the selected country code information.
+     */
+    private CountryCodeInfo pickCountryCode() {
+        if (mOverrideCountryCodeInfo != null) {
+            return mOverrideCountryCodeInfo;
+        }
+
+        if (mLocationCountryCodeInfo != null) {
+            return mLocationCountryCodeInfo;
+        }
+
+        return DEFAULT_COUNTRY_CODE_INFO;
+    }
+
+    private IOperationReceiver newOperationReceiver(CountryCodeInfo countryCodeInfo) {
+        return new IOperationReceiver.Stub() {
+            @Override
+            public void onSuccess() {
+                synchronized ("ThreadNetworkCountryCode.this") {
+                    mCurrentCountryCodeInfo = countryCodeInfo;
+                }
+            }
+
+            @Override
+            public void onError(int otError, String message) {
+                Log.e(
+                        TAG,
+                        "Error "
+                                + otError
+                                + ": "
+                                + message
+                                + ". Failed to set country code "
+                                + countryCodeInfo);
+            }
+        };
+    }
+
+    /**
+     * Updates country code to the Thread native layer.
+     *
+     * @param forceUpdate Force update the country code even if it was the same as previously cached
+     *     value.
+     */
+    @VisibleForTesting
+    public synchronized void updateCountryCode(boolean forceUpdate) {
+        CountryCodeInfo countryCodeInfo = pickCountryCode();
+
+        if (!forceUpdate && countryCodeInfo.isCountryCodeMatch(mCurrentCountryCodeInfo)) {
+            Log.i(TAG, "Ignoring already set country code " + countryCodeInfo.getCountryCode());
+            return;
+        }
+
+        Log.i(TAG, "Set country code: " + countryCodeInfo);
+        mThreadNetworkControllerService.setCountryCode(
+                countryCodeInfo.getCountryCode().toUpperCase(Locale.ROOT),
+                newOperationReceiver(countryCodeInfo));
+    }
+
+    /** Returns the current country code or {@code null} if no country code is set. */
+    @Nullable
+    public synchronized String getCountryCode() {
+        return (mCurrentCountryCodeInfo != null) ? mCurrentCountryCodeInfo.getCountryCode() : null;
+    }
+
+    /**
+     * Returns {@code true} if {@code countryCode} is a valid country code.
+     *
+     * <p>A country code is valid if it consists of 2 alphabets.
+     */
+    public static boolean isValidCountryCode(String countryCode) {
+        return countryCode != null
+                && countryCode.length() == 2
+                && countryCode.chars().allMatch(Character::isLetter);
+    }
+
+    /**
+     * Overrides any existing country code.
+     *
+     * @param countryCode A 2-Character alphabetical country code (as defined in ISO 3166).
+     * @throws IllegalArgumentException if {@code countryCode} is an invalid country code.
+     */
+    public synchronized void setOverrideCountryCode(String countryCode) {
+        if (!isValidCountryCode(countryCode)) {
+            throw new IllegalArgumentException("The override country code is invalid");
+        }
+
+        mOverrideCountryCodeInfo = new CountryCodeInfo(countryCode, COUNTRY_CODE_SOURCE_OVERRIDE);
+        updateCountryCode(true /* forceUpdate */);
+    }
+
+    /** Clears the country code previously set through {@link #setOverrideCountryCode} method. */
+    public synchronized void clearOverrideCountryCode() {
+        mOverrideCountryCodeInfo = null;
+        updateCountryCode(true /* forceUpdate */);
+    }
+
+    /** Dumps the current state of this ThreadNetworkCountryCode object. */
+    public synchronized void dump(FileDescriptor fd, PrintWriter pw, String[] args) {
+        pw.println("---- Dump of ThreadNetworkCountryCode begin ----");
+        pw.println("mOverrideCountryCodeInfo: " + mOverrideCountryCodeInfo);
+        pw.println("mLocationCountryCodeInfo: " + mLocationCountryCodeInfo);
+        pw.println("mCurrentCountryCodeInfo: " + mCurrentCountryCodeInfo);
+        pw.println("---- Dump of ThreadNetworkCountryCode end ------");
+    }
+}
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkService.java b/thread/service/java/com/android/server/thread/ThreadNetworkService.java
index cc694a1..287bb8a 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkService.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkService.java
@@ -16,13 +16,22 @@
 
 package com.android.server.thread;
 
+import static android.content.pm.PackageManager.PERMISSION_GRANTED;
+
+import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.content.Context;
+import android.location.Geocoder;
+import android.location.LocationManager;
 import android.net.thread.IThreadNetworkController;
 import android.net.thread.IThreadNetworkManager;
+import android.os.Binder;
+import android.os.ParcelFileDescriptor;
 
 import com.android.server.SystemService;
 
+import java.io.FileDescriptor;
+import java.io.PrintWriter;
 import java.util.Collections;
 import java.util.List;
 
@@ -31,7 +40,9 @@
  */
 public class ThreadNetworkService extends IThreadNetworkManager.Stub {
     private final Context mContext;
+    @Nullable private ThreadNetworkCountryCode mCountryCode;
     @Nullable private ThreadNetworkControllerService mControllerService;
+    @Nullable private ThreadNetworkShellCommand mShellCommand;
 
     /** Creates a new {@link ThreadNetworkService} object. */
     public ThreadNetworkService(Context context) {
@@ -47,6 +58,13 @@
         if (phase == SystemService.PHASE_BOOT_COMPLETED) {
             mControllerService = ThreadNetworkControllerService.newInstance(mContext);
             mControllerService.initialize();
+            mCountryCode =
+                    new ThreadNetworkCountryCode(
+                            mContext.getSystemService(LocationManager.class),
+                            mControllerService,
+                            Geocoder.isPresent() ? new Geocoder(mContext) : null);
+            mCountryCode.initialize();
+            mShellCommand = new ThreadNetworkShellCommand(mCountryCode);
         }
     }
 
@@ -57,4 +75,40 @@
         }
         return Collections.singletonList(mControllerService);
     }
+
+    @Override
+    public int handleShellCommand(
+            @NonNull ParcelFileDescriptor in,
+            @NonNull ParcelFileDescriptor out,
+            @NonNull ParcelFileDescriptor err,
+            @NonNull String[] args) {
+        if (mShellCommand == null) {
+            return -1;
+        }
+        return mShellCommand.exec(
+                this,
+                in.getFileDescriptor(),
+                out.getFileDescriptor(),
+                err.getFileDescriptor(),
+                args);
+    }
+
+    @Override
+    protected void dump(FileDescriptor fd, PrintWriter pw, String[] args) {
+        if (mContext.checkCallingOrSelfPermission(android.Manifest.permission.DUMP)
+                != PERMISSION_GRANTED) {
+            pw.println(
+                    "Permission Denial: can't dump ThreadNetworkService from from pid="
+                            + Binder.getCallingPid()
+                            + ", uid="
+                            + Binder.getCallingUid());
+            return;
+        }
+
+        if (mCountryCode != null) {
+            mCountryCode.dump(fd, pw, args);
+        }
+
+        pw.println();
+    }
 }
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkShellCommand.java b/thread/service/java/com/android/server/thread/ThreadNetworkShellCommand.java
new file mode 100644
index 0000000..c17c5a7
--- /dev/null
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkShellCommand.java
@@ -0,0 +1,183 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.thread;
+
+import android.annotation.Nullable;
+import android.os.Binder;
+import android.os.Process;
+import android.text.TextUtils;
+
+import com.android.internal.annotations.VisibleForTesting;
+import com.android.modules.utils.BasicShellCommandHandler;
+
+import java.io.PrintWriter;
+import java.util.List;
+
+/**
+ * Interprets and executes 'adb shell cmd thread_network [args]'.
+ *
+ * <p>To add new commands: - onCommand: Add a case "<command>" execute. Return a 0 if command
+ * executed successfully. - onHelp: add a description string.
+ *
+ * <p>Permissions: currently root permission is required for some commands. Others will enforce the
+ * corresponding API permissions.
+ */
+public class ThreadNetworkShellCommand extends BasicShellCommandHandler {
+    private static final String TAG = "ThreadNetworkShellCommand";
+
+    // These don't require root access.
+    private static final List<String> NON_PRIVILEGED_COMMANDS = List.of("help", "get-country-code");
+
+    @Nullable private final ThreadNetworkCountryCode mCountryCode;
+    @Nullable private PrintWriter mOutputWriter;
+    @Nullable private PrintWriter mErrorWriter;
+
+    ThreadNetworkShellCommand(@Nullable ThreadNetworkCountryCode countryCode) {
+        mCountryCode = countryCode;
+    }
+
+    @VisibleForTesting
+    public void setPrintWriters(PrintWriter outputWriter, PrintWriter errorWriter) {
+        mOutputWriter = outputWriter;
+        mErrorWriter = errorWriter;
+    }
+
+    private PrintWriter getOutputWriter() {
+        return (mOutputWriter != null) ? mOutputWriter : getOutPrintWriter();
+    }
+
+    private PrintWriter getErrorWriter() {
+        return (mErrorWriter != null) ? mErrorWriter : getErrPrintWriter();
+    }
+
+    @Override
+    public int onCommand(String cmd) {
+        // Treat no command as help command.
+        if (TextUtils.isEmpty(cmd)) {
+            cmd = "help";
+        }
+
+        final PrintWriter pw = getOutputWriter();
+        final PrintWriter perr = getErrorWriter();
+
+        // Explicit exclusion from root permission
+        if (!NON_PRIVILEGED_COMMANDS.contains(cmd)) {
+            final int uid = Binder.getCallingUid();
+
+            if (uid != Process.ROOT_UID) {
+                perr.println(
+                        "Uid "
+                                + uid
+                                + " does not have access to "
+                                + cmd
+                                + " thread command "
+                                + "(or such command doesn't exist)");
+                return -1;
+            }
+        }
+
+        switch (cmd) {
+            case "force-country-code":
+                boolean enabled;
+
+                if (mCountryCode == null) {
+                    perr.println("Thread country code operations are not supported");
+                    return -1;
+                }
+
+                try {
+                    enabled = getNextArgRequiredTrueOrFalse("enabled", "disabled");
+                } catch (IllegalArgumentException e) {
+                    perr.println("Invalid argument: " + e.getMessage());
+                    return -1;
+                }
+
+                if (enabled) {
+                    String countryCode = getNextArgRequired();
+                    if (!ThreadNetworkCountryCode.isValidCountryCode(countryCode)) {
+                        perr.println(
+                                "Invalid argument: Country code must be a 2-Character"
+                                        + " string. But got country code "
+                                        + countryCode
+                                        + " instead");
+                        return -1;
+                    }
+                    mCountryCode.setOverrideCountryCode(countryCode);
+                    pw.println("Set Thread country code: " + countryCode);
+
+                } else {
+                    mCountryCode.clearOverrideCountryCode();
+                }
+                return 0;
+            case "get-country-code":
+                if (mCountryCode == null) {
+                    perr.println("Thread country code operations are not supported");
+                    return -1;
+                }
+
+                pw.println("Thread country code = " + mCountryCode.getCountryCode());
+                return 0;
+            default:
+                return handleDefaultCommands(cmd);
+        }
+    }
+
+    private static boolean argTrueOrFalse(String arg, String trueString, String falseString) {
+        if (trueString.equals(arg)) {
+            return true;
+        } else if (falseString.equals(arg)) {
+            return false;
+        } else {
+            throw new IllegalArgumentException(
+                    "Expected '"
+                            + trueString
+                            + "' or '"
+                            + falseString
+                            + "' as next arg but got '"
+                            + arg
+                            + "'");
+        }
+    }
+
+    private boolean getNextArgRequiredTrueOrFalse(String trueString, String falseString) {
+        String nextArg = getNextArgRequired();
+        return argTrueOrFalse(nextArg, trueString, falseString);
+    }
+
+    private void onHelpNonPrivileged(PrintWriter pw) {
+        pw.println("  get-country-code");
+        pw.println("    Gets country code as a two-letter string");
+    }
+
+    private void onHelpPrivileged(PrintWriter pw) {
+        pw.println("  force-country-code enabled <two-letter code> | disabled ");
+        pw.println("    Sets country code to <two-letter code> or left for normal value");
+    }
+
+    @Override
+    public void onHelp() {
+        final PrintWriter pw = getOutputWriter();
+        pw.println("Thread network commands:");
+        pw.println("  help or -h");
+        pw.println("    Print this help text.");
+        onHelpNonPrivileged(pw);
+        if (Binder.getCallingUid() == Process.ROOT_UID) {
+            onHelpPrivileged(pw);
+        }
+        pw.println();
+    }
+}
diff --git a/thread/tests/unit/Android.bp b/thread/tests/unit/Android.bp
index 8092693..74b4a35 100644
--- a/thread/tests/unit/Android.bp
+++ b/thread/tests/unit/Android.bp
@@ -31,14 +31,15 @@
         "general-tests",
     ],
     static_libs: [
-        "androidx.test.ext.junit",
-        "compatibility-device-util-axt",
+        "frameworks-base-testutils",
         "framework-connectivity-pre-jarjar",
         "framework-connectivity-t-pre-jarjar",
+        "framework-location.stubs.module_lib",
         "guava",
         "guava-android-testlib",
-        "mockito-target-minus-junit4",
+        "mockito-target-extended-minus-junit4",
         "net-tests-utils",
+        "service-thread-pre-jarjar",
         "truth",
     ],
     libs: [
@@ -46,6 +47,11 @@
         "android.test.runner",
     ],
     jarjar_rules: ":connectivity-jarjar-rules",
+    jni_libs: [
+        // these are needed for Extended Mockito
+        "libdexmakerjvmtiagent",
+        "libstaticjvmtiagent",
+    ],
     // Test coverage system runs on different devices. Need to
     // compile for all architectures.
     compile_multilib: "both",
diff --git a/thread/tests/unit/src/android/net/thread/ActiveOperationalDatasetTest.java b/thread/tests/unit/src/android/net/thread/ActiveOperationalDatasetTest.java
index 7284968..e92dcb9 100644
--- a/thread/tests/unit/src/android/net/thread/ActiveOperationalDatasetTest.java
+++ b/thread/tests/unit/src/android/net/thread/ActiveOperationalDatasetTest.java
@@ -33,12 +33,8 @@
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
-import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
-import java.security.SecureRandom;
-import java.util.Random;
-
 /** Unit tests for {@link ActiveOperationalDataset}. */
 @SmallTest
 @RunWith(AndroidJUnit4.class)
@@ -62,9 +58,6 @@
                                     + "642D643961300102D9A00410A245479C836D551B9CA557F7"
                                     + "B9D351B40C0402A0FFF8");
 
-    @Mock private Random mockRandom;
-    @Mock private SecureRandom mockSecureRandom;
-
     @Before
     public void setUp() {
         MockitoAnnotations.initMocks(this);
diff --git a/thread/tests/unit/src/com/android/server/thread/BinderUtil.java b/thread/tests/unit/src/com/android/server/thread/BinderUtil.java
new file mode 100644
index 0000000..3614bce
--- /dev/null
+++ b/thread/tests/unit/src/com/android/server/thread/BinderUtil.java
@@ -0,0 +1,31 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.thread;
+
+import android.os.Binder;
+
+/** Utilities for faking the calling uid in Binder. */
+public class BinderUtil {
+    /**
+     * Fake the calling uid in Binder.
+     *
+     * @param uid the calling uid that Binder should return from now on
+     */
+    public static void setUid(int uid) {
+        Binder.restoreCallingIdentity((((long) uid) << 32) | Binder.getCallingPid());
+    }
+}
diff --git a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkCountryCodeTest.java b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkCountryCodeTest.java
new file mode 100644
index 0000000..f51aa0a
--- /dev/null
+++ b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkCountryCodeTest.java
@@ -0,0 +1,191 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.thread;
+
+import static android.net.thread.ThreadNetworkException.ERROR_INTERNAL_ERROR;
+
+import static com.android.server.thread.ThreadNetworkCountryCode.DEFAULT_COUNTRY_CODE;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.anyDouble;
+import static org.mockito.Mockito.anyFloat;
+import static org.mockito.Mockito.anyInt;
+import static org.mockito.Mockito.anyLong;
+import static org.mockito.Mockito.anyString;
+import static org.mockito.Mockito.clearInvocations;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import android.content.pm.PackageManager;
+import android.location.Address;
+import android.location.Geocoder;
+import android.location.Location;
+import android.location.LocationListener;
+import android.location.LocationManager;
+import android.net.thread.IOperationReceiver;
+
+import androidx.test.filters.SmallTest;
+import androidx.test.runner.AndroidJUnit4;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Captor;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.mockito.stubbing.Answer;
+
+import java.util.List;
+import java.util.Locale;
+
+/** Unit tests for {@link ThreadNetworkCountryCode}. */
+@RunWith(AndroidJUnit4.class)
+@SmallTest
+public class ThreadNetworkCountryCodeTest {
+    private static final String TEST_COUNTRY_CODE_US = "US";
+    private static final String TEST_COUNTRY_CODE_CN = "CN";
+
+    @Mock LocationManager mLocationManager;
+    @Mock Geocoder mGeocoder;
+    @Mock ThreadNetworkControllerService mThreadNetworkControllerService;
+    @Mock PackageManager mPackageManager;
+    @Mock Location mLocation;
+
+    private ThreadNetworkCountryCode mThreadNetworkCountryCode;
+    private boolean mErrorSetCountryCode;
+
+    @Captor private ArgumentCaptor<LocationListener> mLocationListenerCaptor;
+    @Captor private ArgumentCaptor<Geocoder.GeocodeListener> mGeocodeListenerCaptor;
+    @Captor private ArgumentCaptor<IOperationReceiver> mOperationReceiverCaptor;
+
+    @Before
+    public void setUp() throws Exception {
+        MockitoAnnotations.initMocks(this);
+
+        when(mLocation.getLatitude()).thenReturn(0.0);
+        when(mLocation.getLongitude()).thenReturn(0.0);
+
+        Answer setCountryCodeCallback =
+                invocation -> {
+                    Object[] args = invocation.getArguments();
+                    IOperationReceiver cb = (IOperationReceiver) args[1];
+
+                    if (mErrorSetCountryCode) {
+                        cb.onError(ERROR_INTERNAL_ERROR, new String("Invalid country code"));
+                    } else {
+                        cb.onSuccess();
+                    }
+                    return new Object();
+                };
+
+        doAnswer(setCountryCodeCallback)
+                .when(mThreadNetworkControllerService)
+                .setCountryCode(any(), any(IOperationReceiver.class));
+
+        mThreadNetworkCountryCode =
+                new ThreadNetworkCountryCode(
+                        mLocationManager, mThreadNetworkControllerService, mGeocoder);
+    }
+
+    private static Address newAddress(String countryCode) {
+        Address address = new Address(Locale.ROOT);
+        address.setCountryCode(countryCode);
+        return address;
+    }
+
+    @Test
+    public void initialize_defaultCountryCodeIsUsed() {
+        mThreadNetworkCountryCode.initialize();
+
+        assertThat(mThreadNetworkCountryCode.getCountryCode()).isEqualTo(DEFAULT_COUNTRY_CODE);
+    }
+
+    @Test
+    public void locationCountryCode_locationChanged_locationCountryCodeIsUsed() {
+        mThreadNetworkCountryCode.initialize();
+
+        verify(mLocationManager)
+                .requestLocationUpdates(
+                        anyString(), anyLong(), anyFloat(), mLocationListenerCaptor.capture());
+        mLocationListenerCaptor.getValue().onLocationChanged(mLocation);
+        verify(mGeocoder)
+                .getFromLocation(
+                        anyDouble(), anyDouble(), anyInt(), mGeocodeListenerCaptor.capture());
+        mGeocodeListenerCaptor.getValue().onGeocode(List.of(newAddress(TEST_COUNTRY_CODE_US)));
+
+        assertThat(mThreadNetworkCountryCode.getCountryCode()).isEqualTo(TEST_COUNTRY_CODE_US);
+    }
+
+    @Test
+    public void updateCountryCode_noForceUpdateDefaultCountryCode_noCountryCodeIsUpdated() {
+        mThreadNetworkCountryCode.initialize();
+        clearInvocations(mThreadNetworkControllerService);
+
+        mThreadNetworkCountryCode.updateCountryCode(false /* forceUpdate */);
+
+        verify(mThreadNetworkControllerService, never()).setCountryCode(any(), any());
+    }
+
+    @Test
+    public void updateCountryCode_forceUpdateDefaultCountryCode_countryCodeIsUpdated() {
+        mThreadNetworkCountryCode.initialize();
+        clearInvocations(mThreadNetworkControllerService);
+
+        mThreadNetworkCountryCode.updateCountryCode(true /* forceUpdate */);
+
+        verify(mThreadNetworkControllerService)
+                .setCountryCode(eq(DEFAULT_COUNTRY_CODE), mOperationReceiverCaptor.capture());
+    }
+
+    @Test
+    public void setOverrideCountryCode_defaultCountryCodeAvailable_overrideCountryCodeIsUsed() {
+        mThreadNetworkCountryCode.initialize();
+
+        mThreadNetworkCountryCode.setOverrideCountryCode(TEST_COUNTRY_CODE_CN);
+
+        assertThat(mThreadNetworkCountryCode.getCountryCode()).isEqualTo(TEST_COUNTRY_CODE_CN);
+    }
+
+    @Test
+    public void clearOverrideCountryCode_defaultCountryCodeAvailable_defaultCountryCodeIsUsed() {
+        mThreadNetworkCountryCode.initialize();
+        mThreadNetworkCountryCode.setOverrideCountryCode(TEST_COUNTRY_CODE_CN);
+
+        mThreadNetworkCountryCode.clearOverrideCountryCode();
+
+        assertThat(mThreadNetworkCountryCode.getCountryCode()).isEqualTo(DEFAULT_COUNTRY_CODE);
+    }
+
+    @Test
+    public void setCountryCodeFailed_defaultCountryCodeAvailable_countryCodeIsNotUpdated() {
+        mThreadNetworkCountryCode.initialize();
+
+        mErrorSetCountryCode = true;
+        mThreadNetworkCountryCode.setOverrideCountryCode(TEST_COUNTRY_CODE_CN);
+
+        verify(mThreadNetworkControllerService)
+                .setCountryCode(eq(TEST_COUNTRY_CODE_CN), mOperationReceiverCaptor.capture());
+        assertThat(mThreadNetworkCountryCode.getCountryCode()).isEqualTo(DEFAULT_COUNTRY_CODE);
+    }
+}
diff --git a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkShellCommandTest.java b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkShellCommandTest.java
new file mode 100644
index 0000000..c7e0eca
--- /dev/null
+++ b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkShellCommandTest.java
@@ -0,0 +1,140 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.thread;
+
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.contains;
+import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.validateMockitoUsage;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import android.os.Binder;
+import android.os.Process;
+
+import androidx.test.filters.SmallTest;
+import androidx.test.runner.AndroidJUnit4;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+import java.io.FileDescriptor;
+import java.io.PrintWriter;
+
+/** Unit tests for {@link ThreadNetworkShellCommand}. */
+@RunWith(AndroidJUnit4.class)
+@SmallTest
+public class ThreadNetworkShellCommandTest {
+    private static final String TAG = "ThreadNetworkShellCommandTTest";
+    @Mock ThreadNetworkService mThreadNetworkService;
+    @Mock ThreadNetworkCountryCode mThreadNetworkCountryCode;
+    @Mock PrintWriter mErrorWriter;
+    @Mock PrintWriter mOutputWriter;
+
+    ThreadNetworkShellCommand mThreadNetworkShellCommand;
+
+    @Before
+    public void setUp() throws Exception {
+        MockitoAnnotations.initMocks(this);
+
+        mThreadNetworkShellCommand = new ThreadNetworkShellCommand(mThreadNetworkCountryCode);
+        mThreadNetworkShellCommand.setPrintWriters(mOutputWriter, mErrorWriter);
+    }
+
+    @After
+    public void tearDown() throws Exception {
+        validateMockitoUsage();
+    }
+
+    @Test
+    public void getCountryCode_executeInUnrootedShell_allowed() {
+        BinderUtil.setUid(Process.SHELL_UID);
+        when(mThreadNetworkCountryCode.getCountryCode()).thenReturn("US");
+
+        mThreadNetworkShellCommand.exec(
+                new Binder(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new String[] {"get-country-code"});
+
+        verify(mOutputWriter).println(contains("US"));
+    }
+
+    @Test
+    public void forceSetCountryCodeEnabled_executeInUnrootedShell_notAllowed() {
+        BinderUtil.setUid(Process.SHELL_UID);
+
+        mThreadNetworkShellCommand.exec(
+                new Binder(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new String[] {"force-country-code", "enabled", "US"});
+
+        verify(mThreadNetworkCountryCode, never()).setOverrideCountryCode(eq("US"));
+        verify(mErrorWriter).println(contains("force-country-code"));
+    }
+
+    @Test
+    public void forceSetCountryCodeEnabled_executeInRootedShell_allowed() {
+        BinderUtil.setUid(Process.ROOT_UID);
+
+        mThreadNetworkShellCommand.exec(
+                new Binder(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new String[] {"force-country-code", "enabled", "US"});
+
+        verify(mThreadNetworkCountryCode).setOverrideCountryCode(eq("US"));
+    }
+
+    @Test
+    public void forceSetCountryCodeDisabled_executeInUnrootedShell_notAllowed() {
+        BinderUtil.setUid(Process.SHELL_UID);
+
+        mThreadNetworkShellCommand.exec(
+                new Binder(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new String[] {"force-country-code", "disabled"});
+
+        verify(mThreadNetworkCountryCode, never()).setOverrideCountryCode(any());
+        verify(mErrorWriter).println(contains("force-country-code"));
+    }
+
+    @Test
+    public void forceSetCountryCodeDisabled_executeInRootedShell_allowed() {
+        BinderUtil.setUid(Process.ROOT_UID);
+
+        mThreadNetworkShellCommand.exec(
+                new Binder(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new String[] {"force-country-code", "disabled"});
+
+        verify(mThreadNetworkCountryCode).clearOverrideCountryCode();
+    }
+}