Use java BpfMap in BpfNetMaps#replaceUidChain

Bug: 217624062
Test: atest BpfNetMapsTest
Change-Id: Ib2a2c2646834110a3eeeb786a4ea7a3f85718be8
diff --git a/framework/src/android/net/ConnectivityManager.java b/framework/src/android/net/ConnectivityManager.java
index 6ccd77e..28f0699 100644
--- a/framework/src/android/net/ConnectivityManager.java
+++ b/framework/src/android/net/ConnectivityManager.java
@@ -5949,7 +5949,7 @@
      *
      * @param chain target chain to replace.
      * @param uids The list of UIDs to be placed into chain.
-     * @throws IllegalStateException if replacing the firewall chain failed.
+     * @throws UnsupportedOperationException if called on pre-T devices.
      * @throws IllegalArgumentException if {@code chain} is not a valid chain.
      * @hide
      */
diff --git a/service/src/com/android/server/BpfNetMaps.java b/service/src/com/android/server/BpfNetMaps.java
index d7c5a06..0ff8810 100644
--- a/service/src/com/android/server/BpfNetMaps.java
+++ b/service/src/com/android/server/BpfNetMaps.java
@@ -46,6 +46,10 @@
 
 import java.io.FileDescriptor;
 import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.stream.Collectors;
 
 /**
  * BpfNetMaps is responsible for providing traffic controller relevant functionality.
@@ -404,25 +408,46 @@
 
     /**
      * Replaces the contents of the specified UID-based firewall chain.
+     * Enables the chain for specified uids and disables the chain for non-specified uids.
      *
-     * The chain may be an allowlist chain or a denylist chain. A denylist chain contains DROP
-     * rules for the specified UIDs and a RETURN rule at the end. An allowlist chain contains RETURN
-     * rules for the system UID range (0 to {@code UID_APP} - 1), RETURN rules for the specified
-     * UIDs, and a DROP rule at the end. The chain will be created if it does not exist.
-     *
-     * @param chainName   The name of the chain to replace.
-     * @param isAllowlist Whether this is an allowlist or denylist chain.
+     * @param chain       Target chain.
      * @param uids        The list of UIDs to allow/deny.
-     * @return 0 if the chain was successfully replaced, errno otherwise.
+     * @throws UnsupportedOperationException if called on pre-T devices.
+     * @throws IllegalArgumentException if {@code chain} is not a valid chain.
      */
