Merge "Revert "Simplify Exception expressions in CSTest""
diff --git a/services/core/java/com/android/server/ConnectivityService.java b/services/core/java/com/android/server/ConnectivityService.java
index ae80d7e..a71d46c 100644
--- a/services/core/java/com/android/server/ConnectivityService.java
+++ b/services/core/java/com/android/server/ConnectivityService.java
@@ -2614,9 +2614,11 @@
                     final boolean valid = ((msg.arg1 & NETWORK_VALIDATION_RESULT_VALID) != 0);
                     final boolean wasValidated = nai.lastValidated;
                     final boolean wasDefault = isDefaultNetwork(nai);
-                    if (nai.everCaptivePortalDetected && !nai.captivePortalLoginNotified
-                            && valid) {
-                        nai.captivePortalLoginNotified = true;
+                    // Only show a connected notification if the network is pending validation
+                    // after the captive portal app was open, and it has now validated.
+                    if (nai.captivePortalValidationPending && valid) {
+                        // User is now logged in, network validated.
+                        nai.captivePortalValidationPending = false;
                         showNetworkNotification(nai, NotificationType.LOGGED_IN);
                     }
 
@@ -2687,9 +2689,6 @@
                         final int oldScore = nai.getCurrentScore();
                         nai.lastCaptivePortalDetected = visible;
                         nai.everCaptivePortalDetected |= visible;
-                        if (visible) {
-                            nai.captivePortalLoginNotified = false;
-                        }
                         if (nai.lastCaptivePortalDetected &&
                             Settings.Global.CAPTIVE_PORTAL_MODE_AVOID == getCaptivePortalMode()) {
                             if (DBG) log("Avoiding captive portal network: " + nai.name());
@@ -3498,6 +3497,12 @@
                 new CaptivePortal(new CaptivePortalImpl(network).asBinder()));
         appIntent.setFlags(Intent.FLAG_ACTIVITY_BROUGHT_TO_FRONT | Intent.FLAG_ACTIVITY_NEW_TASK);
 
+        // This runs on a random binder thread, but getNetworkAgentInfoForNetwork is thread-safe,
+        // and captivePortalValidationPending is volatile.
+        final NetworkAgentInfo nai = getNetworkAgentInfoForNetwork(network);
+        if (nai != null) {
+            nai.captivePortalValidationPending = true;
+        }
         Binder.withCleanCallingIdentity(() ->
                 mContext.startActivityAsUser(appIntent, UserHandle.CURRENT));
     }
diff --git a/services/core/java/com/android/server/connectivity/NetworkAgentInfo.java b/services/core/java/com/android/server/connectivity/NetworkAgentInfo.java
index 864a793..5b04379 100644
--- a/services/core/java/com/android/server/connectivity/NetworkAgentInfo.java
+++ b/services/core/java/com/android/server/connectivity/NetworkAgentInfo.java
@@ -155,9 +155,9 @@
     // Whether a captive portal was found during the last network validation attempt.
     public boolean lastCaptivePortalDetected;
 
-    // Indicates the user was notified of a successful captive portal login since a portal was
-    // last detected.
-    public boolean captivePortalLoginNotified;
+    // Indicates the captive portal app was opened to show a login UI to the user, but the network
+    // has not validated yet.
+    public volatile boolean captivePortalValidationPending;
 
     // Set to true when partial connectivity was detected.
     public boolean partialConnectivity;
@@ -629,7 +629,7 @@
                 + "acceptUnvalidated{" + networkMisc.acceptUnvalidated + "} "
                 + "everCaptivePortalDetected{" + everCaptivePortalDetected + "} "
                 + "lastCaptivePortalDetected{" + lastCaptivePortalDetected + "} "
-                + "captivePortalLoginNotified{" + captivePortalLoginNotified + "} "
+                + "captivePortalValidationPending{" + captivePortalValidationPending + "} "
                 + "partialConnectivity{" + partialConnectivity + "} "
                 + "acceptPartialConnectivity{" + networkMisc.acceptPartialConnectivity + "} "
                 + "clat{" + clatd + "} "
