Merge "Automatically direct the user to the captive portal in Wi-Fi Slice" into qt-dev
diff --git a/src/com/android/settings/wifi/slice/ConnectToWifiHandler.java b/src/com/android/settings/wifi/slice/ConnectToWifiHandler.java
index ee15820..f1b0b6f 100644
--- a/src/com/android/settings/wifi/slice/ConnectToWifiHandler.java
+++ b/src/com/android/settings/wifi/slice/ConnectToWifiHandler.java
@@ -43,17 +43,21 @@
                 WifiDialogActivity.KEY_ACCESS_POINT_STATE);
 
         if (network != null) {
+            WifiScanWorker.clearClickedWifi();
             final ConnectivityManager cm = getSystemService(ConnectivityManager.class);
             // start captive portal app to sign in to network
             cm.startCaptivePortalApp(network);
         } else if (accessPointState != null) {
             connect(new AccessPoint(this, accessPointState));
         }
+
         finish();
     }
 
     @VisibleForTesting
     void connect(AccessPoint accessPoint) {
+        WifiScanWorker.saveClickedWifi(accessPoint);
+
         final WifiConnectListener connectListener = new WifiConnectListener(this);
         switch (WifiUtils.getConnectingType(accessPoint)) {
             case WifiUtils.CONNECT_TYPE_OSU_PROVISION:
diff --git a/src/com/android/settings/wifi/slice/WifiScanWorker.java b/src/com/android/settings/wifi/slice/WifiScanWorker.java
index b846228..e438443 100644
--- a/src/com/android/settings/wifi/slice/WifiScanWorker.java
+++ b/src/com/android/settings/wifi/slice/WifiScanWorker.java
@@ -20,6 +20,7 @@
 import static com.android.settings.wifi.slice.WifiSlice.DEFAULT_EXPANDED_ROW_COUNT;
 
 import android.content.Context;
+import android.content.Intent;
 import android.net.ConnectivityManager;
 import android.net.ConnectivityManager.NetworkCallback;
 import android.net.Network;
@@ -27,9 +28,12 @@
 import android.net.NetworkInfo;
 import android.net.NetworkRequest;
 import android.net.Uri;
+import android.net.wifi.WifiInfo;
 import android.os.Bundle;
 import android.os.Handler;
 import android.os.Looper;
+import android.os.UserHandle;
+import android.text.TextUtils;
 import android.util.Log;
 
 import androidx.annotation.VisibleForTesting;
@@ -55,22 +59,21 @@
     CaptivePortalNetworkCallback mCaptivePortalNetworkCallback;
 
     private final Context mContext;
+    private final ConnectivityManager mConnectivityManager;
+    private final WifiTracker mWifiTracker;
 
-    private WifiTracker mWifiTracker;
-    private ConnectivityManager mConnectivityManager;
+    private static String sClickedWifiSsid;
 
     public WifiScanWorker(Context context, Uri uri) {
         super(context, uri);
         mContext = context;
         mConnectivityManager = context.getSystemService(ConnectivityManager.class);
+        mWifiTracker = new WifiTracker(mContext, this /* wifiListener */,
+                true /* includeSaved */, true /* includeScans */);
     }
 
     @Override
     protected void onSlicePinned() {
-        if (mWifiTracker == null) {
-            mWifiTracker = new WifiTracker(mContext, this /* wifiListener */,
-                    true /* includeSaved */, true /* includeScans */);
-        }
         mWifiTracker.onStart();
         onAccessPointsChanged();
     }
@@ -79,6 +82,7 @@
     protected void onSliceUnpinned() {
         mWifiTracker.onStop();
         unregisterCaptivePortalNetworkCallback();
+        clearClickedWifi();
     }
 
     @Override
@@ -146,6 +150,19 @@
         return null;
     }
 
+    static void saveClickedWifi(AccessPoint accessPoint) {
+        sClickedWifiSsid = accessPoint.getSsidStr();
+    }
+
+    static void clearClickedWifi() {
+        sClickedWifiSsid = null;
+    }
+
+    static boolean isWifiClicked(WifiInfo info) {
+        final String ssid = WifiInfo.removeDoubleQuotes(info.getSSID());
+        return !TextUtils.isEmpty(ssid) && TextUtils.equals(ssid, sClickedWifiSsid);
+    }
+
     public void registerCaptivePortalNetworkCallback(Network wifiNetwork) {
         if (wifiNetwork == null) {
             return;
@@ -191,7 +208,7 @@
         @Override
         public void onCapabilitiesChanged(Network network,
                 NetworkCapabilities networkCapabilities) {
-            if (!mNetwork.equals(network)) {
+            if (!isSameNetwork(network)) {
                 return;
             }
 
@@ -202,6 +219,19 @@
 
             mIsCaptivePortal = isCaptivePortal;
             notifySliceChange();
+
+            // Automatically start captive portal
+            if (mIsCaptivePortal) {
+                if (!isWifiClicked(mWifiTracker.getManager().getConnectionInfo())) {
+                    return;
+                }
+
+                final Intent intent = new Intent(mContext, ConnectToWifiHandler.class)
+                        .putExtra(ConnectivityManager.EXTRA_NETWORK, network)
+                        .addFlags(Intent.FLAG_ACTIVITY_NEW_TASK);
+                // Starting activity in the system process needs to specify a user
+                mContext.startActivityAsUser(intent, UserHandle.CURRENT);
+            }
         }
 
         /**
diff --git a/tests/robotests/src/com/android/settings/wifi/slice/ConnectToWifiHandlerTest.java b/tests/robotests/src/com/android/settings/wifi/slice/ConnectToWifiHandlerTest.java
index b18102d..cea8365 100644
--- a/tests/robotests/src/com/android/settings/wifi/slice/ConnectToWifiHandlerTest.java
+++ b/tests/robotests/src/com/android/settings/wifi/slice/ConnectToWifiHandlerTest.java
@@ -27,7 +27,6 @@
 import android.net.wifi.WifiConfiguration.NetworkSelectionStatus;
 import android.net.wifi.WifiManager;
 
-import com.android.settings.testutils.shadow.ShadowConnectivityManager;
 import com.android.settings.testutils.shadow.ShadowWifiManager;
 import com.android.settingslib.wifi.AccessPoint;
 
@@ -41,13 +40,10 @@
 import org.robolectric.annotation.Config;
 
 @RunWith(RobolectricTestRunner.class)
-@Config(shadows = {
-        ShadowConnectivityManager.class,
-        ShadowWifiManager.class,
-})
+@Config(shadows = ShadowWifiManager.class)
 public class ConnectToWifiHandlerTest {
 
-    private static final String AP1_SSID = "\"ap1\"";
+    private static final String AP_SSID = "\"ap\"";
     private ConnectToWifiHandler mHandler;
     private WifiConfiguration mWifiConfig;
     @Mock
@@ -59,7 +55,7 @@
 
         mHandler = Robolectric.setupActivity(ConnectToWifiHandler.class);
         mWifiConfig = new WifiConfiguration();
-        mWifiConfig.SSID = AP1_SSID;
+        mWifiConfig.SSID = AP_SSID;
         doReturn(mWifiConfig).when(mAccessPoint).getConfig();
     }
 
@@ -70,7 +66,7 @@
 
         mHandler.connect(mAccessPoint);
 
-        assertThat(ShadowWifiManager.get().savedWifiConfig.SSID).isEqualTo(AP1_SSID);
+        assertThat(ShadowWifiManager.get().savedWifiConfig.SSID).isEqualTo(AP_SSID);
     }
 
     @Test
@@ -91,7 +87,7 @@
 
         mHandler.connect(mAccessPoint);
 
-        assertThat(ShadowWifiManager.get().savedWifiConfig.SSID).isEqualTo(AP1_SSID);
+        assertThat(ShadowWifiManager.get().savedWifiConfig.SSID).isEqualTo(AP_SSID);
     }
 
     @Test
@@ -104,7 +100,7 @@
 
         mHandler.connect(mAccessPoint);
 
-        assertThat(ShadowWifiManager.get().savedWifiConfig.SSID).isEqualTo(AP1_SSID);
+        assertThat(ShadowWifiManager.get().savedWifiConfig.SSID).isEqualTo(AP_SSID);
     }
 
     @Test
diff --git a/tests/robotests/src/com/android/settings/wifi/slice/WifiScanWorkerTest.java b/tests/robotests/src/com/android/settings/wifi/slice/WifiScanWorkerTest.java
index 30e289b..19d3e40 100644
--- a/tests/robotests/src/com/android/settings/wifi/slice/WifiScanWorkerTest.java
+++ b/tests/robotests/src/com/android/settings/wifi/slice/WifiScanWorkerTest.java
@@ -20,36 +20,54 @@
 
 import static com.google.common.truth.Truth.assertThat;
 
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.verify;
 
 import android.content.ContentResolver;
 import android.content.Context;
+import android.content.Intent;
 import android.net.ConnectivityManager;
 import android.net.Network;
 import android.net.NetworkInfo;
 import android.net.NetworkInfo.State;
+import android.net.wifi.WifiInfo;
 import android.net.wifi.WifiManager;
+import android.net.wifi.WifiSsid;
 import android.os.Bundle;
+import android.os.UserHandle;
 
 import androidx.slice.SliceProvider;
 import androidx.slice.widget.SliceLiveData;
 
+import com.android.settings.testutils.shadow.ShadowWifiManager;
 import com.android.settingslib.wifi.AccessPoint;
+import com.android.settingslib.wifi.WifiTracker;
 
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
 
+import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
+import org.robolectric.Robolectric;
 import org.robolectric.RobolectricTestRunner;
 import org.robolectric.RuntimeEnvironment;
+import org.robolectric.annotation.Config;
+import org.robolectric.annotation.Implementation;
+import org.robolectric.annotation.Implements;
 
 @RunWith(RobolectricTestRunner.class)
+@Config(shadows = {
+        ShadowWifiManager.class,
+        WifiScanWorkerTest.ShadowWifiTracker.class,
+})
 public class WifiScanWorkerTest {
 
     private static final String AP_NAME = "ap";
@@ -59,6 +77,7 @@
     private WifiManager mWifiManager;
     private ConnectivityManager mConnectivityManager;
     private WifiScanWorker mWifiScanWorker;
+    private ConnectToWifiHandler mConnectToWifiHandler;
 
     @Before
     public void setUp() {
@@ -73,6 +92,12 @@
 
         mConnectivityManager = mContext.getSystemService(ConnectivityManager.class);
         mWifiScanWorker = new WifiScanWorker(mContext, WIFI_SLICE_URI);
+        mConnectToWifiHandler = Robolectric.setupActivity(ConnectToWifiHandler.class);
+    }
+
+    @After
+    public void tearDown() {
+        mWifiScanWorker.clearClickedWifi();
     }
 
     @Test
@@ -131,4 +156,82 @@
 
         verify(mResolver).notifyChange(WIFI_SLICE_URI, null);
     }
+
+    private AccessPoint createAccessPoint(String ssid) {
+        final AccessPoint accessPoint = mock(AccessPoint.class);
+        doReturn(ssid).when(accessPoint).getSsidStr();
+        return accessPoint;
+    }
+
+    private void setConnectionInfoSSID(String ssid) {
+        final WifiInfo wifiInfo = new WifiInfo();
+        wifiInfo.setSSID(WifiSsid.createFromAsciiEncoded(ssid));
+        ShadowWifiManager.get().setConnectionInfo(wifiInfo);
+    }
+
+    @Test
+    public void NetworkCallback_onCapabilitiesChanged_isClickedWifi_shouldStartActivity() {
+        final AccessPoint accessPoint = createAccessPoint("ap1");
+        setConnectionInfoSSID("ap1");
+        final Network network = mConnectivityManager.getActiveNetwork();
+        mWifiScanWorker.registerCaptivePortalNetworkCallback(network);
+
+        mConnectToWifiHandler.connect(accessPoint);
+        mWifiScanWorker.mCaptivePortalNetworkCallback.onCapabilitiesChanged(network,
+                WifiSliceTest.makeCaptivePortalNetworkCapabilities());
+
+        verify(mContext).startActivityAsUser(any(Intent.class), eq(UserHandle.CURRENT));
+    }
+
+    @Test
+    public void NetworkCallback_onCapabilitiesChanged_isNotClickedWifi_shouldNotStartActivity() {
+        final AccessPoint accessPoint = createAccessPoint("ap1");
+        setConnectionInfoSSID("ap2");
+        final Network network = mConnectivityManager.getActiveNetwork();
+        mWifiScanWorker.registerCaptivePortalNetworkCallback(network);
+
+        mConnectToWifiHandler.connect(accessPoint);
+        mWifiScanWorker.mCaptivePortalNetworkCallback.onCapabilitiesChanged(network,
+                WifiSliceTest.makeCaptivePortalNetworkCapabilities());
+
+        verify(mContext, never()).startActivityAsUser(any(Intent.class), eq(UserHandle.CURRENT));
+    }
+
+    @Test
+    public void NetworkCallback_onCapabilitiesChanged_neverClickWifi_shouldNotStartActivity() {
+        setConnectionInfoSSID("ap1");
+        final Network network = mConnectivityManager.getActiveNetwork();
+        mWifiScanWorker.registerCaptivePortalNetworkCallback(network);
+
+        mWifiScanWorker.mCaptivePortalNetworkCallback.onCapabilitiesChanged(network,
+                WifiSliceTest.makeCaptivePortalNetworkCapabilities());
+
+        verify(mContext, never()).startActivityAsUser(any(Intent.class), eq(UserHandle.CURRENT));
+    }
+
+    @Test
+    public void NetworkCallback_onCapabilitiesChanged_sliceIsUnpinned_shouldNotStartActivity() {
+        final AccessPoint accessPoint = createAccessPoint("ap1");
+        setConnectionInfoSSID("ap1");
+        final Network network = mConnectivityManager.getActiveNetwork();
+        mWifiScanWorker.registerCaptivePortalNetworkCallback(network);
+        final WifiScanWorker.CaptivePortalNetworkCallback callback =
+                mWifiScanWorker.mCaptivePortalNetworkCallback;
+
+        mWifiScanWorker.onSlicePinned();
+        mConnectToWifiHandler.connect(accessPoint);
+        mWifiScanWorker.onSliceUnpinned();
+        callback.onCapabilitiesChanged(network,
+                WifiSliceTest.makeCaptivePortalNetworkCapabilities());
+
+        verify(mContext, never()).startActivityAsUser(any(Intent.class), eq(UserHandle.CURRENT));
+    }
+
+    @Implements(WifiTracker.class)
+    public static class ShadowWifiTracker {
+        @Implementation
+        public void onStart() {
+            // do nothing
+        }
+    }
 }