Add WiFi country code to Thread country code sources

Bug: b/309357909
Test: Run `atest ThreadNetworkUnitTests`.
Change-Id: I087cc0fc2b6c21fd9d3d4607f30c194214aeb37e
diff --git a/thread/service/Android.bp b/thread/service/Android.bp
index b5fee95..69295cc 100644
--- a/thread/service/Android.bp
+++ b/thread/service/Android.bp
@@ -36,6 +36,7 @@
         "framework-connectivity-pre-jarjar",
         "framework-connectivity-t-pre-jarjar",
         "framework-location.stubs.module_lib",
+        "framework-wifi",
         "service-connectivity-pre-jarjar",
         "ServiceConnectivityResources",
     ],
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkCountryCode.java b/thread/service/java/com/android/server/thread/ThreadNetworkCountryCode.java
index df2c56e..7845209 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkCountryCode.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkCountryCode.java
@@ -24,6 +24,8 @@
 import android.location.Location;
 import android.location.LocationManager;
 import android.net.thread.IOperationReceiver;
+import android.net.wifi.WifiManager;
+import android.net.wifi.WifiManager.ActiveCountryCodeChangedCallback;
 import android.os.Build;
 import android.util.Log;
 
@@ -42,7 +44,7 @@
 
 /**
  * Provide functions for making changes to Thread Network country code. This Country Code is from
- * location. This class sends Country Code to Thread Network native layer.
+ * location or WiFi configuration. This class sends Country Code to Thread Network native layer.
  *
  * <p>This class is thread-safe.
  */
@@ -66,12 +68,14 @@
                 COUNTRY_CODE_SOURCE_DEFAULT,
                 COUNTRY_CODE_SOURCE_LOCATION,
                 COUNTRY_CODE_SOURCE_OVERRIDE,
+                COUNTRY_CODE_SOURCE_WIFI,
             })
     private @interface CountryCodeSource {}
 
     private static final String COUNTRY_CODE_SOURCE_DEFAULT = "Default";
     private static final String COUNTRY_CODE_SOURCE_LOCATION = "Location";
     private static final String COUNTRY_CODE_SOURCE_OVERRIDE = "Override";
+    private static final String COUNTRY_CODE_SOURCE_WIFI = "Wifi";
     private static final CountryCodeInfo DEFAULT_COUNTRY_CODE_INFO =
             new CountryCodeInfo(DEFAULT_COUNTRY_CODE, COUNTRY_CODE_SOURCE_DEFAULT);
 
@@ -79,10 +83,12 @@
     private final LocationManager mLocationManager;
     @Nullable private final Geocoder mGeocoder;
     private final ThreadNetworkControllerService mThreadNetworkControllerService;
+    private final WifiManager mWifiManager;
 
     @Nullable private CountryCodeInfo mCurrentCountryCodeInfo;
     @Nullable private CountryCodeInfo mLocationCountryCodeInfo;
     @Nullable private CountryCodeInfo mOverrideCountryCodeInfo;
+    @Nullable private CountryCodeInfo mWifiCountryCodeInfo;
 
     /** Container class to store Thread country code information. */
     private static final class CountryCodeInfo {
@@ -135,16 +141,19 @@
             LocationManager locationManager,
             ThreadNetworkControllerService threadNetworkControllerService,
             @Nullable Geocoder geocoder,
-            ConnectivityResources resources) {
+            ConnectivityResources resources,
+            WifiManager wifiManager) {
         mLocationManager = locationManager;
         mThreadNetworkControllerService = threadNetworkControllerService;
         mGeocoder = geocoder;
         mResources = resources;
+        mWifiManager = wifiManager;
     }
 
     /** Sets up this country code module to listen to location country code changes. */
     public synchronized void initialize() {
         registerGeocoderCountryCodeCallback();
+        registerWifiCountryCodeCallback();
         updateCountryCode(false /* forceUpdate */);
     }
 
@@ -193,13 +202,41 @@
                 this::geocodeListener);
     }
 