diff --git a/tests/net/java/com/android/server/ConnectivityServiceTest.java b/tests/net/java/com/android/server/ConnectivityServiceTest.java
index a192945..b60dc2d 100644
--- a/tests/net/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/net/java/com/android/server/ConnectivityServiceTest.java
@@ -18,6 +18,7 @@
 
 import static android.content.pm.PackageManager.GET_PERMISSIONS;
 import static android.content.pm.PackageManager.MATCH_ANY_USER;
+import static android.net.ConnectivityManager.ACTION_CAPTIVE_PORTAL_SIGN_IN;
 import static android.net.ConnectivityManager.CONNECTIVITY_ACTION;
 import static android.net.ConnectivityManager.NETID_UNSET;
 import static android.net.ConnectivityManager.PRIVATE_DNS_MODE_OFF;
@@ -84,13 +85,13 @@
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
+import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Matchers.anyInt;
 import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.doNothing;
-import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.inOrder;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
@@ -145,6 +146,7 @@
 import android.net.NetworkMisc;
 import android.net.NetworkRequest;
 import android.net.NetworkSpecifier;
+import android.net.NetworkStack;
 import android.net.NetworkStackClient;
 import android.net.NetworkState;
 import android.net.NetworkUtils;
@@ -159,6 +161,7 @@
 import android.net.util.MultinetworkPolicyTracker;
 import android.os.BadParcelableException;
 import android.os.Binder;
+import android.os.Bundle;
 import android.os.ConditionVariable;
 import android.os.Handler;
 import android.os.HandlerThread;
@@ -196,6 +199,7 @@
 import com.android.server.connectivity.IpConnectivityMetrics;
 import com.android.server.connectivity.MockableSystemProperties;
 import com.android.server.connectivity.Nat464Xlat;
+import com.android.server.connectivity.NetworkNotificationManager.NotificationType;
 import com.android.server.connectivity.ProxyTracker;
 import com.android.server.connectivity.Tethering;
 import com.android.server.connectivity.Vpn;
@@ -290,6 +294,7 @@
     @Mock NetworkStackClient mNetworkStack;
     @Mock PackageManager mPackageManager;
     @Mock UserManager mUserManager;
+    @Mock NotificationManager mNotificationManager;
 
     private ArgumentCaptor<ResolverParamsParcel> mResolverParamsParcelCaptor =
             ArgumentCaptor.forClass(ResolverParamsParcel.class);
