Merge changes from topic "restricted-networking-mode"

* changes:
  Add Restricted Mode Firewall Chain
  Clean Up NetworkManagementService Tests
diff --git a/tests/net/java/com/android/server/NetworkManagementServiceTest.java b/tests/net/java/com/android/server/NetworkManagementServiceTest.java
index 968b307..ea763d2 100644
--- a/tests/net/java/com/android/server/NetworkManagementServiceTest.java
+++ b/tests/net/java/com/android/server/NetworkManagementServiceTest.java
@@ -16,6 +16,12 @@
 
 package com.android.server;
 
+import static android.util.DebugUtils.valueToString;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.ArgumentMatchers.anyBoolean;
 import static org.mockito.Mockito.doNothing;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
@@ -29,15 +35,19 @@
 import android.net.INetd;
 import android.net.INetdUnsolicitedEventListener;
 import android.net.LinkAddress;
+import android.net.NetworkPolicyManager;
 import android.os.BatteryStats;
 import android.os.Binder;
 import android.os.IBinder;
+import android.os.Process;
+import android.os.RemoteException;
 import android.test.suitebuilder.annotation.SmallTest;
+import android.util.ArrayMap;
 
 import androidx.test.runner.AndroidJUnit4;
 
 import com.android.internal.app.IBatteryStats;
-import com.android.server.NetworkManagementService.SystemServices;
+import com.android.server.NetworkManagementService.Dependencies;
 import com.android.server.net.BaseNetworkObserver;
 
 import org.junit.After;
@@ -49,13 +59,14 @@
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
+import java.util.function.BiFunction;
+
 /**
  * Tests for {@link NetworkManagementService}.
  */
 @RunWith(AndroidJUnit4.class)
 @SmallTest
 public class NetworkManagementServiceTest {
-
     private NetworkManagementService mNMService;
 
     @Mock private Context mContext;
@@ -66,7 +77,9 @@
     @Captor
     private ArgumentCaptor<INetdUnsolicitedEventListener> mUnsolListenerCaptor;
 
-    private final SystemServices mServices = new SystemServices() {
+    private final MockDependencies mDeps = new MockDependencies();
+
+    private final class MockDependencies extends Dependencies {
         @Override
         public IBinder getService(String name) {
             switch (name) {
@@ -76,14 +89,21 @@
                     throw new UnsupportedOperationException("Unknown service " + name);
             }
         }
+
         @Override
         public void registerLocalService(NetworkManagementInternal nmi) {
         }
+
         @Override
         public INetd getNetd() {
             return mNetdService;
         }
-    };
+
+        @Override
+        public int getCallingUid() {
+            return Process.SYSTEM_UID;
+        }
+    }
 
     @Before
     public void setUp() throws Exception {
@@ -91,7 +111,7 @@
         doNothing().when(mNetdService)
                 .registerUnsolicitedEventListener(mUnsolListenerCaptor.capture());
         // Start the service and wait until it connects to our socket.
-        mNMService = NetworkManagementService.create(mContext, mServices);
+        mNMService = NetworkManagementService.create(mContext, mDeps);
     }
 
     @After
@@ -192,4 +212,105 @@
         // Make sure nothing else was called.
         verifyNoMoreInteractions(observer);
     }