+    private synchronized void registerWifiCountryCodeCallback() {
+        if (mWifiManager != null) {
+            mWifiManager.registerActiveCountryCodeChangedCallback(
+                    r -> r.run(), new WifiCountryCodeCallback());
+        }
+    }
+
+    private class WifiCountryCodeCallback implements ActiveCountryCodeChangedCallback {
+        @Override
+        public void onActiveCountryCodeChanged(String countryCode) {
+            Log.d(TAG, "Wifi country code is changed to " + countryCode);
+            synchronized ("ThreadNetworkCountryCode.this") {
+                mWifiCountryCodeInfo = new CountryCodeInfo(countryCode, COUNTRY_CODE_SOURCE_WIFI);
+                updateCountryCode(false /* forceUpdate */);
+            }
+        }
+
+        @Override
+        public void onCountryCodeInactive() {
+            Log.d(TAG, "Wifi country code is inactived");
+            synchronized ("ThreadNetworkCountryCode.this") {
+                mWifiCountryCodeInfo = null;
+                updateCountryCode(false /* forceUpdate */);
+            }
+        }
+    }
+
     /**
      * Priority order of country code sources (we stop at the first known country code source):
      *
      * <ul>
      *   <li>1. Override country code - Country code forced via shell command (local/automated
      *       testing)
-     *   <li>2. Location Country code - Country code retrieved from LocationManager passive location
+     *   <li>2. Wifi country code - Current country code retrieved via wifi (via 80211.ad).
+     *   <li>3. Location Country code - Country code retrieved from LocationManager passive location
      *       provider.
      * </ul>
      *
@@ -210,6 +247,10 @@
             return mOverrideCountryCodeInfo;
         }
 
+        if (mWifiCountryCodeInfo != null) {
+            return mWifiCountryCodeInfo;
+        }
+
         if (mLocationCountryCodeInfo != null) {
             return mLocationCountryCodeInfo;
         }
@@ -303,6 +344,7 @@
     public synchronized void dump(FileDescriptor fd, PrintWriter pw, String[] args) {
         pw.println("---- Dump of ThreadNetworkCountryCode begin ----");
         pw.println("mOverrideCountryCodeInfo: " + mOverrideCountryCodeInfo);
+        pw.println("mWifiCountryCodeInfo: " + mWifiCountryCodeInfo);
         pw.println("mLocationCountryCodeInfo: " + mLocationCountryCodeInfo);
         pw.println("mCurrentCountryCodeInfo: " + mCurrentCountryCodeInfo);
         pw.println("---- Dump of ThreadNetworkCountryCode end ------");
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkService.java b/thread/service/java/com/android/server/thread/ThreadNetworkService.java
index 95c7256..80b8842 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkService.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkService.java
@@ -25,6 +25,7 @@
 import android.location.LocationManager;
 import android.net.thread.IThreadNetworkController;
 import android.net.thread.IThreadNetworkManager;
+import android.net.wifi.WifiManager;
 import android.os.Binder;
 import android.os.ParcelFileDescriptor;
 
@@ -64,7 +65,8 @@
                             mContext.getSystemService(LocationManager.class),
                             mControllerService,
                             Geocoder.isPresent() ? new Geocoder(mContext) : null,
-                            new ConnectivityResources(mContext));
+                            new ConnectivityResources(mContext),
+                            mContext.getSystemService(WifiManager.class));
             mCountryCode.initialize();
 
             mShellCommand = new ThreadNetworkShellCommand(mCountryCode);
diff --git a/thread/tests/unit/Android.bp b/thread/tests/unit/Android.bp
index c7887bc..2523915 100644
--- a/thread/tests/unit/Android.bp
+++ b/thread/tests/unit/Android.bp
@@ -47,6 +47,7 @@
         "android.test.base",
         "android.test.runner",
         "ServiceConnectivityResources",
+        "framework-wifi",
     ],
     jarjar_rules: ":connectivity-jarjar-rules",
     jni_libs: [
diff --git a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkCountryCodeTest.java b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkCountryCodeTest.java
index a0eff6c..d7082fe 100644
--- a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkCountryCodeTest.java
+++ b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkCountryCodeTest.java
@@ -32,6 +32,7 @@
 import static org.mockito.Mockito.clearInvocations;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoMoreInteractions;
@@ -45,6 +46,8 @@
 import android.location.LocationListener;
 import android.location.LocationManager;
 import android.net.thread.IOperationReceiver;
+import android.net.wifi.WifiManager;
+import android.net.wifi.WifiManager.ActiveCountryCodeChangedCallback;
 
 import androidx.test.filters.SmallTest;
 import androidx.test.runner.AndroidJUnit4;
@@ -78,6 +81,7 @@
     @Mock Location mLocation;
     @Mock Resources mResources;
     @Mock ConnectivityResources mConnectivityResources;
+    @Mock WifiManager mWifiManager;
 
     private ThreadNetworkCountryCode mThreadNetworkCountryCode;
     private boolean mErrorSetCountryCode;
@@ -85,6 +89,7 @@
     @Captor private ArgumentCaptor<LocationListener> mLocationListenerCaptor;
     @Captor private ArgumentCaptor<Geocoder.GeocodeListener> mGeocodeListenerCaptor;
     @Captor private ArgumentCaptor<IOperationReceiver> mOperationReceiverCaptor;
+    @Captor private ArgumentCaptor<ActiveCountryCodeChangedCallback> mWifiCountryCodeReceiverCaptor;
 
     @Before
     public void setUp() throws Exception {
@@ -118,7 +123,8 @@
                         mLocationManager,
                         mThreadNetworkControllerService,
                         mGeocoder,
-                        mConnectivityResources);
+                        mConnectivityResources,
+                        mWifiManager);
     }
 
     private static Address newAddress(String countryCode) {
@@ -162,6 +168,55 @@
     }
 
     @Test
+    public void wifiCountryCode_bothWifiAndLocationAreAvailable_wifiCountryCodeIsUsed() {
+        mThreadNetworkCountryCode.initialize();
+        verify(mLocationManager)
+                .requestLocationUpdates(
+                        anyString(), anyLong(), anyFloat(), mLocationListenerCaptor.capture());
+        mLocationListenerCaptor.getValue().onLocationChanged(mLocation);
+        verify(mGeocoder)
+                .getFromLocation(
+                        anyDouble(), anyDouble(), anyInt(), mGeocodeListenerCaptor.capture());
+        Address mockAddress = mock(Address.class);
+        when(mockAddress.getCountryCode()).thenReturn(TEST_COUNTRY_CODE_US);
+        List<Address> addresses = List.of(mockAddress);
+        mGeocodeListenerCaptor.getValue().onGeocode(addresses);
+
+        verify(mWifiManager)
+                .registerActiveCountryCodeChangedCallback(
+                        any(), mWifiCountryCodeReceiverCaptor.capture());
+        mWifiCountryCodeReceiverCaptor.getValue().onActiveCountryCodeChanged(TEST_COUNTRY_CODE_CN);
+
+        assertThat(mThreadNetworkCountryCode.getCountryCode()).isEqualTo(TEST_COUNTRY_CODE_CN);
+    }
+
+    @Test
+    public void wifiCountryCode_wifiCountryCodeIsActive_wifiCountryCodeIsUsed() {
+        mThreadNetworkCountryCode.initialize();
+
+        verify(mWifiManager)
+                .registerActiveCountryCodeChangedCallback(
+                        any(), mWifiCountryCodeReceiverCaptor.capture());
+        mWifiCountryCodeReceiverCaptor.getValue().onActiveCountryCodeChanged(TEST_COUNTRY_CODE_US);
+
+        assertThat(mThreadNetworkCountryCode.getCountryCode()).isEqualTo(TEST_COUNTRY_CODE_US);
+    }
+
+    @Test
+    public void wifiCountryCode_wifiCountryCodeIsInactive_defaultCountryCodeIsUsed() {
+        mThreadNetworkCountryCode.initialize();
+        verify(mWifiManager)
+                .registerActiveCountryCodeChangedCallback(
+                        any(), mWifiCountryCodeReceiverCaptor.capture());
+        mWifiCountryCodeReceiverCaptor.getValue().onActiveCountryCodeChanged(TEST_COUNTRY_CODE_US);
+
+        mWifiCountryCodeReceiverCaptor.getValue().onCountryCodeInactive();
+
+        assertThat(mThreadNetworkCountryCode.getCountryCode())
+                .isEqualTo(ThreadNetworkCountryCode.DEFAULT_COUNTRY_CODE);
+    }
+
+    @Test
     public void updateCountryCode_noForceUpdateDefaultCountryCode_noCountryCodeIsUpdated() {
         mThreadNetworkCountryCode.initialize();
         clearInvocations(mThreadNetworkControllerService);