@@ -359,7 +364,7 @@
         @Override
         public Object getSystemService(String name) {
             if (Context.CONNECTIVITY_SERVICE.equals(name)) return mCm;
-            if (Context.NOTIFICATION_SERVICE.equals(name)) return mock(NotificationManager.class);
+            if (Context.NOTIFICATION_SERVICE.equals(name)) return mNotificationManager;
             if (Context.NETWORK_STACK_SERVICE.equals(name)) return mNetworkStack;
             if (Context.USER_SERVICE.equals(name)) return mUserManager;
             return super.getSystemService(name);
@@ -379,7 +384,17 @@
         public PackageManager getPackageManager() {
             return mPackageManager;
         }
-   }
+
+        @Override
+        public void enforceCallingOrSelfPermission(String permission, String message) {
+            // The mainline permission can only be held if signed with the network stack certificate
+            // Skip testing for this permission.
+            if (NetworkStack.PERMISSION_MAINLINE_NETWORK_STACK.equals(permission)) return;
+            // All other permissions should be held by the test or unnecessary: check as normal to
+            // make sure the code does not rely on unexpected permissions.
+            super.enforceCallingOrSelfPermission(permission, message);
+        }
+    }
 
     public void waitForIdle(int timeoutMsAsInt) {
         long timeoutMs = timeoutMsAsInt;
@@ -2864,6 +2879,9 @@
 
         // Expect NET_CAPABILITY_VALIDATED onAvailable callback.
         validatedCallback.expectAvailableDoubleValidatedCallbacks(mWiFiNetworkAgent);
+        // Expect no notification to be shown when captive portal disappears by itself
+        verify(mNotificationManager, never()).notifyAsUser(
+                anyString(), eq(NotificationType.LOGGED_IN.eventId), any(), any());
 
         // Break network connectivity.
         // Expect NET_CAPABILITY_VALIDATED onLost callback.
@@ -2893,6 +2911,8 @@
         // Check that calling startCaptivePortalApp does nothing.
         final int fastTimeoutMs = 100;
         mCm.startCaptivePortalApp(wifiNetwork);
+        waitForIdle();
+        verify(mWiFiNetworkAgent.mNetworkMonitor, never()).launchCaptivePortalApp();
         mServiceContext.expectNoStartActivityIntent(fastTimeoutMs);
 
         // Turn into a captive portal.
@@ -2903,14 +2923,26 @@
 
         // Check that startCaptivePortalApp sends the expected command to NetworkMonitor.
         mCm.startCaptivePortalApp(wifiNetwork);
-        verify(mWiFiNetworkAgent.mNetworkMonitor, timeout(TIMEOUT_MS).times(1))
-                .launchCaptivePortalApp();
+        waitForIdle();
+        verify(mWiFiNetworkAgent.mNetworkMonitor).launchCaptivePortalApp();
+
+        // NetworkMonitor uses startCaptivePortal(Network, Bundle) (startCaptivePortalAppInternal)
+        final Bundle testBundle = new Bundle();
+        final String testKey = "testkey";
+        final String testValue = "testvalue";
+        testBundle.putString(testKey, testValue);
+        mCm.startCaptivePortalApp(wifiNetwork, testBundle);
+        final Intent signInIntent = mServiceContext.expectStartActivityIntent(TIMEOUT_MS);
+        assertEquals(ACTION_CAPTIVE_PORTAL_SIGN_IN, signInIntent.getAction());
+        assertEquals(testValue, signInIntent.getStringExtra(testKey));
 
         // Report that the captive portal is dismissed, and check that callbacks are fired
         mWiFiNetworkAgent.setNetworkValid();
         mWiFiNetworkAgent.mNetworkMonitor.forceReevaluation(Process.myUid());
         validatedCallback.expectAvailableCallbacksValidated(mWiFiNetworkAgent);
         captivePortalCallback.expectCallback(CallbackState.LOST, mWiFiNetworkAgent);
+        verify(mNotificationManager, times(1)).notifyAsUser(anyString(),
+                eq(NotificationType.LOGGED_IN.eventId), any(), eq(UserHandle.ALL));
 
         mCm.unregisterNetworkCallback(validatedCallback);
         mCm.unregisterNetworkCallback(captivePortalCallback);
diff --git a/tests/net/java/com/android/server/net/NetworkStatsServiceTest.java b/tests/net/java/com/android/server/net/NetworkStatsServiceTest.java
index adaef40..51dcc3c 100644
--- a/tests/net/java/com/android/server/net/NetworkStatsServiceTest.java
+++ b/tests/net/java/com/android/server/net/NetworkStatsServiceTest.java
@@ -37,6 +37,7 @@
 import static android.net.NetworkStats.SET_FOREGROUND;
 import static android.net.NetworkStats.STATS_PER_IFACE;
 import static android.net.NetworkStats.STATS_PER_UID;
+import static android.net.NetworkStats.TAG_ALL;
 import static android.net.NetworkStats.TAG_NONE;
 import static android.net.NetworkStats.UID_ALL;
 import static android.net.NetworkStatsHistory.FIELD_ALL;
@@ -154,6 +155,7 @@
     private File mStatsDir;
 
     private @Mock INetworkManagementService mNetManager;
+    private @Mock NetworkStatsFactory mStatsFactory;
     private @Mock NetworkStatsSettings mSettings;
     private @Mock IBinder mBinder;
     private @Mock AlarmManager mAlarmManager;
@@ -189,8 +191,8 @@
 
         mService = new NetworkStatsService(
                 mServiceContext, mNetManager, mAlarmManager, wakeLock, mClock,
-                TelephonyManager.getDefault(), mSettings, new NetworkStatsObservers(),
-                mStatsDir, getBaseDir(mStatsDir));
+                TelephonyManager.getDefault(), mSettings, mStatsFactory,
+                new NetworkStatsObservers(),  mStatsDir, getBaseDir(mStatsDir));
         mHandlerThread = new HandlerThread("HandlerThread");
         mHandlerThread.start();
         Handler.Callback callback = new NetworkStatsService.HandlerCallback(mService);
