Set Thread Network country code from location country code

This CL selects the Thread country code from location country code.
If the location country code is not avaliable, the country code `WW`
will be selected as the default Thread country code.

This CL also adds Shell commands for developers to override the Thread
country code for testing.

Bug: b/309357909
Test: Run `atest ThreadNetworkUnitTests`.

Change-Id: Id87c293005f0e75922a72854b40c41837b74397f
diff --git a/thread/service/Android.bp b/thread/service/Android.bp
index 35ae3c2..92cdedc 100644
--- a/thread/service/Android.bp
+++ b/thread/service/Android.bp
@@ -35,9 +35,11 @@
     libs: [
         "framework-connectivity-pre-jarjar",
         "framework-connectivity-t-pre-jarjar",
+        "framework-location.stubs.module_lib",
         "service-connectivity-pre-jarjar",
     ],
     static_libs: [
+        "modules-utils-shell-command-handler",
         "net-utils-device-common",
         "net-utils-device-common-netlink",
         "ot-daemon-aidl-java",
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
index 2cd1be3..5ae310c 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
@@ -54,6 +54,8 @@
 import android.Manifest.permission;
 import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.annotation.RequiresPermission;
+import android.annotation.TargetApi;
 import android.content.Context;
 import android.net.ConnectivityManager;
 import android.net.IpPrefix;
@@ -82,6 +84,7 @@
 import android.net.thread.PendingOperationalDataset;
 import android.net.thread.ThreadNetworkController;
 import android.net.thread.ThreadNetworkController.DeviceRole;
+import android.os.Build;
 import android.os.Handler;
 import android.os.HandlerThread;
 import android.os.IBinder;
@@ -121,6 +124,7 @@
  * `mHandlerThread` 2. In the @Override methods, the actual work MUST be dispatched to the
  * HandlerThread except for arguments or permissions checking
  */
+@TargetApi(Build.VERSION_CODES.UPSIDE_DOWN_CAKE)
 final class ThreadNetworkControllerService extends IThreadNetworkController.Stub {
     private static final String TAG = "ThreadNetworkService";
 
@@ -742,6 +746,32 @@
         }
     }
 