-    public int replaceUidChain(final String chainName, final boolean isAllowlist,
-            final int[] uids) {
-        synchronized (sUidOwnerMap) {
-            final int err = native_replaceUidChain(chainName, isAllowlist, uids);
-            if (err != 0) {
-                Log.e(TAG, "replaceUidChain failed: " + Os.strerror(-err));
+    public void replaceUidChain(final int chain, final int[] uids) {
+        throwIfPreT("replaceUidChain is not available on pre-T devices");
+
+        final long match;
+        try {
+            match = getMatchByFirewallChain(chain);
+        } catch (ServiceSpecificException e) {
+            // Throws IllegalArgumentException to keep the behavior of
+            // ConnectivityManager#replaceFirewallChain API
+            throw new IllegalArgumentException("Invalid firewall chain: " + chain);
+        }
+        final Set<Integer> uidSet = Arrays.stream(uids).boxed().collect(Collectors.toSet());
+        final Set<Integer> uidSetToRemoveRule = new HashSet<>();
+        try {
+            synchronized (sUidOwnerMap) {
+                sUidOwnerMap.forEach((uid, config) -> {
+                    // config could be null if there is a concurrent entry deletion.
+                    // http://b/220084230.
+                    if (config != null
+                            && !uidSet.contains((int) uid.val) && (config.rule & match) != 0) {
+                        uidSetToRemoveRule.add((int) uid.val);
+                    }
+                });
+
+                for (final int uid : uidSetToRemoveRule) {
+                    removeRule(uid, match, "replaceUidChain");
+                }
+                for (final int uid : uids) {
+                    addRule(uid, match, "replaceUidChain");
+                }
             }
-            return -err;
+        } catch (ErrnoException | ServiceSpecificException e) {
+            Log.e(TAG, "replaceUidChain failed: " + e);
         }
     }
 
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index 7050b42..ae1f808 100644
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -11387,39 +11387,6 @@
     public void replaceFirewallChain(final int chain, final int[] uids) {
         enforceNetworkStackOrSettingsPermission();
 
-        try {
-            switch (chain) {
-                case ConnectivityManager.FIREWALL_CHAIN_DOZABLE:
-                    mBpfNetMaps.replaceUidChain("fw_dozable", true /* isAllowList */, uids);
-                    break;
-                case ConnectivityManager.FIREWALL_CHAIN_STANDBY:
-                    mBpfNetMaps.replaceUidChain("fw_standby", false /* isAllowList */, uids);
-                    break;
-                case ConnectivityManager.FIREWALL_CHAIN_POWERSAVE:
-                    mBpfNetMaps.replaceUidChain("fw_powersave", true /* isAllowList */, uids);
-                    break;
-                case ConnectivityManager.FIREWALL_CHAIN_RESTRICTED:
-                    mBpfNetMaps.replaceUidChain("fw_restricted", true /* isAllowList */, uids);
-                    break;
-                case ConnectivityManager.FIREWALL_CHAIN_LOW_POWER_STANDBY:
-                    mBpfNetMaps.replaceUidChain("fw_low_power_standby", true /* isAllowList */,
-                            uids);
-                    break;
-                case ConnectivityManager.FIREWALL_CHAIN_OEM_DENY_1:
-                    mBpfNetMaps.replaceUidChain("fw_oem_deny_1", false /* isAllowList */, uids);
-                    break;
-                case ConnectivityManager.FIREWALL_CHAIN_OEM_DENY_2:
-                    mBpfNetMaps.replaceUidChain("fw_oem_deny_2", false /* isAllowList */, uids);
-                    break;
-                case ConnectivityManager.FIREWALL_CHAIN_OEM_DENY_3:
-                    mBpfNetMaps.replaceUidChain("fw_oem_deny_3", false /* isAllowList */, uids);
-                    break;
-                default:
-                    throw new IllegalArgumentException("replaceFirewallChain with invalid chain: "
-                            + chain);
-            }
-        } catch (ServiceSpecificException e) {
-            throw new IllegalStateException(e);
-        }
+        mBpfNetMaps.replaceUidChain(chain, uids);
     }
 }
diff --git a/tests/unit/java/com/android/server/BpfNetMapsTest.java b/tests/unit/java/com/android/server/BpfNetMapsTest.java
index 0718952..61d9eea 100644
--- a/tests/unit/java/com/android/server/BpfNetMapsTest.java
+++ b/tests/unit/java/com/android/server/BpfNetMapsTest.java
@@ -649,4 +649,80 @@
         assertThrows(UnsupportedOperationException.class, () ->
                 mBpfNetMaps.setUidRule(FIREWALL_CHAIN_DOZABLE, TEST_UID, FIREWALL_RULE_ALLOW));
     }
+
+    @Test
+    @IgnoreUpTo(Build.VERSION_CODES.S_V2)
+    public void testReplaceUidChain() throws Exception {
+        final int uid0 = TEST_UIDS[0];
+        final int uid1 = TEST_UIDS[1];
+
+        mBpfNetMaps.replaceUidChain(FIREWALL_CHAIN_DOZABLE, TEST_UIDS);
+
+        checkUidOwnerValue(uid0, NO_IIF, DOZABLE_MATCH);
+        checkUidOwnerValue(uid1, NO_IIF, DOZABLE_MATCH);
+    }
+
+    @Test
+    @IgnoreUpTo(Build.VERSION_CODES.S_V2)
+    public void testReplaceUidChainWithOtherMatch() throws Exception {
+        final int uid0 = TEST_UIDS[0];
+        final int uid1 = TEST_UIDS[1];
+        final long match0 = POWERSAVE_MATCH;
+        final long match1 = POWERSAVE_MATCH | RESTRICTED_MATCH;
+        mUidOwnerMap.updateEntry(new U32(uid0), new UidOwnerValue(NO_IIF, match0));
+        mUidOwnerMap.updateEntry(new U32(uid1), new UidOwnerValue(NO_IIF, match1));
+
+        mBpfNetMaps.replaceUidChain(FIREWALL_CHAIN_DOZABLE, new int[]{uid1});
+
+        checkUidOwnerValue(uid0, NO_IIF, match0);
+        checkUidOwnerValue(uid1, NO_IIF, match1 | DOZABLE_MATCH);
+    }
+
+    @Test
+    @IgnoreUpTo(Build.VERSION_CODES.S_V2)
+    public void testReplaceUidChainWithExistingIifMatch() throws Exception {
+        final int uid0 = TEST_UIDS[0];
+        final int uid1 = TEST_UIDS[1];
+        final long match0 = IIF_MATCH;
+        final long match1 = IIF_MATCH | POWERSAVE_MATCH | RESTRICTED_MATCH;
+        mUidOwnerMap.updateEntry(new U32(uid0), new UidOwnerValue(TEST_IF_INDEX, match0));
+        mUidOwnerMap.updateEntry(new U32(uid1), new UidOwnerValue(NULL_IIF, match1));
+
+        mBpfNetMaps.replaceUidChain(FIREWALL_CHAIN_DOZABLE, TEST_UIDS);
+
+        checkUidOwnerValue(uid0, TEST_IF_INDEX, match0 | DOZABLE_MATCH);
+        checkUidOwnerValue(uid1, NULL_IIF, match1 | DOZABLE_MATCH);
+    }
+
+    @Test
+    @IgnoreUpTo(Build.VERSION_CODES.S_V2)
+    public void testReplaceUidChainRemoveExistingMatch() throws Exception {
+        final int uid0 = TEST_UIDS[0];
+        final int uid1 = TEST_UIDS[1];
+        final long match0 = IIF_MATCH | DOZABLE_MATCH;
+        final long match1 = IIF_MATCH | POWERSAVE_MATCH | RESTRICTED_MATCH;
+        mUidOwnerMap.updateEntry(new U32(uid0), new UidOwnerValue(TEST_IF_INDEX, match0));
+        mUidOwnerMap.updateEntry(new U32(uid1), new UidOwnerValue(NULL_IIF, match1));
+
+        mBpfNetMaps.replaceUidChain(FIREWALL_CHAIN_DOZABLE, new int[]{uid1});
+
+        checkUidOwnerValue(uid0, TEST_IF_INDEX, match0 & ~DOZABLE_MATCH);
+        checkUidOwnerValue(uid1, NULL_IIF, match1 | DOZABLE_MATCH);
+    }
+
+    @Test
+    @IgnoreUpTo(Build.VERSION_CODES.S_V2)
+    public void testReplaceUidChainInvalidChain() {
+        final Class<IllegalArgumentException> expected = IllegalArgumentException.class;
+        assertThrows(expected, () -> mBpfNetMaps.replaceUidChain(-1 /* chain */, TEST_UIDS));
+        assertThrows(expected, () -> mBpfNetMaps.replaceUidChain(1000 /* chain */, TEST_UIDS));
+    }
+
+    @Test
+    @IgnoreAfter(Build.VERSION_CODES.S_V2)
+    public void testReplaceUidChainBeforeT() {
+        assertThrows(UnsupportedOperationException.class,
+                () -> mBpfNetMaps.replaceUidChain(FIREWALL_CHAIN_DOZABLE, TEST_UIDS));
+    }
+
 }
diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
index 0919dfc..3264a36 100644
--- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
@@ -9592,24 +9592,23 @@
         }
     }
 
-    private void doTestReplaceFirewallChain(final int chain, final String chainName,
-            final boolean allowList) {
+    private void doTestReplaceFirewallChain(final int chain) {
         final int[] uids = new int[] {1001, 1002};
         mCm.replaceFirewallChain(chain, uids);
-        verify(mBpfNetMaps).replaceUidChain(chainName, allowList, uids);
+        verify(mBpfNetMaps).replaceUidChain(chain, uids);
         reset(mBpfNetMaps);
     }
 
     @Test @IgnoreUpTo(SC_V2)
     public void testReplaceFirewallChain() {
-        doTestReplaceFirewallChain(FIREWALL_CHAIN_DOZABLE, "fw_dozable", true);
-        doTestReplaceFirewallChain(FIREWALL_CHAIN_STANDBY, "fw_standby", false);
-        doTestReplaceFirewallChain(FIREWALL_CHAIN_POWERSAVE, "fw_powersave",  true);
-        doTestReplaceFirewallChain(FIREWALL_CHAIN_RESTRICTED, "fw_restricted", true);
-        doTestReplaceFirewallChain(FIREWALL_CHAIN_LOW_POWER_STANDBY, "fw_low_power_standby", true);
-        doTestReplaceFirewallChain(FIREWALL_CHAIN_OEM_DENY_1, "fw_oem_deny_1", false);
-        doTestReplaceFirewallChain(FIREWALL_CHAIN_OEM_DENY_2, "fw_oem_deny_2", false);
-        doTestReplaceFirewallChain(FIREWALL_CHAIN_OEM_DENY_3, "fw_oem_deny_3", false);
+        doTestReplaceFirewallChain(FIREWALL_CHAIN_DOZABLE);
+        doTestReplaceFirewallChain(FIREWALL_CHAIN_STANDBY);
+        doTestReplaceFirewallChain(FIREWALL_CHAIN_POWERSAVE);
+        doTestReplaceFirewallChain(FIREWALL_CHAIN_RESTRICTED);
+        doTestReplaceFirewallChain(FIREWALL_CHAIN_LOW_POWER_STANDBY);
+        doTestReplaceFirewallChain(FIREWALL_CHAIN_OEM_DENY_1);
+        doTestReplaceFirewallChain(FIREWALL_CHAIN_OEM_DENY_2);
+        doTestReplaceFirewallChain(FIREWALL_CHAIN_OEM_DENY_3);
     }
 
     @Test @IgnoreUpTo(SC_V2)
@@ -9620,8 +9619,6 @@
                 () -> mCm.setUidFirewallRule(-1 /* chain */, uid, FIREWALL_RULE_ALLOW));
         assertThrows(expected,
                 () -> mCm.setUidFirewallRule(100 /* chain */, uid, FIREWALL_RULE_ALLOW));
-        assertThrows(expected, () -> mCm.replaceFirewallChain(-1 /* chain */, new int[]{uid}));
-        assertThrows(expected, () -> mCm.replaceFirewallChain(100 /* chain */, new int[]{uid}));
     }
 
     @Test @IgnoreUpTo(SC_V2)