+
+    @Test
+    public void testFirewallEnabled() {
+        mNMService.setFirewallEnabled(true);
+        assertTrue(mNMService.isFirewallEnabled());
+
+        mNMService.setFirewallEnabled(false);
+        assertFalse(mNMService.isFirewallEnabled());
+    }
+
+    private static final int TEST_UID = 111;
+
+    @Test
+    public void testNetworkRestrictedDefault() {
+        assertFalse(mNMService.isNetworkRestricted(TEST_UID));
+    }
+
+    @Test
+    public void testMeteredNetworkRestrictions() throws RemoteException {
+        // Make sure the mocked netd method returns true.
+        doReturn(true).when(mNetdService).bandwidthEnableDataSaver(anyBoolean());
+
+        // Restrict usage of mobile data in background
+        mNMService.setUidMeteredNetworkDenylist(TEST_UID, true);
+        assertTrue("Should be true since mobile data usage is restricted",
+                mNMService.isNetworkRestricted(TEST_UID));
+
+        mNMService.setDataSaverModeEnabled(true);
+        verify(mNetdService).bandwidthEnableDataSaver(true);
+
+        mNMService.setUidMeteredNetworkDenylist(TEST_UID, false);
+        assertTrue("Should be true since data saver is on and the uid is not allowlisted",
+                mNMService.isNetworkRestricted(TEST_UID));
+
+        mNMService.setUidMeteredNetworkAllowlist(TEST_UID, true);
+        assertFalse("Should be false since data saver is on and the uid is allowlisted",
+                mNMService.isNetworkRestricted(TEST_UID));
+
+        // remove uid from allowlist and turn datasaver off again
+        mNMService.setUidMeteredNetworkAllowlist(TEST_UID, false);
+        mNMService.setDataSaverModeEnabled(false);
+        verify(mNetdService).bandwidthEnableDataSaver(false);
+        assertFalse("Network should not be restricted when data saver is off",
+                mNMService.isNetworkRestricted(TEST_UID));
+    }
+
+    @Test
+    public void testFirewallChains() {
+        final ArrayMap<Integer, ArrayMap<Integer, Boolean>> expected = new ArrayMap<>();
+        // Dozable chain
+        final ArrayMap<Integer, Boolean> isRestrictedForDozable = new ArrayMap<>();
+        isRestrictedForDozable.put(NetworkPolicyManager.FIREWALL_RULE_DEFAULT, true);
+        isRestrictedForDozable.put(INetd.FIREWALL_RULE_ALLOW, false);
+        isRestrictedForDozable.put(INetd.FIREWALL_RULE_DENY, true);
+        expected.put(INetd.FIREWALL_CHAIN_DOZABLE, isRestrictedForDozable);
+        // Powersaver chain
+        final ArrayMap<Integer, Boolean> isRestrictedForPowerSave = new ArrayMap<>();
+        isRestrictedForPowerSave.put(NetworkPolicyManager.FIREWALL_RULE_DEFAULT, true);
+        isRestrictedForPowerSave.put(INetd.FIREWALL_RULE_ALLOW, false);
+        isRestrictedForPowerSave.put(INetd.FIREWALL_RULE_DENY, true);
+        expected.put(INetd.FIREWALL_CHAIN_POWERSAVE, isRestrictedForPowerSave);
+        // Standby chain
+        final ArrayMap<Integer, Boolean> isRestrictedForStandby = new ArrayMap<>();
+        isRestrictedForStandby.put(NetworkPolicyManager.FIREWALL_RULE_DEFAULT, false);
+        isRestrictedForStandby.put(INetd.FIREWALL_RULE_ALLOW, false);
+        isRestrictedForStandby.put(INetd.FIREWALL_RULE_DENY, true);
+        expected.put(INetd.FIREWALL_CHAIN_STANDBY, isRestrictedForStandby);
+        // Restricted mode chain
+        final ArrayMap<Integer, Boolean> isRestrictedForRestrictedMode = new ArrayMap<>();
+        isRestrictedForRestrictedMode.put(NetworkPolicyManager.FIREWALL_RULE_DEFAULT, true);
+        isRestrictedForRestrictedMode.put(INetd.FIREWALL_RULE_ALLOW, false);
+        isRestrictedForRestrictedMode.put(INetd.FIREWALL_RULE_DENY, true);
+        expected.put(INetd.FIREWALL_CHAIN_RESTRICTED, isRestrictedForRestrictedMode);
+
+        final int[] chains = {
+                INetd.FIREWALL_CHAIN_STANDBY,
+                INetd.FIREWALL_CHAIN_POWERSAVE,
+                INetd.FIREWALL_CHAIN_DOZABLE,
+                INetd.FIREWALL_CHAIN_RESTRICTED
+        };
+        final int[] states = {
+                INetd.FIREWALL_RULE_ALLOW,
+                INetd.FIREWALL_RULE_DENY,
+                NetworkPolicyManager.FIREWALL_RULE_DEFAULT
+        };
+        BiFunction<Integer, Integer, String> errorMsg = (chain, state) -> {
+            return String.format("Unexpected value for chain: %s and state: %s",
+                    valueToString(INetd.class, "FIREWALL_CHAIN_", chain),
+                    valueToString(INetd.class, "FIREWALL_RULE_", state));
+        };
+        for (int chain : chains) {
+            final ArrayMap<Integer, Boolean> expectedValues = expected.get(chain);
+            mNMService.setFirewallChainEnabled(chain, true);
+            for (int state : states) {
+                mNMService.setFirewallUidRule(chain, TEST_UID, state);
+                assertEquals(errorMsg.apply(chain, state),
+                        expectedValues.get(state), mNMService.isNetworkRestricted(TEST_UID));
+            }
+            mNMService.setFirewallChainEnabled(chain, false);
+        }
+    }
 }