+    /**
+     * Sets the country code.
+     *
+     * @param countryCode 2 characters string country code (as defined in ISO 3166) to set.
+     * @param receiver the receiver to receive result of this operation
+     */
+    @RequiresPermission(PERMISSION_THREAD_NETWORK_PRIVILEGED)
+    public void setCountryCode(@NonNull String countryCode, @NonNull IOperationReceiver receiver) {
+        enforceAllPermissionsGranted(PERMISSION_THREAD_NETWORK_PRIVILEGED);
+
+        OperationReceiverWrapper receiverWrapper = new OperationReceiverWrapper(receiver);
+        mHandler.post(() -> setCountryCodeInternal(countryCode, receiverWrapper));
+    }
+
+    private void setCountryCodeInternal(
+            String countryCode, @NonNull OperationReceiverWrapper receiver) {
+        checkOnHandlerThread();
+
+        try {
+            getOtDaemon().setCountryCode(countryCode, newOtStatusReceiver(receiver));
+        } catch (RemoteException e) {
+            Log.e(TAG, "otDaemon.setCountryCode failed", e);
+            receiver.onError(ERROR_INTERNAL_ERROR, "Thread stack error");
+        }
+    }
+
     private void enableBorderRouting(String infraIfName) {
         if (mBorderRouterConfig.isBorderRoutingEnabled
                 && infraIfName.equals(mBorderRouterConfig.infraInterfaceName)) {
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkCountryCode.java b/thread/service/java/com/android/server/thread/ThreadNetworkCountryCode.java
new file mode 100644
index 0000000..b05183c
--- /dev/null
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkCountryCode.java
@@ -0,0 +1,304 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.thread;
+
+import android.annotation.Nullable;
+import android.annotation.StringDef;
+import android.annotation.TargetApi;
+import android.location.Address;
+import android.location.Geocoder;
+import android.location.Location;
+import android.location.LocationManager;
+import android.net.thread.IOperationReceiver;
+import android.os.Build;
+import android.util.Log;
+
+import com.android.internal.annotations.VisibleForTesting;
+
+import java.io.FileDescriptor;
+import java.io.PrintWriter;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.time.Instant;
+import java.util.List;
+import java.util.Locale;
+import java.util.Objects;
+
+/**
+ * Provide functions for making changes to Thread Network country code. This Country Code is from
+ * location. This class sends Country Code to Thread Network native layer.
+ *
+ * <p>This class is thread-safe.
+ */
+@TargetApi(Build.VERSION_CODES.UPSIDE_DOWN_CAKE)
+public class ThreadNetworkCountryCode {
+    private static final String TAG = "ThreadNetworkCountryCode";
+    // To be used when there is no country code available.
+    @VisibleForTesting public static final String DEFAULT_COUNTRY_CODE = "WW";
+
+    // Wait 1 hour between updates.
+    private static final long TIME_BETWEEN_LOCATION_UPDATES_MS = 1000L * 60 * 60 * 1;
+    // Minimum distance before an update is triggered, in meters. We don't need this to be too
+    // exact because all we care about is what country the user is in.
+    private static final float DISTANCE_BETWEEN_LOCALTION_UPDATES_METERS = 5_000.0f;
+
+    /** List of country code sources. */
+    @Retention(RetentionPolicy.SOURCE)
+    @StringDef(
+            prefix = "COUNTRY_CODE_SOURCE_",
+            value = {
+                COUNTRY_CODE_SOURCE_DEFAULT,
+                COUNTRY_CODE_SOURCE_LOCATION,
+                COUNTRY_CODE_SOURCE_OVERRIDE,
+            })
+    private @interface CountryCodeSource {}
+
+    private static final String COUNTRY_CODE_SOURCE_DEFAULT = "Default";
+    private static final String COUNTRY_CODE_SOURCE_LOCATION = "Location";
+    private static final String COUNTRY_CODE_SOURCE_OVERRIDE = "Override";
+    private static final CountryCodeInfo DEFAULT_COUNTRY_CODE_INFO =
+            new CountryCodeInfo(DEFAULT_COUNTRY_CODE, COUNTRY_CODE_SOURCE_DEFAULT);
+
+    private final LocationManager mLocationManager;
+    @Nullable private final Geocoder mGeocoder;
+    private final ThreadNetworkControllerService mThreadNetworkControllerService;
+
+    @Nullable private CountryCodeInfo mCurrentCountryCodeInfo;
+    @Nullable private CountryCodeInfo mLocationCountryCodeInfo;
+    @Nullable private CountryCodeInfo mOverrideCountryCodeInfo;
+
+    /** Container class to store Thread country code information. */
+    private static final class CountryCodeInfo {
+        private String mCountryCode;
+        @CountryCodeSource private String mSource;
+        private final Instant mUpdatedTimestamp;
+
+        public CountryCodeInfo(
+                String countryCode, @CountryCodeSource String countryCodeSource, Instant instant) {
+            mCountryCode = countryCode;
+            mSource = countryCodeSource;
+            mUpdatedTimestamp = instant;
+        }
+
+        public CountryCodeInfo(String countryCode, @CountryCodeSource String countryCodeSource) {
+            this(countryCode, countryCodeSource, Instant.now());
+        }
+
+        public String getCountryCode() {
+            return mCountryCode;
+        }
+
+        public boolean isCountryCodeMatch(CountryCodeInfo countryCodeInfo) {
+            if (countryCodeInfo == null) {
+                return false;
+            }
+
+            return Objects.equals(countryCodeInfo.mCountryCode, mCountryCode);
+        }
+
+        @Override
+        public String toString() {
+            return "CountryCodeInfo{ mCountryCode: "
+                    + mCountryCode
+                    + ", mSource: "
+                    + mSource
+                    + ", mUpdatedTimestamp: "
+                    + mUpdatedTimestamp
+                    + "}";
+        }
+    }
+
+    private boolean isLocationUseForCountryCodeEnabled() {
+        // TODO: b/311324956 read the configuration from the overlay configuration.
+        return true;
+    }
+
+    public ThreadNetworkCountryCode(
+            LocationManager locationManager,
+            ThreadNetworkControllerService threadNetworkControllerService,
+            @Nullable Geocoder geocoder) {
+        mLocationManager = locationManager;
+        mThreadNetworkControllerService = threadNetworkControllerService;
+        mGeocoder = geocoder;
+    }
+
+    /** Sets up this country code module to listen to location country code changes. */
+    public synchronized void initialize() {
+        registerGeocoderCountryCodeCallback();
+        updateCountryCode(false /* forceUpdate */);
+    }
+
+    private synchronized void registerGeocoderCountryCodeCallback() {
+        if ((mGeocoder != null) && isLocationUseForCountryCodeEnabled()) {
+            mLocationManager.requestLocationUpdates(
+                    LocationManager.PASSIVE_PROVIDER,
+                    TIME_BETWEEN_LOCATION_UPDATES_MS,
+                    DISTANCE_BETWEEN_LOCALTION_UPDATES_METERS,
+                    location -> setCountryCodeFromGeocodingLocation(location));
+        }
+    }
+
+    private synchronized void geocodeListener(List<Address> addresses) {
+        if (addresses != null && !addresses.isEmpty()) {
+            String countryCode = addresses.get(0).getCountryCode();
+
+            if (isValidCountryCode(countryCode)) {
+                Log.d(TAG, "Set location country code to: " + countryCode);
+                mLocationCountryCodeInfo =
+                        new CountryCodeInfo(countryCode, COUNTRY_CODE_SOURCE_LOCATION);
+            } else {
+                Log.d(TAG, "Received invalid location country code");
+                mLocationCountryCodeInfo = null;
+            }
+
+            updateCountryCode(false /* forceUpdate */);
+        }
+    }
+
+    private synchronized void setCountryCodeFromGeocodingLocation(@Nullable Location location) {
+        if ((location == null) || (mGeocoder == null)) return;
+
+        if (Build.VERSION.SDK_INT <= Build.VERSION_CODES.TIRAMISU) {
+            Log.wtf(
+                    TAG,
+                    "Unexpected call to set country code from the Geocoding location, "
+                            + "Thread code never runs under T or lower.");
+            return;
+        }
+
+        mGeocoder.getFromLocation(
+                location.getLatitude(),
+                location.getLongitude(),
+                1 /* maxResults */,
+                this::geocodeListener);
+    }
+
+    /**
+     * Priority order of country code sources (we stop at the first known country code source):
+     *
+     * <ul>
+     *   <li>1. Override country code - Country code forced via shell command (local/automated
+     *       testing)
+     *   <li>2. Location Country code - Country code retrieved from LocationManager passive location
+     *       provider.
+     * </ul>
+     *
+     * @return the selected country code information.
+     */
+    private CountryCodeInfo pickCountryCode() {
+        if (mOverrideCountryCodeInfo != null) {
+            return mOverrideCountryCodeInfo;
+        }
+
+        if (mLocationCountryCodeInfo != null) {
+            return mLocationCountryCodeInfo;
+        }
+
+        return DEFAULT_COUNTRY_CODE_INFO;
+    }
+
+    private IOperationReceiver newOperationReceiver(CountryCodeInfo countryCodeInfo) {
+        return new IOperationReceiver.Stub() {
+            @Override
+            public void onSuccess() {
+                synchronized ("ThreadNetworkCountryCode.this") {
+                    mCurrentCountryCodeInfo = countryCodeInfo;
+                }
+            }
+
+            @Override
+            public void onError(int otError, String message) {
+                Log.e(
+                        TAG,
+                        "Error "
+                                + otError
+                                + ": "
+                                + message
+                                + ". Failed to set country code "
+                                + countryCodeInfo);
+            }
+        };
+    }
+
+    /**
+     * Updates country code to the Thread native layer.
+     *
+     * @param forceUpdate Force update the country code even if it was the same as previously cached
+     *     value.
+     */
+    @VisibleForTesting
+    public synchronized void updateCountryCode(boolean forceUpdate) {
+        CountryCodeInfo countryCodeInfo = pickCountryCode();
+
+        if (!forceUpdate && countryCodeInfo.isCountryCodeMatch(mCurrentCountryCodeInfo)) {
+            Log.i(TAG, "Ignoring already set country code " + countryCodeInfo.getCountryCode());
+            return;
+        }
+
+        Log.i(TAG, "Set country code: " + countryCodeInfo);
+        mThreadNetworkControllerService.setCountryCode(
+                countryCodeInfo.getCountryCode().toUpperCase(Locale.ROOT),
+                newOperationReceiver(countryCodeInfo));
+    }
+
+    /** Returns the current country code or {@code null} if no country code is set. */
+    @Nullable
+    public synchronized String getCountryCode() {
+        return (mCurrentCountryCodeInfo != null) ? mCurrentCountryCodeInfo.getCountryCode() : null;
+    }
+
+    /**
+     * Returns {@code true} if {@code countryCode} is a valid country code.
+     *
+     * <p>A country code is valid if it consists of 2 alphabets.
+     */
+    public static boolean isValidCountryCode(String countryCode) {
+        return countryCode != null
+                && countryCode.length() == 2
+                && countryCode.chars().allMatch(Character::isLetter);
+    }
+
+    /**
+     * Overrides any existing country code.
+     *
+     * @param countryCode A 2-Character alphabetical country code (as defined in ISO 3166).
+     * @throws IllegalArgumentException if {@code countryCode} is an invalid country code.
+     */
+    public synchronized void setOverrideCountryCode(String countryCode) {
+        if (!isValidCountryCode(countryCode)) {
+            throw new IllegalArgumentException("The override country code is invalid");
+        }
+
+        mOverrideCountryCodeInfo = new CountryCodeInfo(countryCode, COUNTRY_CODE_SOURCE_OVERRIDE);
+        updateCountryCode(true /* forceUpdate */);
+    }
+
+    /** Clears the country code previously set through {@link #setOverrideCountryCode} method. */
+    public synchronized void clearOverrideCountryCode() {
+        mOverrideCountryCodeInfo = null;
+        updateCountryCode(true /* forceUpdate */);
+    }
+
+    /** Dumps the current state of this ThreadNetworkCountryCode object. */
+    public synchronized void dump(FileDescriptor fd, PrintWriter pw, String[] args) {
+        pw.println("---- Dump of ThreadNetworkCountryCode begin ----");
+        pw.println("mOverrideCountryCodeInfo: " + mOverrideCountryCodeInfo);
+        pw.println("mLocationCountryCodeInfo: " + mLocationCountryCodeInfo);
+        pw.println("mCurrentCountryCodeInfo: " + mCurrentCountryCodeInfo);
+        pw.println("---- Dump of ThreadNetworkCountryCode end ------");
+    }
+}
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkService.java b/thread/service/java/com/android/server/thread/ThreadNetworkService.java
index cc694a1..287bb8a 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkService.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkService.java
@@ -16,13 +16,22 @@
 
 package com.android.server.thread;
 
+import static android.content.pm.PackageManager.PERMISSION_GRANTED;
+
+import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.content.Context;
+import android.location.Geocoder;
+import android.location.LocationManager;
 import android.net.thread.IThreadNetworkController;
 import android.net.thread.IThreadNetworkManager;
+import android.os.Binder;
+import android.os.ParcelFileDescriptor;
 
 import com.android.server.SystemService;
 
+import java.io.FileDescriptor;
+import java.io.PrintWriter;
 import java.util.Collections;
 import java.util.List;
 
@@ -31,7 +40,9 @@
  */
 public class ThreadNetworkService extends IThreadNetworkManager.Stub {
     private final Context mContext;
+    @Nullable private ThreadNetworkCountryCode mCountryCode;
     @Nullable private ThreadNetworkControllerService mControllerService;
+    @Nullable private ThreadNetworkShellCommand mShellCommand;
 
     /** Creates a new {@link ThreadNetworkService} object. */
     public ThreadNetworkService(Context context) {
@@ -47,6 +58,13 @@
         if (phase == SystemService.PHASE_BOOT_COMPLETED) {
             mControllerService = ThreadNetworkControllerService.newInstance(mContext);
             mControllerService.initialize();
+            mCountryCode =
+                    new ThreadNetworkCountryCode(
+                            mContext.getSystemService(LocationManager.class),
+                            mControllerService,
+                            Geocoder.isPresent() ? new Geocoder(mContext) : null);
+            mCountryCode.initialize();
+            mShellCommand = new ThreadNetworkShellCommand(mCountryCode);
         }
     }
 
@@ -57,4 +75,40 @@
         }
         return Collections.singletonList(mControllerService);
     }
+
+    @Override
+    public int handleShellCommand(
+            @NonNull ParcelFileDescriptor in,
+            @NonNull ParcelFileDescriptor out,
+            @NonNull ParcelFileDescriptor err,
+            @NonNull String[] args) {
+        if (mShellCommand == null) {
+            return -1;
+        }
+        return mShellCommand.exec(
+                this,
+                in.getFileDescriptor(),
+                out.getFileDescriptor(),
+                err.getFileDescriptor(),
+                args);
+    }
+
+    @Override
+    protected void dump(FileDescriptor fd, PrintWriter pw, String[] args) {
+        if (mContext.checkCallingOrSelfPermission(android.Manifest.permission.DUMP)
+                != PERMISSION_GRANTED) {
+            pw.println(
+                    "Permission Denial: can't dump ThreadNetworkService from from pid="
+                            + Binder.getCallingPid()
+                            + ", uid="
+                            + Binder.getCallingUid());
+            return;
+        }
+
+        if (mCountryCode != null) {
+            mCountryCode.dump(fd, pw, args);
+        }
+
+        pw.println();
+    }
 }
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkShellCommand.java b/thread/service/java/com/android/server/thread/ThreadNetworkShellCommand.java
new file mode 100644
index 0000000..c17c5a7
--- /dev/null
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkShellCommand.java
@@ -0,0 +1,183 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.thread;
+
+import android.annotation.Nullable;
+import android.os.Binder;
+import android.os.Process;
+import android.text.TextUtils;
+
+import com.android.internal.annotations.VisibleForTesting;
+import com.android.modules.utils.BasicShellCommandHandler;
+
+import java.io.PrintWriter;
+import java.util.List;
+
+/**
+ * Interprets and executes 'adb shell cmd thread_network [args]'.
+ *
+ * <p>To add new commands: - onCommand: Add a case "<command>" execute. Return a 0 if command
+ * executed successfully. - onHelp: add a description string.
+ *
+ * <p>Permissions: currently root permission is required for some commands. Others will enforce the
+ * corresponding API permissions.
+ */
+public class ThreadNetworkShellCommand extends BasicShellCommandHandler {
+    private static final String TAG = "ThreadNetworkShellCommand";
+
+    // These don't require root access.
+    private static final List<String> NON_PRIVILEGED_COMMANDS = List.of("help", "get-country-code");
+
+    @Nullable private final ThreadNetworkCountryCode mCountryCode;
+    @Nullable private PrintWriter mOutputWriter;
+    @Nullable private PrintWriter mErrorWriter;
+
+    ThreadNetworkShellCommand(@Nullable ThreadNetworkCountryCode countryCode) {
+        mCountryCode = countryCode;
+    }
+
+    @VisibleForTesting
+    public void setPrintWriters(PrintWriter outputWriter, PrintWriter errorWriter) {
+        mOutputWriter = outputWriter;
+        mErrorWriter = errorWriter;
+    }
+
+    private PrintWriter getOutputWriter() {
+        return (mOutputWriter != null) ? mOutputWriter : getOutPrintWriter();
+    }
+
+    private PrintWriter getErrorWriter() {
+        return (mErrorWriter != null) ? mErrorWriter : getErrPrintWriter();
+    }
+
+    @Override
+    public int onCommand(String cmd) {
+        // Treat no command as help command.
+        if (TextUtils.isEmpty(cmd)) {
+            cmd = "help";
+        }
+
+        final PrintWriter pw = getOutputWriter();
+        final PrintWriter perr = getErrorWriter();
+
+        // Explicit exclusion from root permission
+        if (!NON_PRIVILEGED_COMMANDS.contains(cmd)) {
+            final int uid = Binder.getCallingUid();
+
+            if (uid != Process.ROOT_UID) {
+                perr.println(
+                        "Uid "
+                                + uid
+                                + " does not have access to "
+                                + cmd
+                                + " thread command "
+                                + "(or such command doesn't exist)");
+                return -1;
+            }
+        }
+
+        switch (cmd) {
+            case "force-country-code":
+                boolean enabled;
+
+                if (mCountryCode == null) {
+                    perr.println("Thread country code operations are not supported");
+                    return -1;
+                }
+
+                try {
+                    enabled = getNextArgRequiredTrueOrFalse("enabled", "disabled");
+                } catch (IllegalArgumentException e) {
+                    perr.println("Invalid argument: " + e.getMessage());
+                    return -1;
+                }
+
+                if (enabled) {
+                    String countryCode = getNextArgRequired();
+                    if (!ThreadNetworkCountryCode.isValidCountryCode(countryCode)) {
+                        perr.println(
+                                "Invalid argument: Country code must be a 2-Character"
+                                        + " string. But got country code "
+                                        + countryCode
+                                        + " instead");
+                        return -1;
+                    }
+                    mCountryCode.setOverrideCountryCode(countryCode);
+                    pw.println("Set Thread country code: " + countryCode);
+
+                } else {
+                    mCountryCode.clearOverrideCountryCode();
+                }
+                return 0;
+            case "get-country-code":
+                if (mCountryCode == null) {
+                    perr.println("Thread country code operations are not supported");
+                    return -1;
+                }
+
+                pw.println("Thread country code = " + mCountryCode.getCountryCode());
+                return 0;
+            default:
+                return handleDefaultCommands(cmd);
+        }
+    }
+
+    private static boolean argTrueOrFalse(String arg, String trueString, String falseString) {
+        if (trueString.equals(arg)) {
+            return true;
+        } else if (falseString.equals(arg)) {
+            return false;
+        } else {
+            throw new IllegalArgumentException(
+                    "Expected '"
+                            + trueString
+                            + "' or '"
+                            + falseString
+                            + "' as next arg but got '"
+                            + arg
+                            + "'");
+        }
+    }
+
+    private boolean getNextArgRequiredTrueOrFalse(String trueString, String falseString) {
+        String nextArg = getNextArgRequired();
+        return argTrueOrFalse(nextArg, trueString, falseString);
+    }
+
+    private void onHelpNonPrivileged(PrintWriter pw) {
+        pw.println("  get-country-code");
+        pw.println("    Gets country code as a two-letter string");
+    }
+
+    private void onHelpPrivileged(PrintWriter pw) {
+        pw.println("  force-country-code enabled <two-letter code> | disabled ");
+        pw.println("    Sets country code to <two-letter code> or left for normal value");
+    }
+
+    @Override
+    public void onHelp() {
+        final PrintWriter pw = getOutputWriter();
+        pw.println("Thread network commands:");
+        pw.println("  help or -h");
+        pw.println("    Print this help text.");
+        onHelpNonPrivileged(pw);
+        if (Binder.getCallingUid() == Process.ROOT_UID) {
+            onHelpPrivileged(pw);
+        }
+        pw.println();
+    }
+}
diff --git a/thread/tests/unit/Android.bp b/thread/tests/unit/Android.bp
index 8092693..74b4a35 100644
--- a/thread/tests/unit/Android.bp
+++ b/thread/tests/unit/Android.bp
@@ -31,14 +31,15 @@
         "general-tests",
     ],
     static_libs: [
-        "androidx.test.ext.junit",
-        "compatibility-device-util-axt",
+        "frameworks-base-testutils",
         "framework-connectivity-pre-jarjar",
         "framework-connectivity-t-pre-jarjar",
+        "framework-location.stubs.module_lib",
         "guava",
         "guava-android-testlib",
-        "mockito-target-minus-junit4",
+        "mockito-target-extended-minus-junit4",
         "net-tests-utils",
+        "service-thread-pre-jarjar",
         "truth",
     ],
     libs: [
@@ -46,6 +47,11 @@
         "android.test.runner",
     ],
     jarjar_rules: ":connectivity-jarjar-rules",