@@ -205,7 +207,7 @@
 
         mService.systemReady();
         // Verify that system ready fetches realtime stats
-        verify(mNetManager).getNetworkStatsUidDetail(UID_ALL, INTERFACES_ALL);
+        verify(mStatsFactory).readNetworkStatsDetail(UID_ALL, INTERFACES_ALL, TAG_ALL);
 
         mSession = mService.openSession();
         assertNotNull("openSession() failed", mSession);
@@ -696,7 +698,7 @@
         incrementCurrentTime(HOUR_IN_MILLIS);
         expectDefaultSettings();
         expectNetworkStatsSummary(buildEmptyStats());
-        when(mNetManager.getNetworkStatsUidDetail(eq(UID_ALL), any()))
+        when(mStatsFactory.readNetworkStatsDetail(eq(UID_ALL), any(), eq(TAG_ALL)))
                 .thenReturn(new NetworkStats(getElapsedRealtime(), 1)
                         .addValues(uidStats));
         when(mNetManager.getNetworkStatsTethering(STATS_PER_UID))
@@ -712,9 +714,10 @@
         //
         // Additionally, we should have one call from the above call to mService#getDetailedUidStats
         // with the augmented ifaceFilter
-        verify(mNetManager, times(2)).getNetworkStatsUidDetail(UID_ALL, INTERFACES_ALL);
-        verify(mNetManager, times(1)).getNetworkStatsUidDetail(
-                eq(UID_ALL), eq(NetworkStatsFactory.augmentWithStackedInterfaces(ifaceFilter)));
+        verify(mStatsFactory, times(2)).readNetworkStatsDetail(UID_ALL, INTERFACES_ALL, TAG_ALL);
+        verify(mStatsFactory, times(1)).readNetworkStatsDetail(
+                eq(UID_ALL), eq(NetworkStatsFactory.augmentWithStackedInterfaces(ifaceFilter)),
+                eq(TAG_ALL));
         assertTrue(ArrayUtils.contains(stats.getUniqueIfaces(), TEST_IFACE));
         assertTrue(ArrayUtils.contains(stats.getUniqueIfaces(), stackedIface));
         assertEquals(2, stats.size());
@@ -1062,11 +1065,11 @@
     }
 
     private void expectNetworkStatsSummaryDev(NetworkStats summary) throws Exception {
-        when(mNetManager.getNetworkStatsSummaryDev()).thenReturn(summary);
+        when(mStatsFactory.readNetworkStatsSummaryDev()).thenReturn(summary);
     }
 
     private void expectNetworkStatsSummaryXt(NetworkStats summary) throws Exception {
-        when(mNetManager.getNetworkStatsSummaryXt()).thenReturn(summary);
+        when(mStatsFactory.readNetworkStatsSummaryXt()).thenReturn(summary);
     }
 
     private void expectNetworkStatsTethering(int how, NetworkStats stats)
@@ -1080,7 +1083,8 @@
 
     private void expectNetworkStatsUidDetail(NetworkStats detail, NetworkStats tetherStats)
             throws Exception {
-        when(mNetManager.getNetworkStatsUidDetail(UID_ALL, INTERFACES_ALL)).thenReturn(detail);
+        when(mStatsFactory.readNetworkStatsDetail(UID_ALL, INTERFACES_ALL, TAG_ALL))
+                .thenReturn(detail);
 
         // also include tethering details, since they are folded into UID
         when(mNetManager.getNetworkStatsTethering(STATS_PER_UID)).thenReturn(tetherStats);