Merge "Add support for Multiple enterprise slice"
diff --git a/TEST_MAPPING b/TEST_MAPPING
index 9f2ea35..ffbd1fc 100644
--- a/TEST_MAPPING
+++ b/TEST_MAPPING
@@ -23,6 +23,9 @@
     },
     {
       "name": "TetheringIntegrationTests"
+    },
+    {
+      "name": "traffic_controller_unit_test"
     }
   ],
   "postsubmit": [
@@ -35,6 +38,10 @@
     },
     {
       "name": "libclat_test"
+    },
+    {
+      "name": "traffic_controller_unit_test",
+      "keywords": ["netd-device-kernel-4.9", "netd-device-kernel-4.14"]
     }
   ],
   "mainline-presubmit": [
@@ -51,6 +58,9 @@
     },
     {
       "name": "ConnectivityCoverageTests[CaptivePortalLoginGoogle.apk+NetworkStackGoogle.apk+com.google.android.resolv.apex+com.google.android.tethering.apex]"
+    },
+    {
+      "name": "traffic_controller_unit_test[CaptivePortalLoginGoogle.apk+NetworkStackGoogle.apk+com.google.android.resolv.apex+com.google.android.tethering.apex]"
     }
   ],
   "mainline-postsubmit": [
diff --git a/Tethering/apex/Android.bp b/Tethering/apex/Android.bp
index 72c83fa..611f100 100644
--- a/Tethering/apex/Android.bp
+++ b/Tethering/apex/Android.bp
@@ -51,7 +51,8 @@
         first: {
             jni_libs: [
                 "libservice-connectivity",
-                "libcom_android_connectivity_com_android_net_module_util_jni"
+                "libcom_android_connectivity_com_android_net_module_util_jni",
+                "libtraffic_controller_jni",
             ],
         },
         both: {
diff --git a/bpf_progs/Android.bp b/bpf_progs/Android.bp
index bb9f5ead6..e228df0 100644
--- a/bpf_progs/Android.bp
+++ b/bpf_progs/Android.bp
@@ -43,6 +43,7 @@
         // calls to JNI in libservices.core.
         "//frameworks/base/services/core/jni",
         "//packages/modules/Connectivity/Tethering",
+        "//packages/modules/Connectivity/service/native",
         "//packages/modules/Connectivity/tests/unit/jni",
         // TODO: remove system/netd/* when all BPF code is moved out of Netd.
         "//system/netd/libnetdbpf",
diff --git a/service/native/Android.bp b/service/native/Android.bp
new file mode 100644
index 0000000..5816318
--- /dev/null
+++ b/service/native/Android.bp
@@ -0,0 +1,102 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+cc_library {
+    name: "libtraffic_controller",
+    defaults: ["netd_defaults"],
+    srcs: [
+        "TrafficController.cpp",
+    ],
+    header_libs: [
+        "bpf_connectivity_headers",
+        "bpf_headers",
+        "bpf_syscall_wrappers",
+    ],
+    static_libs: [
+        "libnetdutils",
+        // TrafficController would use the constants of INetd so that add
+        // netd_aidl_interface-lateststable-ndk.
+        "netd_aidl_interface-lateststable-ndk",
+    ],
+    shared_libs: [
+        // TODO: Find a good way to remove libbase.
+        "libbase",
+        "libcutils",
+        "libutils",
+        "liblog",
+    ],
+    export_include_dirs: ["include"],
+    sanitize: {
+        cfi: true,
+    },
+    apex_available: [
+        "com.android.tethering",
+    ],
+    min_sdk_version: "30",
+}
+
+cc_library_shared {
+    name: "libtraffic_controller_jni",
+    cflags: [
+        "-Wall",
+        "-Werror",
+        "-Wno-unused-parameter",
+        "-Wthread-safety",
+    ],
+    srcs: [
+        "jni/*.cpp",
+    ],
+    header_libs: [
+        "bpf_connectivity_headers",
+    ],
+    static_libs: [
+        "libnetdutils",
+        "libtraffic_controller",
+        "netd_aidl_interface-lateststable-ndk",
+    ],
+    shared_libs: [
+        "libbase",
+        "liblog",
+        "libutils",
+        "libnativehelper",
+    ],
+    apex_available: [
+        "com.android.tethering",
+    ],
+    min_sdk_version: "30",
+}
+
+cc_test {
+    name: "traffic_controller_unit_test",
+    test_suites: ["general-tests"],
+    require_root: true,
+    local_include_dirs: ["include"],
+    header_libs: [
+        "bpf_connectivity_headers",
+    ],
+    srcs: [
+        "TrafficControllerTest.cpp",
+    ],
+    static_libs: [
+        "libbase",
+        "libgmock",
+        "liblog",
+        "libnetdutils",
+        "libtraffic_controller",
+        "libutils",
+        "netd_aidl_interface-lateststable-ndk",
+    ],
+}
diff --git a/service/native/TrafficController.cpp b/service/native/TrafficController.cpp
index c24a41b..cac545d 100644
--- a/service/native/TrafficController.cpp
+++ b/service/native/TrafficController.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 2017 The Android Open Source Project
+ * Copyright (C) 2022 The Android Open Source Project
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -30,6 +30,7 @@
 #include <sys/types.h>
 #include <sys/utsname.h>
 #include <sys/wait.h>
+#include <map>
 #include <mutex>
 #include <unordered_set>
 #include <vector>
@@ -38,12 +39,12 @@
 #include <android-base/strings.h>
 #include <android-base/unique_fd.h>
 #include <netdutils/StatusOr.h>
-
 #include <netdutils/Syscalls.h>
 #include <netdutils/Utils.h>
+#include <private/android_filesystem_config.h>
+
 #include "TrafficController.h"
 #include "bpf/BpfMap.h"
-
 #include "netdutils/DumpWriter.h"
 
 namespace android {
@@ -186,18 +187,12 @@
 }
 
 Status TrafficController::start() {
-    /* When netd restarts from a crash without total system reboot, the program
-     * is still attached to the cgroup, detach it so the program can be freed
-     * and we can load and attach new program into the target cgroup.
-     *
-     * TODO: Scrape existing socket when run-time restart and clean up the map
-     * if the socket no longer exist
-     */
-
     RETURN_IF_NOT_OK(initMaps());
 
     // Fetch the list of currently-existing interfaces. At this point NetlinkHandler is
     // already running, so it will call addInterface() when any new interface appears.
+    // TODO: Clean-up addInterface() after interface monitoring is in
+    // NetworkStatsService.
     std::map<std::string, uint32_t> ifacePairs;
     ASSIGN_OR_RETURN(ifacePairs, getIfaceList());
     for (const auto& ifacePair:ifacePairs) {
@@ -406,18 +401,16 @@
     return netdutils::status::ok;
 }
 
-Status TrafficController::updateUidOwnerMap(const std::vector<uint32_t>& appUids,
+Status TrafficController::updateUidOwnerMap(const uint32_t uid,
                                             UidOwnerMatchType matchType, IptOp op) {
     std::lock_guard guard(mMutex);
-    for (uint32_t uid : appUids) {
-        if (op == IptOpDelete) {
-            RETURN_IF_NOT_OK(removeRule(uid, matchType));
-        } else if (op == IptOpInsert) {
-            RETURN_IF_NOT_OK(addRule(uid, matchType));
-        } else {
-            // Cannot happen.
-            return statusFromErrno(EINVAL, StringPrintf("invalid IptOp: %d, %d", op, matchType));
-        }
+    if (op == IptOpDelete) {
+        RETURN_IF_NOT_OK(removeRule(uid, matchType));
+    } else if (op == IptOpInsert) {
+        RETURN_IF_NOT_OK(addRule(uid, matchType));
+    } else {
+        // Cannot happen.
+        return statusFromErrno(EINVAL, StringPrintf("invalid IptOp: %d, %d", op, matchType));
     }
     return netdutils::status::ok;
 }
@@ -752,7 +745,7 @@
         dw.println("mCookieTagMap print end with error: %s", res.error().message().c_str());
     }
 
-    // Print UidCounterSetMap Content
+    // Print UidCounterSetMap content.
     dumpBpfMap("mUidCounterSetMap", dw, "");
     const auto printUidInfo = [&dw](const uint32_t& key, const uint8_t& value,
                                     const BpfMap<uint32_t, uint8_t>&) {
@@ -764,7 +757,7 @@
         dw.println("mUidCounterSetMap print end with error: %s", res.error().message().c_str());
     }
 
-    // Print AppUidStatsMap content
+    // Print AppUidStatsMap content.
     std::string appUidStatsHeader = StringPrintf("uid rxBytes rxPackets txBytes txPackets");
     dumpBpfMap("mAppUidStatsMap:", dw, appUidStatsHeader);
     auto printAppUidStatsInfo = [&dw](const uint32_t& key, const StatsValue& value,
@@ -778,7 +771,7 @@
         dw.println("mAppUidStatsMap print end with error: %s", res.error().message().c_str());
     }
 
-    // Print uidStatsMap content
+    // Print uidStatsMap content.
     std::string statsHeader = StringPrintf("ifaceIndex ifaceName tag_hex uid_int cnt_set rxBytes"
                                            " rxPackets txBytes txPackets");
     dumpBpfMap("mStatsMapA", dw, statsHeader);
diff --git a/service/native/TrafficControllerTest.cpp b/service/native/TrafficControllerTest.cpp
index f5d5911..f401636 100644
--- a/service/native/TrafficControllerTest.cpp
+++ b/service/native/TrafficControllerTest.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright 2017 The Android Open Source Project
+ * Copyright 2022 The Android Open Source Project
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -32,6 +32,7 @@
 
 #include <android-base/stringprintf.h>
 #include <android-base/strings.h>
+#include <binder/Status.h>
 
 #include <netdutils/MockSyscalls.h>
 
@@ -43,6 +44,7 @@
 namespace android {
 namespace net {
 
+using android::netdutils::Status;
 using base::Result;
 using netdutils::isOk;
 
@@ -273,6 +275,17 @@
         EXPECT_EQ((uint64_t)1, appStatsResult.value().rxPackets);
         EXPECT_EQ((uint64_t)100, appStatsResult.value().rxBytes);
     }
+
+    Status updateUidOwnerMaps(const std::vector<uint32_t>& appUids,
+                              UidOwnerMatchType matchType, TrafficController::IptOp op) {
+        Status ret(0);
+        for (auto uid : appUids) {
+        ret = mTc.updateUidOwnerMap(uid, matchType, op);
+           if(!isOk(ret)) break;
+        }
+        return ret;
+    }
+
 };
 
 TEST_F(TrafficControllerTest, TestSetCounterSet) {
@@ -478,66 +491,62 @@
 
 TEST_F(TrafficControllerTest, TestDenylistUidMatch) {
     std::vector<uint32_t> appUids = {1000, 1001, 10012};
-    ASSERT_TRUE(isOk(
-            mTc.updateUidOwnerMap(appUids, PENALTY_BOX_MATCH, TrafficController::IptOpInsert)));
+    ASSERT_TRUE(isOk(updateUidOwnerMaps(appUids, PENALTY_BOX_MATCH,
+                                        TrafficController::IptOpInsert)));
     expectUidOwnerMapValues(appUids, PENALTY_BOX_MATCH, 0);
-    ASSERT_TRUE(isOk(
-            mTc.updateUidOwnerMap(appUids, PENALTY_BOX_MATCH, TrafficController::IptOpDelete)));
+    ASSERT_TRUE(isOk(updateUidOwnerMaps(appUids, PENALTY_BOX_MATCH,
+                                        TrafficController::IptOpDelete)));
     expectMapEmpty(mFakeUidOwnerMap);
 }
 
 TEST_F(TrafficControllerTest, TestAllowlistUidMatch) {
     std::vector<uint32_t> appUids = {1000, 1001, 10012};
-    ASSERT_TRUE(
-            isOk(mTc.updateUidOwnerMap(appUids, HAPPY_BOX_MATCH, TrafficController::IptOpInsert)));
+    ASSERT_TRUE(isOk(updateUidOwnerMaps(appUids, HAPPY_BOX_MATCH, TrafficController::IptOpInsert)));
     expectUidOwnerMapValues(appUids, HAPPY_BOX_MATCH, 0);
-    ASSERT_TRUE(
-            isOk(mTc.updateUidOwnerMap(appUids, HAPPY_BOX_MATCH, TrafficController::IptOpDelete)));
+    ASSERT_TRUE(isOk(updateUidOwnerMaps(appUids, HAPPY_BOX_MATCH, TrafficController::IptOpDelete)));
     expectMapEmpty(mFakeUidOwnerMap);
 }
 
 TEST_F(TrafficControllerTest, TestReplaceMatchUid) {
     std::vector<uint32_t> appUids = {1000, 1001, 10012};
     // Add appUids to the denylist and expect that their values are all PENALTY_BOX_MATCH.
-    ASSERT_TRUE(isOk(
-            mTc.updateUidOwnerMap(appUids, PENALTY_BOX_MATCH, TrafficController::IptOpInsert)));
+    ASSERT_TRUE(isOk(updateUidOwnerMaps(appUids, PENALTY_BOX_MATCH,
+                                        TrafficController::IptOpInsert)));
     expectUidOwnerMapValues(appUids, PENALTY_BOX_MATCH, 0);
 
     // Add the same UIDs to the allowlist and expect that we get PENALTY_BOX_MATCH |
     // HAPPY_BOX_MATCH.
-    ASSERT_TRUE(
-            isOk(mTc.updateUidOwnerMap(appUids, HAPPY_BOX_MATCH, TrafficController::IptOpInsert)));
+    ASSERT_TRUE(isOk(updateUidOwnerMaps(appUids, HAPPY_BOX_MATCH, TrafficController::IptOpInsert)));
     expectUidOwnerMapValues(appUids, HAPPY_BOX_MATCH | PENALTY_BOX_MATCH, 0);
 
     // Remove the same UIDs from the allowlist and check the PENALTY_BOX_MATCH is still there.
-    ASSERT_TRUE(
-            isOk(mTc.updateUidOwnerMap(appUids, HAPPY_BOX_MATCH, TrafficController::IptOpDelete)));
+    ASSERT_TRUE(isOk(updateUidOwnerMaps(appUids, HAPPY_BOX_MATCH, TrafficController::IptOpDelete)));
     expectUidOwnerMapValues(appUids, PENALTY_BOX_MATCH, 0);
 
     // Remove the same UIDs from the denylist and check the map is empty.
-    ASSERT_TRUE(isOk(
-            mTc.updateUidOwnerMap(appUids, PENALTY_BOX_MATCH, TrafficController::IptOpDelete)));
+    ASSERT_TRUE(isOk(updateUidOwnerMaps(appUids, PENALTY_BOX_MATCH,
+                                        TrafficController::IptOpDelete)));
     ASSERT_FALSE(mFakeUidOwnerMap.getFirstKey().ok());
 }
 
 TEST_F(TrafficControllerTest, TestDeleteWrongMatchSilentlyFails) {
     std::vector<uint32_t> appUids = {1000, 1001, 10012};
     // If the uid does not exist in the map, trying to delete a rule about it will fail.
-    ASSERT_FALSE(
-            isOk(mTc.updateUidOwnerMap(appUids, HAPPY_BOX_MATCH, TrafficController::IptOpDelete)));
+    ASSERT_FALSE(isOk(updateUidOwnerMaps(appUids, HAPPY_BOX_MATCH,
+                                         TrafficController::IptOpDelete)));
     expectMapEmpty(mFakeUidOwnerMap);
 
     // Add denylist rules for appUids.
-    ASSERT_TRUE(
-            isOk(mTc.updateUidOwnerMap(appUids, HAPPY_BOX_MATCH, TrafficController::IptOpInsert)));
+    ASSERT_TRUE(isOk(updateUidOwnerMaps(appUids, HAPPY_BOX_MATCH,
+                                        TrafficController::IptOpInsert)));
     expectUidOwnerMapValues(appUids, HAPPY_BOX_MATCH, 0);
 
     // Delete (non-existent) denylist rules for appUids, and check that this silently does
     // nothing if the uid is in the map but does not have denylist match. This is required because
     // NetworkManagementService will try to remove a uid from denylist after adding it to the
     // allowlist and if the remove fails it will not update the uid status.
-    ASSERT_TRUE(isOk(
-            mTc.updateUidOwnerMap(appUids, PENALTY_BOX_MATCH, TrafficController::IptOpDelete)));
+    ASSERT_TRUE(isOk(updateUidOwnerMaps(appUids, PENALTY_BOX_MATCH,
+                                        TrafficController::IptOpDelete)));
     expectUidOwnerMapValues(appUids, HAPPY_BOX_MATCH, 0);
 }
 
@@ -586,8 +595,8 @@
 
 TEST_F(TrafficControllerTest, TestUidInterfaceFilteringRulesCoexistWithExistingMatches) {
     // Set up existing PENALTY_BOX_MATCH rules
-    ASSERT_TRUE(isOk(mTc.updateUidOwnerMap({1000, 1001, 10012}, PENALTY_BOX_MATCH,
-                                           TrafficController::IptOpInsert)));
+    ASSERT_TRUE(isOk(updateUidOwnerMaps({1000, 1001, 10012}, PENALTY_BOX_MATCH,
+                                        TrafficController::IptOpInsert)));
     expectUidOwnerMapValues({1000, 1001, 10012}, PENALTY_BOX_MATCH, 0);
 
     // Add some partially-overlapping uid owner rules and check result
@@ -598,8 +607,8 @@
     expectUidOwnerMapValues({10013, 10014}, IIF_MATCH, iif1);
 
     // Removing some PENALTY_BOX_MATCH rules should not change uid interface rule
-    ASSERT_TRUE(isOk(mTc.updateUidOwnerMap({1001, 10012}, PENALTY_BOX_MATCH,
-                                           TrafficController::IptOpDelete)));
+    ASSERT_TRUE(isOk(updateUidOwnerMaps({1001, 10012}, PENALTY_BOX_MATCH,
+                                        TrafficController::IptOpDelete)));
     expectUidOwnerMapValues({1000}, PENALTY_BOX_MATCH, 0);
     expectUidOwnerMapValues({10012, 10013, 10014}, IIF_MATCH, iif1);
 
diff --git a/service/native/include/Common.h b/service/native/include/Common.h
new file mode 100644
index 0000000..7c0b797
--- /dev/null
+++ b/service/native/include/Common.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+// TODO: deduplicate with the constants in NetdConstants.h.
+#include <aidl/android/net/INetd.h>
+
+using aidl::android::net::INetd;
+
+enum FirewallRule { ALLOW = INetd::FIREWALL_RULE_ALLOW, DENY = INetd::FIREWALL_RULE_DENY };
+
+// ALLOWLIST means the firewall denies all by default, uids must be explicitly ALLOWed
+// DENYLIST means the firewall allows all by default, uids must be explicitly DENYed
+
+enum FirewallType { ALLOWLIST = INetd::FIREWALL_ALLOWLIST, DENYLIST = INetd::FIREWALL_DENYLIST };
+
+enum ChildChain {
+    NONE = INetd::FIREWALL_CHAIN_NONE,
+    DOZABLE = INetd::FIREWALL_CHAIN_DOZABLE,
+    STANDBY = INetd::FIREWALL_CHAIN_STANDBY,
+    POWERSAVE = INetd::FIREWALL_CHAIN_POWERSAVE,
+    RESTRICTED = INetd::FIREWALL_CHAIN_RESTRICTED,
+    INVALID_CHAIN
+};
diff --git a/service/native/include/TrafficController.h b/service/native/include/TrafficController.h
index 3e98b68..c050871 100644
--- a/service/native/include/TrafficController.h
+++ b/service/native/include/TrafficController.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 2017 The Android Open Source Project
+ * Copyright (C) 2022 The Android Open Source Project
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -14,12 +14,11 @@
  * limitations under the License.
  */
 
-#ifndef NETD_SERVER_TRAFFIC_CONTROLLER_H
-#define NETD_SERVER_TRAFFIC_CONTROLLER_H
+#pragma once
 
-#include <linux/bpf.h>
+#include <set>
+#include <Common.h>
 
-#include "NetlinkListener.h"
 #include "android-base/thread_annotations.h"
 #include "bpf/BpfMap.h"
 #include "bpf_shared.h"
@@ -31,6 +30,8 @@
 namespace android {
 namespace net {
 
+using netdutils::StatusOr;
+
 class TrafficController {
   public:
     /*
@@ -38,9 +39,6 @@
      */
     netdutils::Status start();
 
-    /*
-     * Similiar as above, no external lock required.
-     */
     int setCounterSet(int counterSetNum, uid_t uid, uid_t callingUid) EXCLUDES(mMutex);
 
     /*
@@ -84,7 +82,7 @@
             EXCLUDES(mMutex);
     netdutils::Status removeUidInterfaceRules(const std::vector<int32_t>& uids) EXCLUDES(mMutex);
 
-    netdutils::Status updateUidOwnerMap(const std::vector<uint32_t>& appStrUids,
+    netdutils::Status updateUidOwnerMap(const uint32_t uid,
                                         UidOwnerMatchType matchType, IptOp op) EXCLUDES(mMutex);
     static const String16 DUMP_KEYWORD;
 
@@ -187,21 +185,6 @@
     netdutils::Status addRule(uint32_t uid, UidOwnerMatchType match, uint32_t iif = 0)
             REQUIRES(mMutex);
 
-    // mMutex guards all accesses to mConfigurationMap, mUidOwnerMap, mUidPermissionMap,
-    // mStatsMapA, mStatsMapB and mPrivilegedUser. It is designed to solve the following
-    // problems:
-    // 1. Prevent concurrent access and modification to mConfigurationMap, mUidOwnerMap,
-    //    mUidPermissionMap, and mPrivilegedUser. These data members are controlled by netd but can
-    //    be modified from different threads. TrafficController provides several APIs directly
-    //    called by the binder RPC, and different binder threads can concurrently access these data
-    //    members mentioned above. Some of the data members such as mUidPermissionMap and
-    //    mPrivilegedUsers are also accessed from a different thread when tagging sockets or
-    //    setting the counterSet through FwmarkServer
-    // 2. Coordinate the deletion of uid stats in mStatsMapA and mStatsMapB. The system server
-    //    always call into netd to ask for a live stats map change before it pull and clean up the
-    //    stats from the inactive map. The mMutex will block netd from accessing the stats map when
-    //    the mConfigurationMap is updating the current stats map so netd will not accidentally
-    //    read the map that system_server is cleaning up.
     std::mutex mMutex;
 
     netdutils::Status initMaps() EXCLUDES(mMutex);
@@ -218,5 +201,3 @@
 
 }  // namespace net
 }  // namespace android
-
-#endif  // NETD_SERVER_TRAFFIC_CONTROLLER_H
diff --git a/service/native/jni/com_android_server_BpfNetMaps.cpp b/service/native/jni/com_android_server_BpfNetMaps.cpp
new file mode 100644
index 0000000..7ab4d46
--- /dev/null
+++ b/service/native/jni/com_android_server_BpfNetMaps.cpp
@@ -0,0 +1,266 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#define LOG_TAG "TrafficControllerJni"
+
+#include <jni.h>
+#include <nativehelper/JNIHelp.h>
+#include <nativehelper/ScopedUtfChars.h>
+#include <nativehelper/ScopedPrimitiveArray.h>
+#include <net/if.h>
+#include <vector>
+
+#include "TrafficController.h"
+#include "android-base/logging.h"
+#include "bpf_shared.h"
+#include "utils/Log.h"
+
+using android::net::TrafficController;
+using android::netdutils::Status;
+
+using UidOwnerMatchType::PENALTY_BOX_MATCH;
+using UidOwnerMatchType::HAPPY_BOX_MATCH;
+
+static android::net::TrafficController mTc;
+
+namespace android {
+
+static void native_init(JNIEnv* env, jobject clazz) {
+  Status status = mTc.start();
+   if (!isOk(status)) {
+    ALOGE("%s failed", __func__);
+  }
+}
+
+static jint native_addNaughtyApp(JNIEnv* env, jobject clazz, jint uid) {
+  const uint32_t appUids = static_cast<uint32_t>(abs(uid));
+  Status status = mTc.updateUidOwnerMap(appUids, PENALTY_BOX_MATCH,
+      TrafficController::IptOp::IptOpInsert);
+  if (!isOk(status)) {
+    ALOGE("%s failed, errer code = %d", __func__, status.code());
+  }
+  return (jint)status.code();
+}
+
+static jint native_removeNaughtyApp(JNIEnv* env, jobject clazz, jint uid) {
+  const uint32_t appUids = static_cast<uint32_t>(abs(uid));
+  Status status = mTc.updateUidOwnerMap(appUids, PENALTY_BOX_MATCH,
+      TrafficController::IptOp::IptOpDelete);
+  if (!isOk(status)) {
+    ALOGE("%s failed, errer code = %d", __func__, status.code());
+  }
+  return (jint)status.code();
+}
+
+static jint native_addNiceApp(JNIEnv* env, jobject clazz, jint uid) {
+  const uint32_t appUids = static_cast<uint32_t>(abs(uid));
+  Status status = mTc.updateUidOwnerMap(appUids, HAPPY_BOX_MATCH,
+      TrafficController::IptOp::IptOpInsert);
+  if (!isOk(status)) {
+    ALOGE("%s failed, errer code = %d", __func__, status.code());
+  }
+  return (jint)status.code();
+}
+
+static jint native_removeNiceApp(JNIEnv* env, jobject clazz, jint uid) {
+  const uint32_t appUids = static_cast<uint32_t>(abs(uid));
+  Status status = mTc.updateUidOwnerMap(appUids, HAPPY_BOX_MATCH,
+      TrafficController::IptOp::IptOpDelete);
+  if (!isOk(status)) {
+    ALOGD("%s failed, errer code = %d", __func__, status.code());
+  }
+  return (jint)status.code();
+}
+
+static jint native_setChildChain(JNIEnv* env, jobject clazz, jint childChain, jboolean enable) {
+  auto chain = static_cast<ChildChain>(childChain);
+  int res = mTc.toggleUidOwnerMap(chain, enable);
+  if (res) {
+    ALOGE("%s failed, error code = %d", __func__, res);
+  }
+  return (jint)res;
+}
+
+static jint native_replaceUidChain(JNIEnv* env, jobject clazz, jstring name, jboolean isAllowlist,
+                                jintArray jUids) {
+    const ScopedUtfChars chainNameUtf8(env, name);
+    if (chainNameUtf8.c_str() == nullptr) {
+        return -EINVAL;
+    }
+    const std::string chainName(chainNameUtf8.c_str());
+
+    ScopedIntArrayRO uids(env, jUids);
+    if (uids.get() == nullptr) {
+        return -EINVAL;
+    }
+
+    size_t size = uids.size();
+    std::vector<int32_t> data ((int32_t *)&uids[0], (int32_t*)&uids[size]);
+    int res = mTc.replaceUidOwnerMap(chainName, isAllowlist, data);
+    if (res) {
+      ALOGE("%s failed, error code = %d", __func__, res);
+    }
+    return (jint)res;
+}
+
+static FirewallType getFirewallType(ChildChain chain) {
+    switch (chain) {
+        case DOZABLE:
+            return ALLOWLIST;
+        case STANDBY:
+            return DENYLIST;
+        case POWERSAVE:
+            return ALLOWLIST;
+        case RESTRICTED:
+            return ALLOWLIST;
+        case NONE:
+        default:
+            return DENYLIST;
+    }
+}
+
+static jint native_setUidRule(JNIEnv* env, jobject clazz, jint childChain, jint uid,
+                          jint firewallRule) {
+    auto chain = static_cast<ChildChain>(childChain);
+    auto rule = static_cast<FirewallRule>(firewallRule);
+    FirewallType fType = getFirewallType(chain);
+
+    int res = mTc.changeUidOwnerRule(chain, uid, rule, fType);
+    if (res) {
+      ALOGE("%s failed, error code = %d", __func__, res);
+    }
+    return (jint)res;
+}
+
+static jint native_addUidInterfaceRules(JNIEnv* env, jobject clazz, jstring ifName,
+                                    jintArray jUids) {
+    const ScopedUtfChars ifNameUtf8(env, ifName);
+    if (ifNameUtf8.c_str() == nullptr) {
+        return -EINVAL;
+    }
+    const std::string interfaceName(ifNameUtf8.c_str());
+    const int ifIndex = if_nametoindex(interfaceName.c_str());
+
+    ScopedIntArrayRO uids(env, jUids);
+    if (uids.get() == nullptr) {
+        return -EINVAL;
+    }
+
+    size_t size = uids.size();
+    std::vector<int32_t> data ((int32_t *)&uids[0], (int32_t*)&uids[size]);
+    Status status = mTc.addUidInterfaceRules(ifIndex, data);
+    if (!isOk(status)) {
+        ALOGE("%s failed, error code = %d", __func__, status.code());
+    }
+    return (jint)status.code();
+}
+
+static jint native_removeUidInterfaceRules(JNIEnv* env, jobject clazz, jintArray jUids) {
+    ScopedIntArrayRO uids(env, jUids);
+    if (uids.get() == nullptr) {
+        return -EINVAL;
+    }
+
+    size_t size = uids.size();
+    std::vector<int32_t> data ((int32_t *)&uids[0], (int32_t*)&uids[size]);
+    Status status = mTc.removeUidInterfaceRules(data);
+    if (!isOk(status)) {
+        ALOGE("%s failed, error code = %d", __func__, status.code());
+    }
+    return (jint)status.code();
+}
+
+static jint native_swapActiveStatsMap(JNIEnv* env, jobject clazz) {
+    Status status = mTc.swapActiveStatsMap();
+    if (!isOk(status)) {
+        ALOGD("%s failed, error code = %d", __func__, status.code());
+    }
+    return (jint)status.code();
+}
+
+static void native_setPermissionForUids(JNIEnv* env, jobject clazz, jint permission,
+                                      jintArray jUids) {
+    ScopedIntArrayRO uids(env, jUids);
+    if (uids.get() == nullptr) return;
+
+    size_t size = uids.size();
+    static_assert(sizeof(*(uids.get())) == sizeof(uid_t));
+    std::vector<uid_t> data ((uid_t *)&uids[0], (uid_t*)&uids[size]);
+    mTc.setPermissionForUids(permission, data);
+}
+
+static jint native_setCounterSet(JNIEnv* env, jobject clazz, jint setNum, jint uid) {
+    uid_t callingUid = getuid();
+    int res = mTc.setCounterSet(setNum, (uid_t)uid, callingUid);
+    if (res) {
+      ALOGE("%s failed, error code = %d", __func__, res);
+    }
+    return (jint)res;
+}
+
+static jint native_deleteTagData(JNIEnv* env, jobject clazz, jint tagNum, jint uid) {
+    uid_t callingUid = getuid();
+    int res = mTc.deleteTagData(tagNum, (uid_t)uid, callingUid);
+    if (res) {
+      ALOGE("%s failed, error code = %d", __func__, res);
+    }
+    return (jint)res;
+}
+
+/*
+ * JNI registration.
+ */
+// clang-format off
+static const JNINativeMethod gMethods[] = {
+    /* name, signature, funcPtr */
+    {"native_init", "()V",
+    (void*)native_init},
+    {"native_addNaughtyApp", "(I)I",
+    (void*)native_addNaughtyApp},
+    {"native_removeNaughtyApp", "(I)I",
+    (void*)native_removeNaughtyApp},
+    {"native_addNiceApp", "(I)I",
+    (void*)native_addNiceApp},
+    {"native_removeNiceApp", "(I)I",
+    (void*)native_removeNiceApp},
+    {"native_setChildChain", "(IZ)I",
+    (void*)native_setChildChain},
+    {"native_replaceUidChain", "(Ljava/lang/String;Z[I)I",
+    (void*)native_replaceUidChain},
+    {"native_setUidRule", "(III)I",
+    (void*)native_setUidRule},
+    {"native_addUidInterfaceRules", "(Ljava/lang/String;[I)I",
+    (void*)native_addUidInterfaceRules},
+    {"native_removeUidInterfaceRules", "([I)I",
+    (void*)native_removeUidInterfaceRules},
+    {"native_swapActiveStatsMap", "()I",
+    (void*)native_swapActiveStatsMap},
+    {"native_setPermissionForUids", "(I[I)V",
+    (void*)native_setPermissionForUids},
+    {"native_setCounterSet", "(II)I",
+    (void*)native_setCounterSet},
+    {"native_deleteTagData", "(II)I",
+    (void*)native_deleteTagData},
+};
+// clang-format on
+
+int register_com_android_server_BpfNetMaps(JNIEnv* env) {
+    return jniRegisterNativeMethods(env,
+    "com/android/server/BpfNetMaps",
+    gMethods, NELEM(gMethods));
+}
+
+}; // namespace android
diff --git a/service/native/jni/onload.cpp b/service/native/jni/onload.cpp
new file mode 100644
index 0000000..df7c77b
--- /dev/null
+++ b/service/native/jni/onload.cpp
@@ -0,0 +1,41 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#define LOG_TAG "TrafficControllerJni"
+
+#include <jni.h>
+#include <nativehelper/JNIHelp.h>
+
+#include "utils/Log.h"
+
+namespace android {
+
+int register_com_android_server_BpfNetMaps(JNIEnv* env);
+
+extern "C" jint JNI_OnLoad(JavaVM* vm, void*) {
+    JNIEnv *env;
+    if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6) != JNI_OK) {
+        ALOGE("%s: ERROR: GetEnv failed", __func__);
+        return JNI_ERR;
+    }
+
+    if (register_com_android_server_BpfNetMaps(env) < 0)
+      return JNI_ERR;
+
+    return JNI_VERSION_1_6;
+}
+
+}; // namespace android
diff --git a/service/src/com/android/server/BpfNetMaps.java b/service/src/com/android/server/BpfNetMaps.java
index bc63eef..a6909c0 100644
--- a/service/src/com/android/server/BpfNetMaps.java
+++ b/service/src/com/android/server/BpfNetMaps.java
@@ -28,6 +28,11 @@
 public class BpfNetMaps {
     private static final String TAG = "BpfNetMaps";
 
+    static {
+        System.loadLibrary("traffic_controller_jni");
+        native_init();
+    }
+
    /**
     * Add naughty app bandwidth rule for specific app
     *
@@ -239,6 +244,7 @@
         return -err;
     }
 
+    private static native void native_init();
     private native int native_addNaughtyApp(int uid);
     private native int native_removeNaughtyApp(int uid);
     private native int native_addNiceApp(int uid);
diff --git a/tests/unit/java/android/net/NetworkIdentityTest.kt b/tests/unit/java/android/net/NetworkIdentityTest.kt
index 6ad8b06..ec0420e 100644
--- a/tests/unit/java/android/net/NetworkIdentityTest.kt
+++ b/tests/unit/java/android/net/NetworkIdentityTest.kt
@@ -17,7 +17,11 @@
 package android.net
 
 import android.content.Context
+import android.net.ConnectivityManager.MAX_NETWORK_TYPE
+import android.net.ConnectivityManager.TYPE_ETHERNET
 import android.net.ConnectivityManager.TYPE_MOBILE
+import android.net.ConnectivityManager.TYPE_NONE
+import android.net.ConnectivityManager.TYPE_WIFI
 import android.net.NetworkIdentity.OEM_NONE
 import android.net.NetworkIdentity.OEM_PAID
 import android.net.NetworkIdentity.OEM_PRIVATE
@@ -30,10 +34,12 @@
 import org.junit.runner.RunWith
 import org.mockito.Mockito.mock
 import kotlin.test.assertEquals
+import kotlin.test.assertFailsWith
 import kotlin.test.assertFalse
 import kotlin.test.assertTrue
 
 private const val TEST_IMSI = "testimsi"
+private const val TEST_WIFI_KEY = "testwifikey"
 
 @RunWith(DevSdkIgnoreRunner::class)
 @DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.R)
@@ -126,6 +132,92 @@
         assertEquals(identFromLegacyBuild, identFromSnapshot)
         assertEquals(identFromConstructor, identFromSnapshot)
 
-        // TODO: Add test cases for wifiNetworkKey and ratType.
+        // Assert non-wifi can't have wifi network key.
+        assertFailsWith<IllegalArgumentException> {
+            NetworkIdentity.Builder()
+                    .setType(TYPE_ETHERNET)
+                    .setWifiNetworkKey(TEST_WIFI_KEY)
+                    .build()
+        }
+
+        // Assert non-mobile can't have ratType.
+        assertFailsWith<IllegalArgumentException> {
+            NetworkIdentity.Builder()
+                    .setType(TYPE_WIFI)
+                    .setRatType(TelephonyManager.NETWORK_TYPE_LTE)
+                    .build()
+        }
+    }
+
+    @Test
+    fun testBuilder_type() {
+        // Assert illegal type values cannot make an identity.
+        listOf(Integer.MIN_VALUE, TYPE_NONE - 1, MAX_NETWORK_TYPE + 1, Integer.MAX_VALUE)
+                .forEach { type ->
+                    assertFailsWith<IllegalArgumentException> {
+                        NetworkIdentity.Builder().setType(type).build()
+                    }
+                }
+
+        // Verify legitimate type values can make an identity.
+        for (type in TYPE_NONE..MAX_NETWORK_TYPE) {
+            NetworkIdentity.Builder().setType(type).build().also {
+                assertEquals(it.type, type)
+            }
+        }
+    }
+
+    @Test
+    fun testBuilder_ratType() {
+        // Assert illegal ratTypes cannot make an identity.
+        listOf(Integer.MIN_VALUE, NetworkTemplate.NETWORK_TYPE_ALL,
+                TelephonyManager.NETWORK_TYPE_UNKNOWN - 1, Integer.MAX_VALUE)
+                .forEach {
+                    assertFailsWith<IllegalArgumentException> {
+                        NetworkIdentity.Builder()
+                                .setType(TYPE_MOBILE)
+                                .setRatType(it)
+                                .build()
+                    }
+                }
+
+        // Verify legitimate ratTypes can make an identity.
+        TelephonyManager.getAllNetworkTypes().toMutableList().also {
+            it.add(TelephonyManager.NETWORK_TYPE_UNKNOWN)
+        }.forEach { rat ->
+            NetworkIdentity.Builder()
+                    .setType(TYPE_MOBILE)
+                    .setRatType(rat)
+                    .build().also {
+                        assertEquals(it.ratType, rat)
+                    }
+        }
+    }
+
+    @Test
+    fun testBuilder_oemManaged() {
+        // Assert illegal oemManage values cannot make an identity.
+        listOf(Integer.MIN_VALUE, NetworkTemplate.OEM_MANAGED_ALL, NetworkTemplate.OEM_MANAGED_YES,
+                Integer.MAX_VALUE)
+                .forEach {
+                    assertFailsWith<IllegalArgumentException> {
+                        NetworkIdentity.Builder()
+                                .setType(TYPE_MOBILE)
+                                .setRatType(it)
+                                .build()
+                    }
+                }
+
+        // Verify legitimate oem managed values can make an identity.
+        listOf(NetworkTemplate.OEM_MANAGED_NO, NetworkTemplate.OEM_MANAGED_PAID,
+                NetworkTemplate.OEM_MANAGED_PRIVATE, NetworkTemplate.OEM_MANAGED_PAID or
+                NetworkTemplate.OEM_MANAGED_PRIVATE)
+                .forEach { oemManaged ->
+                    NetworkIdentity.Builder()
+                            .setOemManaged(oemManaged)
+                            .build().also {
+                                assertEquals(it.oemManaged, oemManaged)
+                            }
+                }
     }
 }
diff --git a/tests/unit/java/android/net/NetworkStatsHistoryTest.java b/tests/unit/java/android/net/NetworkStatsHistoryTest.java
index c5f8c00..c170605 100644
--- a/tests/unit/java/android/net/NetworkStatsHistoryTest.java
+++ b/tests/unit/java/android/net/NetworkStatsHistoryTest.java
@@ -56,6 +56,7 @@
 import java.io.ByteArrayOutputStream;
 import java.io.DataInputStream;
 import java.io.DataOutputStream;
+import java.util.List;
 import java.util.Random;
 
 @RunWith(DevSdkIgnoreRunner.class)
@@ -532,6 +533,40 @@
         assertEquals(512L + 4096L, stats.getTotalBytes());
     }
 
+    @Test
+    public void testBuilder() {
+        final NetworkStatsHistory.Entry entry1 = new NetworkStatsHistory.Entry(10, 30, 40,
+                4, 50, 5, 60);
+        final NetworkStatsHistory.Entry entry2 = new NetworkStatsHistory.Entry(30, 15, 3,
+                41, 7, 1, 0);
+        final NetworkStatsHistory.Entry entry3 = new NetworkStatsHistory.Entry(7, 301, 11,
+                14, 31, 2, 80);
+
+        final NetworkStatsHistory statsEmpty = new NetworkStatsHistory
+                .Builder(HOUR_IN_MILLIS, 10).build();
+        assertEquals(0, statsEmpty.getEntries().size());
+        assertEquals(HOUR_IN_MILLIS, statsEmpty.getBucketDuration());
+
+        NetworkStatsHistory statsSingle = new NetworkStatsHistory
+                .Builder(HOUR_IN_MILLIS, 8)
+                .addEntry(entry1)
+                .build();
+        assertEquals(1, statsSingle.getEntries().size());
+        assertEquals(HOUR_IN_MILLIS, statsSingle.getBucketDuration());
+        assertEquals(entry1, statsSingle.getEntries().get(0));
+
+        NetworkStatsHistory statsMultiple = new NetworkStatsHistory
+                .Builder(SECOND_IN_MILLIS, 0)
+                .addEntry(entry1).addEntry(entry2).addEntry(entry3)
+                .build();
+        final List<NetworkStatsHistory.Entry> entries = statsMultiple.getEntries();
+        assertEquals(3, entries.size());
+        assertEquals(SECOND_IN_MILLIS, statsMultiple.getBucketDuration());
+        assertEquals(entry1, entries.get(0));
+        assertEquals(entry2, entries.get(1));
+        assertEquals(entry3, entries.get(2));
+    }
+
     private static void assertIndexBeforeAfter(
             NetworkStatsHistory stats, int before, int after, long time) {
         assertEquals("unexpected before", before, stats.getIndexBefore(time));
diff --git a/tests/unit/java/android/net/NetworkTemplateTest.kt b/tests/unit/java/android/net/NetworkTemplateTest.kt
index 0c3bee3..048597f 100644
--- a/tests/unit/java/android/net/NetworkTemplateTest.kt
+++ b/tests/unit/java/android/net/NetworkTemplateTest.kt
@@ -95,7 +95,7 @@
         oemManaged: Int = OEM_NONE,
         metered: Boolean = true
     ): NetworkStateSnapshot {
-        `when`(mockWifiInfo.getCurrentNetworkKey()).thenReturn(wifiKey)
+        `when`(mockWifiInfo.getNetworkKey()).thenReturn(wifiKey)
         val lp = LinkProperties()
         val caps = NetworkCapabilities().apply {
             setCapability(NetworkCapabilities.NET_CAPABILITY_NOT_METERED, !metered)
diff --git a/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java b/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java
index 5e1699a..d7bbf50 100644
--- a/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java
+++ b/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java
@@ -266,7 +266,7 @@
         mServiceContext = new MockContext(context);
         when(mLocationPermissionChecker.checkCallersLocationPermission(
                 any(), any(), anyInt(), anyBoolean(), any())).thenReturn(true);
-        when(sWifiInfo.getCurrentNetworkKey()).thenReturn(TEST_WIFI_NETWORK_KEY);
+        when(sWifiInfo.getNetworkKey()).thenReturn(TEST_WIFI_NETWORK_KEY);
         mStatsDir = TestIoUtils.createTemporaryDirectory(getClass().getSimpleName());
 
         PowerManager powerManager = (PowerManager) mServiceContext.getSystemService(