+    jni_libs: [
+        // these are needed for Extended Mockito
+        "libdexmakerjvmtiagent",
+        "libstaticjvmtiagent",
+    ],
     // Test coverage system runs on different devices. Need to
     // compile for all architectures.
     compile_multilib: "both",
diff --git a/thread/tests/unit/src/android/net/thread/ActiveOperationalDatasetTest.java b/thread/tests/unit/src/android/net/thread/ActiveOperationalDatasetTest.java
index 7284968..e92dcb9 100644
--- a/thread/tests/unit/src/android/net/thread/ActiveOperationalDatasetTest.java
+++ b/thread/tests/unit/src/android/net/thread/ActiveOperationalDatasetTest.java
@@ -33,12 +33,8 @@
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
-import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
-import java.security.SecureRandom;
-import java.util.Random;
-
 /** Unit tests for {@link ActiveOperationalDataset}. */
 @SmallTest
 @RunWith(AndroidJUnit4.class)
@@ -62,9 +58,6 @@
                                     + "642D643961300102D9A00410A245479C836D551B9CA557F7"
                                     + "B9D351B40C0402A0FFF8");
 
-    @Mock private Random mockRandom;
-    @Mock private SecureRandom mockSecureRandom;
-
     @Before
     public void setUp() {
         MockitoAnnotations.initMocks(this);
diff --git a/thread/tests/unit/src/com/android/server/thread/BinderUtil.java b/thread/tests/unit/src/com/android/server/thread/BinderUtil.java
new file mode 100644
index 0000000..3614bce
--- /dev/null
+++ b/thread/tests/unit/src/com/android/server/thread/BinderUtil.java
@@ -0,0 +1,31 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.thread;
+
+import android.os.Binder;
+
+/** Utilities for faking the calling uid in Binder. */
+public class BinderUtil {
+    /**
+     * Fake the calling uid in Binder.
+     *
+     * @param uid the calling uid that Binder should return from now on
+     */
+    public static void setUid(int uid) {
+        Binder.restoreCallingIdentity((((long) uid) << 32) | Binder.getCallingPid());
+    }
+}
diff --git a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkCountryCodeTest.java b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkCountryCodeTest.java
new file mode 100644
index 0000000..f51aa0a
--- /dev/null
+++ b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkCountryCodeTest.java
@@ -0,0 +1,191 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.thread;
+
+import static android.net.thread.ThreadNetworkException.ERROR_INTERNAL_ERROR;
+
+import static com.android.server.thread.ThreadNetworkCountryCode.DEFAULT_COUNTRY_CODE;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.anyDouble;
+import static org.mockito.Mockito.anyFloat;
+import static org.mockito.Mockito.anyInt;
+import static org.mockito.Mockito.anyLong;
+import static org.mockito.Mockito.anyString;
+import static org.mockito.Mockito.clearInvocations;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import android.content.pm.PackageManager;
+import android.location.Address;
+import android.location.Geocoder;
+import android.location.Location;
+import android.location.LocationListener;
+import android.location.LocationManager;
+import android.net.thread.IOperationReceiver;
+
+import androidx.test.filters.SmallTest;
+import androidx.test.runner.AndroidJUnit4;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Captor;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.mockito.stubbing.Answer;
+
+import java.util.List;
+import java.util.Locale;
+
+/** Unit tests for {@link ThreadNetworkCountryCode}. */
+@RunWith(AndroidJUnit4.class)
+@SmallTest
+public class ThreadNetworkCountryCodeTest {
+    private static final String TEST_COUNTRY_CODE_US = "US";
+    private static final String TEST_COUNTRY_CODE_CN = "CN";
+
+    @Mock LocationManager mLocationManager;
+    @Mock Geocoder mGeocoder;
+    @Mock ThreadNetworkControllerService mThreadNetworkControllerService;
+    @Mock PackageManager mPackageManager;
+    @Mock Location mLocation;
+
+    private ThreadNetworkCountryCode mThreadNetworkCountryCode;
+    private boolean mErrorSetCountryCode;
+
+    @Captor private ArgumentCaptor<LocationListener> mLocationListenerCaptor;
+    @Captor private ArgumentCaptor<Geocoder.GeocodeListener> mGeocodeListenerCaptor;
+    @Captor private ArgumentCaptor<IOperationReceiver> mOperationReceiverCaptor;
+
+    @Before
+    public void setUp() throws Exception {
+        MockitoAnnotations.initMocks(this);
+
+        when(mLocation.getLatitude()).thenReturn(0.0);
+        when(mLocation.getLongitude()).thenReturn(0.0);
+
+        Answer setCountryCodeCallback =
+                invocation -> {
+                    Object[] args = invocation.getArguments();
+                    IOperationReceiver cb = (IOperationReceiver) args[1];
+
+                    if (mErrorSetCountryCode) {
+                        cb.onError(ERROR_INTERNAL_ERROR, new String("Invalid country code"));
+                    } else {
+                        cb.onSuccess();
+                    }
+                    return new Object();
+                };
+
+        doAnswer(setCountryCodeCallback)
+                .when(mThreadNetworkControllerService)
+                .setCountryCode(any(), any(IOperationReceiver.class));
+
+        mThreadNetworkCountryCode =
+                new ThreadNetworkCountryCode(
+                        mLocationManager, mThreadNetworkControllerService, mGeocoder);
+    }
+
+    private static Address newAddress(String countryCode) {
+        Address address = new Address(Locale.ROOT);
+        address.setCountryCode(countryCode);
+        return address;
+    }
+
+    @Test
+    public void initialize_defaultCountryCodeIsUsed() {
+        mThreadNetworkCountryCode.initialize();
+
+        assertThat(mThreadNetworkCountryCode.getCountryCode()).isEqualTo(DEFAULT_COUNTRY_CODE);
+    }
+
+    @Test
+    public void locationCountryCode_locationChanged_locationCountryCodeIsUsed() {
+        mThreadNetworkCountryCode.initialize();
+
+        verify(mLocationManager)
+                .requestLocationUpdates(
+                        anyString(), anyLong(), anyFloat(), mLocationListenerCaptor.capture());
+        mLocationListenerCaptor.getValue().onLocationChanged(mLocation);
+        verify(mGeocoder)
+                .getFromLocation(
+                        anyDouble(), anyDouble(), anyInt(), mGeocodeListenerCaptor.capture());
+        mGeocodeListenerCaptor.getValue().onGeocode(List.of(newAddress(TEST_COUNTRY_CODE_US)));
+
+        assertThat(mThreadNetworkCountryCode.getCountryCode()).isEqualTo(TEST_COUNTRY_CODE_US);
+    }
+
+    @Test
+    public void updateCountryCode_noForceUpdateDefaultCountryCode_noCountryCodeIsUpdated() {
+        mThreadNetworkCountryCode.initialize();
+        clearInvocations(mThreadNetworkControllerService);
+
+        mThreadNetworkCountryCode.updateCountryCode(false /* forceUpdate */);
+
+        verify(mThreadNetworkControllerService, never()).setCountryCode(any(), any());
+    }
+
+    @Test
+    public void updateCountryCode_forceUpdateDefaultCountryCode_countryCodeIsUpdated() {
+        mThreadNetworkCountryCode.initialize();
+        clearInvocations(mThreadNetworkControllerService);
+
+        mThreadNetworkCountryCode.updateCountryCode(true /* forceUpdate */);
+
+        verify(mThreadNetworkControllerService)
+                .setCountryCode(eq(DEFAULT_COUNTRY_CODE), mOperationReceiverCaptor.capture());
+    }
+
+    @Test
+    public void setOverrideCountryCode_defaultCountryCodeAvailable_overrideCountryCodeIsUsed() {
+        mThreadNetworkCountryCode.initialize();
+
+        mThreadNetworkCountryCode.setOverrideCountryCode(TEST_COUNTRY_CODE_CN);
+
+        assertThat(mThreadNetworkCountryCode.getCountryCode()).isEqualTo(TEST_COUNTRY_CODE_CN);
+    }
+
+    @Test
+    public void clearOverrideCountryCode_defaultCountryCodeAvailable_defaultCountryCodeIsUsed() {
+        mThreadNetworkCountryCode.initialize();
+        mThreadNetworkCountryCode.setOverrideCountryCode(TEST_COUNTRY_CODE_CN);
+
+        mThreadNetworkCountryCode.clearOverrideCountryCode();
+
+        assertThat(mThreadNetworkCountryCode.getCountryCode()).isEqualTo(DEFAULT_COUNTRY_CODE);
+    }
+
+    @Test
+    public void setCountryCodeFailed_defaultCountryCodeAvailable_countryCodeIsNotUpdated() {
+        mThreadNetworkCountryCode.initialize();
+
+        mErrorSetCountryCode = true;
+        mThreadNetworkCountryCode.setOverrideCountryCode(TEST_COUNTRY_CODE_CN);
+
+        verify(mThreadNetworkControllerService)
+                .setCountryCode(eq(TEST_COUNTRY_CODE_CN), mOperationReceiverCaptor.capture());
+        assertThat(mThreadNetworkCountryCode.getCountryCode()).isEqualTo(DEFAULT_COUNTRY_CODE);
+    }
+}
diff --git a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkShellCommandTest.java b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkShellCommandTest.java
new file mode 100644
index 0000000..c7e0eca
--- /dev/null
+++ b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkShellCommandTest.java
@@ -0,0 +1,140 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.thread;
+
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.contains;
+import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.validateMockitoUsage;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import android.os.Binder;
+import android.os.Process;
+
+import androidx.test.filters.SmallTest;
+import androidx.test.runner.AndroidJUnit4;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+import java.io.FileDescriptor;
+import java.io.PrintWriter;
+
+/** Unit tests for {@link ThreadNetworkShellCommand}. */
+@RunWith(AndroidJUnit4.class)
+@SmallTest
+public class ThreadNetworkShellCommandTest {
+    private static final String TAG = "ThreadNetworkShellCommandTTest";
+    @Mock ThreadNetworkService mThreadNetworkService;
+    @Mock ThreadNetworkCountryCode mThreadNetworkCountryCode;
+    @Mock PrintWriter mErrorWriter;
+    @Mock PrintWriter mOutputWriter;
+
+    ThreadNetworkShellCommand mThreadNetworkShellCommand;
+
+    @Before
+    public void setUp() throws Exception {
+        MockitoAnnotations.initMocks(this);
+
+        mThreadNetworkShellCommand = new ThreadNetworkShellCommand(mThreadNetworkCountryCode);
+        mThreadNetworkShellCommand.setPrintWriters(mOutputWriter, mErrorWriter);
+    }
+
+    @After
+    public void tearDown() throws Exception {
+        validateMockitoUsage();
+    }
+
+    @Test
+    public void getCountryCode_executeInUnrootedShell_allowed() {
+        BinderUtil.setUid(Process.SHELL_UID);
+        when(mThreadNetworkCountryCode.getCountryCode()).thenReturn("US");
+
+        mThreadNetworkShellCommand.exec(
+                new Binder(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new String[] {"get-country-code"});
+
+        verify(mOutputWriter).println(contains("US"));
+    }
+
+    @Test
+    public void forceSetCountryCodeEnabled_executeInUnrootedShell_notAllowed() {
+        BinderUtil.setUid(Process.SHELL_UID);
+
+        mThreadNetworkShellCommand.exec(
+                new Binder(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new String[] {"force-country-code", "enabled", "US"});
+
+        verify(mThreadNetworkCountryCode, never()).setOverrideCountryCode(eq("US"));
+        verify(mErrorWriter).println(contains("force-country-code"));
+    }
+
+    @Test
+    public void forceSetCountryCodeEnabled_executeInRootedShell_allowed() {
+        BinderUtil.setUid(Process.ROOT_UID);
+
+        mThreadNetworkShellCommand.exec(
+                new Binder(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new String[] {"force-country-code", "enabled", "US"});
+
+        verify(mThreadNetworkCountryCode).setOverrideCountryCode(eq("US"));
+    }
+
+    @Test
+    public void forceSetCountryCodeDisabled_executeInUnrootedShell_notAllowed() {
+        BinderUtil.setUid(Process.SHELL_UID);
+
+        mThreadNetworkShellCommand.exec(
+                new Binder(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new String[] {"force-country-code", "disabled"});
+
+        verify(mThreadNetworkCountryCode, never()).setOverrideCountryCode(any());
+        verify(mErrorWriter).println(contains("force-country-code"));
+    }
+
+    @Test
+    public void forceSetCountryCodeDisabled_executeInRootedShell_allowed() {
+        BinderUtil.setUid(Process.ROOT_UID);
+
+        mThreadNetworkShellCommand.exec(
+                new Binder(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new String[] {"force-country-code", "disabled"});
+
+        verify(mThreadNetworkCountryCode).clearOverrideCountryCode();
+    }
+}