Merge changes I3286aeb1,I060509de into main
* changes:
Add Wi-Fi P2P mdns tests
Setup Wi-Fi P2P connection
diff --git a/Tethering/Android.bp b/Tethering/Android.bp
index b4426a6..d04660d 100644
--- a/Tethering/Android.bp
+++ b/Tethering/Android.bp
@@ -69,7 +69,7 @@
"android.hardware.tetheroffload.control-V1.0-java",
"android.hardware.tetheroffload.control-V1.1-java",
"android.hidl.manager-V1.2-java",
- "net-utils-tethering",
+ "net-utils-connectivity-apks",
"netd-client",
"tetheringstatsprotos",
],
diff --git a/Tethering/src/com/android/networkstack/tethering/metrics/TetheringMetrics.java b/Tethering/src/com/android/networkstack/tethering/metrics/TetheringMetrics.java
index 136dfb1..ac4d8b1 100644
--- a/Tethering/src/com/android/networkstack/tethering/metrics/TetheringMetrics.java
+++ b/Tethering/src/com/android/networkstack/tethering/metrics/TetheringMetrics.java
@@ -22,6 +22,12 @@
import static android.net.NetworkCapabilities.TRANSPORT_LOWPAN;
import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
import static android.net.NetworkCapabilities.TRANSPORT_WIFI_AWARE;
+import static android.net.NetworkStats.DEFAULT_NETWORK_YES;
+import static android.net.NetworkStats.METERED_YES;
+import static android.net.NetworkTemplate.MATCH_BLUETOOTH;
+import static android.net.NetworkTemplate.MATCH_ETHERNET;
+import static android.net.NetworkTemplate.MATCH_MOBILE;
+import static android.net.NetworkTemplate.MATCH_WIFI;
import static android.net.TetheringManager.TETHERING_BLUETOOTH;
import static android.net.TetheringManager.TETHERING_ETHERNET;
import static android.net.TetheringManager.TETHERING_NCM;
@@ -48,6 +54,7 @@
import android.annotation.Nullable;
import android.content.Context;
import android.net.NetworkCapabilities;
+import android.net.NetworkTemplate;
import android.stats.connectivity.DownstreamType;
import android.stats.connectivity.ErrorCode;
import android.stats.connectivity.UpstreamType;
@@ -419,4 +426,46 @@
}
return false;
}
+
+ /**
+ * Build NetworkTemplate for the given upstream type.
+ *
+ * <p> NetworkTemplate.Builder API was introduced in Android T.
+ *
+ * @param type the upstream type
+ * @return A NetworkTemplate object with a corresponding match rule or null if tethering
+ * metrics' data usage cannot be collected for a given upstream type.
+ */
+ @Nullable
+ public static NetworkTemplate buildNetworkTemplateForUpstreamType(@NonNull UpstreamType type) {
+ if (!isUsageSupportedForUpstreamType(type)) return null;
+
+ switch (type) {
+ case UT_CELLULAR:
+ // TODO: Handle the DUN connection, which is not a default network.
+ return new NetworkTemplate.Builder(MATCH_MOBILE)
+ .setMeteredness(METERED_YES)
+ .setDefaultNetworkStatus(DEFAULT_NETWORK_YES)
+ .build();
+ case UT_WIFI:
+ return new NetworkTemplate.Builder(MATCH_WIFI)
+ .setMeteredness(METERED_YES)
+ .setDefaultNetworkStatus(DEFAULT_NETWORK_YES)
+ .build();
+ case UT_BLUETOOTH:
+ return new NetworkTemplate.Builder(MATCH_BLUETOOTH)
+ .setMeteredness(METERED_YES)
+ .setDefaultNetworkStatus(DEFAULT_NETWORK_YES)
+ .build();
+ case UT_ETHERNET:
+ return new NetworkTemplate.Builder(MATCH_ETHERNET)
+ .setMeteredness(METERED_YES)
+ .setDefaultNetworkStatus(DEFAULT_NETWORK_YES)
+ .build();
+ default:
+ Log.e(TAG, "Unsupported UpstreamType: " + type.name());
+ break;
+ }
+ return null;
+ }
}
diff --git a/Tethering/src/com/android/networkstack/tethering/util/VersionedBroadcastListener.java b/Tethering/src/com/android/networkstack/tethering/util/VersionedBroadcastListener.java
index c9e75c0..5eb1551 100644
--- a/Tethering/src/com/android/networkstack/tethering/util/VersionedBroadcastListener.java
+++ b/Tethering/src/com/android/networkstack/tethering/util/VersionedBroadcastListener.java
@@ -62,8 +62,8 @@
if (DBG) Log.d(mTag, "startListening");
if (mReceiver != null) return;
- mReceiver = new Receiver(mTag, mGenerationNumber, mCallback);
- mContext.registerReceiver(mReceiver, mFilter, null, mHandler);
+ mReceiver = new Receiver(mTag, mGenerationNumber, mCallback, mHandler);
+ mContext.registerReceiver(mReceiver, mFilter);
}
/** Stop listening to intent broadcast. */
@@ -77,30 +77,35 @@
}
private static class Receiver extends BroadcastReceiver {
- public final String tag;
- public final AtomicInteger atomicGenerationNumber;
- public final Consumer<Intent> callback;
+ final String mTag;
+ final AtomicInteger mAtomicGenerationNumber;
+ final Consumer<Intent> mCallback;
// Used to verify this receiver is still current.
- public final int generationNumber;
+ final int mGenerationNumber;
+ private final Handler mHandler;
- Receiver(String tag, AtomicInteger atomicGenerationNumber, Consumer<Intent> callback) {
- this.tag = tag;
- this.atomicGenerationNumber = atomicGenerationNumber;
- this.callback = callback;
- generationNumber = atomicGenerationNumber.incrementAndGet();
+ Receiver(String tag, AtomicInteger atomicGenerationNumber, Consumer<Intent> callback,
+ Handler handler) {
+ mTag = tag;
+ mAtomicGenerationNumber = atomicGenerationNumber;
+ mCallback = callback;
+ mGenerationNumber = atomicGenerationNumber.incrementAndGet();
+ mHandler = handler;
}
@Override
public void onReceive(Context context, Intent intent) {
- final int currentGenerationNumber = atomicGenerationNumber.get();
+ mHandler.post(() -> {
+ final int currentGenerationNumber = mAtomicGenerationNumber.get();
- if (DBG) {
- Log.d(tag, "receiver generationNumber=" + generationNumber
- + ", current generationNumber=" + currentGenerationNumber);
- }
- if (generationNumber != currentGenerationNumber) return;
+ if (DBG) {
+ Log.d(mTag, "receiver generationNumber=" + mGenerationNumber
+ + ", current generationNumber=" + currentGenerationNumber);
+ }
+ if (mGenerationNumber != currentGenerationNumber) return;
- callback.accept(intent);
+ mCallback.accept(intent);
+ });
}
}
}
diff --git a/Tethering/tests/unit/src/com/android/networkstack/tethering/metrics/TetheringMetricsTest.java b/Tethering/tests/unit/src/com/android/networkstack/tethering/metrics/TetheringMetricsTest.java
index 7cef9cb..fbc2893 100644
--- a/Tethering/tests/unit/src/com/android/networkstack/tethering/metrics/TetheringMetricsTest.java
+++ b/Tethering/tests/unit/src/com/android/networkstack/tethering/metrics/TetheringMetricsTest.java
@@ -22,6 +22,10 @@
import static android.net.NetworkCapabilities.TRANSPORT_LOWPAN;
import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
import static android.net.NetworkCapabilities.TRANSPORT_WIFI_AWARE;
+import static android.net.NetworkTemplate.MATCH_BLUETOOTH;
+import static android.net.NetworkTemplate.MATCH_ETHERNET;
+import static android.net.NetworkTemplate.MATCH_MOBILE;
+import static android.net.NetworkTemplate.MATCH_WIFI;
import static android.net.TetheringManager.TETHERING_BLUETOOTH;
import static android.net.TetheringManager.TETHERING_ETHERNET;
import static android.net.TetheringManager.TETHERING_NCM;
@@ -47,11 +51,14 @@
import static android.net.TetheringManager.TETHER_ERROR_UNTETHER_IFACE_ERROR;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.verify;
import android.content.Context;
import android.net.NetworkCapabilities;
+import android.net.NetworkTemplate;
+import android.os.Build;
import android.stats.connectivity.DownstreamType;
import android.stats.connectivity.ErrorCode;
import android.stats.connectivity.UpstreamType;
@@ -62,8 +69,11 @@
import com.android.networkstack.tethering.UpstreamNetworkState;
import com.android.networkstack.tethering.metrics.TetheringMetrics.Dependencies;
+import com.android.testutils.DevSdkIgnoreRule;
+import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo;
import org.junit.Before;
+import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
@@ -72,12 +82,15 @@
@RunWith(AndroidJUnit4.class)
@SmallTest
public final class TetheringMetricsTest {
+ @Rule public final DevSdkIgnoreRule mIgnoreRule = new DevSdkIgnoreRule();
+
private static final String TEST_CALLER_PKG = "com.test.caller.pkg";
private static final String SETTINGS_PKG = "com.android.settings";
private static final String SYSTEMUI_PKG = "com.android.systemui";
private static final String GMS_PKG = "com.google.android.gms";
private static final long TEST_START_TIME = 1670395936033L;
private static final long SECOND_IN_MILLIS = 1_000L;
+ private static final int MATCH_NONE = -1;
@Mock private Context mContext;
@Mock private Dependencies mDeps;
@@ -392,4 +405,27 @@
runUsageSupportedForUpstreamTypeTest(UpstreamType.UT_LOWPAN, false /* isSupported */);
runUsageSupportedForUpstreamTypeTest(UpstreamType.UT_UNKNOWN, false /* isSupported */);
}
+
+ private void runBuildNetworkTemplateForUpstreamType(final UpstreamType upstreamType,
+ final int matchRule) {
+ final NetworkTemplate template =
+ TetheringMetrics.buildNetworkTemplateForUpstreamType(upstreamType);
+ if (matchRule == MATCH_NONE) {
+ assertNull(template);
+ } else {
+ assertEquals(matchRule, template.getMatchRule());
+ }
+ }
+
+ @Test
+ @IgnoreUpTo(Build.VERSION_CODES.S_V2)
+ public void testBuildNetworkTemplateForUpstreamType() {
+ runBuildNetworkTemplateForUpstreamType(UpstreamType.UT_CELLULAR, MATCH_MOBILE);
+ runBuildNetworkTemplateForUpstreamType(UpstreamType.UT_WIFI, MATCH_WIFI);
+ runBuildNetworkTemplateForUpstreamType(UpstreamType.UT_BLUETOOTH, MATCH_BLUETOOTH);
+ runBuildNetworkTemplateForUpstreamType(UpstreamType.UT_ETHERNET, MATCH_ETHERNET);
+ runBuildNetworkTemplateForUpstreamType(UpstreamType.UT_WIFI_AWARE, MATCH_NONE);
+ runBuildNetworkTemplateForUpstreamType(UpstreamType.UT_LOWPAN, MATCH_NONE);
+ runBuildNetworkTemplateForUpstreamType(UpstreamType.UT_UNKNOWN, MATCH_NONE);
+ }
}
diff --git a/Tethering/tests/unit/src/com/android/networkstack/tethering/util/VersionedBroadcastListenerTest.java b/Tethering/tests/unit/src/com/android/networkstack/tethering/util/VersionedBroadcastListenerTest.java
index b7dc66e..ed4f3da 100644
--- a/Tethering/tests/unit/src/com/android/networkstack/tethering/util/VersionedBroadcastListenerTest.java
+++ b/Tethering/tests/unit/src/com/android/networkstack/tethering/util/VersionedBroadcastListenerTest.java
@@ -16,6 +16,8 @@
package com.android.networkstack.tethering.util;
+import static com.android.testutils.HandlerUtils.waitForIdle;
+
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.reset;
@@ -23,7 +25,7 @@
import android.content.Intent;
import android.content.IntentFilter;
import android.os.Handler;
-import android.os.Looper;
+import android.os.HandlerThread;
import android.os.UserHandle;
import androidx.test.filters.SmallTest;
@@ -33,7 +35,6 @@
import org.junit.After;
import org.junit.Before;
-import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
@@ -44,9 +45,11 @@
public class VersionedBroadcastListenerTest {
private static final String TAG = VersionedBroadcastListenerTest.class.getSimpleName();
private static final String ACTION_TEST = "action.test.happy.broadcasts";
+ private static final long TEST_TIMEOUT_MS = 10_000L;
@Mock private Context mContext;
private BroadcastInterceptingContext mServiceContext;
+ private HandlerThread mHandlerThread;
private Handler mHandler;
private VersionedBroadcastListener mListener;
private int mCallbackCount;
@@ -61,18 +64,13 @@
}
}
- @BeforeClass
- public static void setUpBeforeClass() throws Exception {
- if (Looper.myLooper() == null) {
- Looper.prepare();
- }
- }
-
@Before public void setUp() throws Exception {
MockitoAnnotations.initMocks(this);
reset(mContext);
+ mHandlerThread = new HandlerThread(TAG);
+ mHandlerThread.start();
+ mHandler = new Handler(mHandlerThread.getLooper());
mServiceContext = new MockContext(mContext);
- mHandler = new Handler(Looper.myLooper());
mCallbackCount = 0;
final IntentFilter filter = new IntentFilter();
filter.addAction(ACTION_TEST);
@@ -85,11 +83,15 @@
mListener.stopListening();
mListener = null;
}
+ mHandlerThread.quitSafely();
+ mHandlerThread.join(TEST_TIMEOUT_MS);
}
private void sendBroadcast() {
final Intent intent = new Intent(ACTION_TEST);
mServiceContext.sendStickyBroadcastAsUser(intent, UserHandle.ALL);
+ // Sending the broadcast is synchronous, but the receiver just posts on the handler
+ waitForIdle(mHandler, TEST_TIMEOUT_MS);
}
@Test
diff --git a/bpf_progs/netd.c b/bpf_progs/netd.c
index f29bbf9..f08b007 100644
--- a/bpf_progs/netd.c
+++ b/bpf_progs/netd.c
@@ -524,22 +524,12 @@
return match;
}
-// This program is optional, and enables tracing on Android U+, 5.8+ on user builds.
-DEFINE_BPF_PROG_EXT("cgroupskb/ingress/stats$trace_user", AID_ROOT, AID_SYSTEM,
- bpf_cgroup_ingress_trace_user, KVER_5_8, KVER_INF,
- BPFLOADER_MAINLINE_U_VERSION, BPFLOADER_MAX_VER, OPTIONAL,
- "fs_bpf_netd_readonly", "",
- IGNORE_ON_ENG, LOAD_ON_USER, IGNORE_ON_USERDEBUG)
-(struct __sk_buff* skb) {
- return bpf_traffic_account(skb, INGRESS, TRACE_ON, KVER_5_8);
-}
-
-// This program is required, and enables tracing on Android U+, 5.8+, userdebug/eng.
+// Tracing on Android U+ 5.8+
DEFINE_BPF_PROG_EXT("cgroupskb/ingress/stats$trace", AID_ROOT, AID_SYSTEM,
bpf_cgroup_ingress_trace, KVER_5_8, KVER_INF,
BPFLOADER_MAINLINE_U_VERSION, BPFLOADER_MAX_VER, MANDATORY,
"fs_bpf_netd_readonly", "",
- LOAD_ON_ENG, IGNORE_ON_USER, LOAD_ON_USERDEBUG)
+ LOAD_ON_ENG, LOAD_ON_USER, LOAD_ON_USERDEBUG)
(struct __sk_buff* skb) {
return bpf_traffic_account(skb, INGRESS, TRACE_ON, KVER_5_8);
}
@@ -556,22 +546,12 @@
return bpf_traffic_account(skb, INGRESS, TRACE_OFF, KVER_NONE);
}
-// This program is optional, and enables tracing on Android U+, 5.8+ on user builds.
-DEFINE_BPF_PROG_EXT("cgroupskb/egress/stats$trace_user", AID_ROOT, AID_SYSTEM,
- bpf_cgroup_egress_trace_user, KVER_5_8, KVER_INF,
- BPFLOADER_MAINLINE_U_VERSION, BPFLOADER_MAX_VER, OPTIONAL,
- "fs_bpf_netd_readonly", "",
- IGNORE_ON_ENG, LOAD_ON_USER, IGNORE_ON_USERDEBUG)
-(struct __sk_buff* skb) {
- return bpf_traffic_account(skb, EGRESS, TRACE_ON, KVER_5_8);
-}
-
-// This program is required, and enables tracing on Android U+, 5.8+, userdebug/eng.
+// Tracing on Android U+ 5.8+
DEFINE_BPF_PROG_EXT("cgroupskb/egress/stats$trace", AID_ROOT, AID_SYSTEM,
bpf_cgroup_egress_trace, KVER_5_8, KVER_INF,
BPFLOADER_MAINLINE_U_VERSION, BPFLOADER_MAX_VER, MANDATORY,
"fs_bpf_netd_readonly", "",
- LOAD_ON_ENG, IGNORE_ON_USER, LOAD_ON_USERDEBUG)
+ LOAD_ON_ENG, LOAD_ON_USER, LOAD_ON_USERDEBUG)
(struct __sk_buff* skb) {
return bpf_traffic_account(skb, EGRESS, TRACE_ON, KVER_5_8);
}
diff --git a/common/FlaggedApi.bp b/common/FlaggedApi.bp
index 21be1d3..fef9ac3 100644
--- a/common/FlaggedApi.bp
+++ b/common/FlaggedApi.bp
@@ -22,6 +22,16 @@
visibility: ["//packages/modules/Connectivity:__subpackages__"],
}
+java_aconfig_library {
+ name: "com.android.net.flags-aconfig-java",
+ aconfig_declarations: "com.android.net.flags-aconfig",
+ defaults: ["framework-minus-apex-aconfig-java-defaults"],
+ min_sdk_version: "30",
+ apex_available: [
+ "com.android.tethering",
+ ],
+}
+
aconfig_declarations {
name: "com.android.net.thread.flags-aconfig",
package: "com.android.net.thread.flags",
@@ -37,3 +47,20 @@
srcs: ["nearby_flags.aconfig"],
visibility: ["//packages/modules/Connectivity:__subpackages__"],
}
+
+aconfig_declarations {
+ name: "com.android.networksecurity.flags-aconfig",
+ package: "com.android.net.ct.flags",
+ container: "com.android.tethering",
+ srcs: ["networksecurity_flags.aconfig"],
+ visibility: ["//packages/modules/Connectivity:__subpackages__"],
+}
+
+java_aconfig_library {
+ name: "networksecurity_flags_java_lib",
+ aconfig_declarations: "com.android.networksecurity.flags-aconfig",
+ min_sdk_version: "30",
+ defaults: ["framework-minus-apex-aconfig-java-defaults"],
+ apex_available: ["com.android.tethering"],
+ visibility: ["//packages/modules/Connectivity:__subpackages__"],
+}
diff --git a/common/flags.aconfig b/common/flags.aconfig
index b320b61..1b0da4e 100644
--- a/common/flags.aconfig
+++ b/common/flags.aconfig
@@ -123,3 +123,11 @@
description: "Flag for introducing TETHERING_VIRTUAL type"
bug: "340376953"
}
+
+flag {
+ name: "netstats_add_entries"
+ is_exported: true
+ namespace: "android_core_networking"
+ description: "Flag for NetworkStats#addEntries API"
+ bug: "335680025"
+}
diff --git a/common/networksecurity_flags.aconfig b/common/networksecurity_flags.aconfig
new file mode 100644
index 0000000..ef8ffcd
--- /dev/null
+++ b/common/networksecurity_flags.aconfig
@@ -0,0 +1,9 @@
+package: "com.android.net.ct.flags"
+container: "com.android.tethering"
+flag {
+ name: "certificate_transparency_service"
+ is_exported: true
+ namespace: "network_security"
+ description: "Enable service for certificate transparency log list data"
+ bug: "319829948"
+}
diff --git a/framework-t/Android.bp b/framework-t/Android.bp
index f076f5b..ac78d09 100644
--- a/framework-t/Android.bp
+++ b/framework-t/Android.bp
@@ -137,6 +137,10 @@
// framework-connectivity-pre-jarjar match at runtime.
jarjar_rules: ":framework-connectivity-jarjar-rules",
stub_only_libs: [
+ // static_libs is not used to compile stubs. So libs which have
+ // been included in static_libs might still need to
+ // be in stub_only_libs to be usable when generating the API stubs.
+ "com.android.net.flags-aconfig-java",
// Use prebuilt framework-connectivity stubs to avoid circular dependencies
"sdk_module-lib_current_framework-connectivity",
],
diff --git a/framework-t/api/system-current.txt b/framework-t/api/system-current.txt
index 52de1a3..2354882 100644
--- a/framework-t/api/system-current.txt
+++ b/framework-t/api/system-current.txt
@@ -310,6 +310,7 @@
public final class NetworkStats implements java.lang.Iterable<android.net.NetworkStats.Entry> android.os.Parcelable {
ctor public NetworkStats(long, int);
method @NonNull public android.net.NetworkStats add(@NonNull android.net.NetworkStats);
+ method @FlaggedApi("com.android.net.flags.netstats_add_entries") @NonNull public android.net.NetworkStats addEntries(@NonNull java.util.List<android.net.NetworkStats.Entry>);
method @NonNull public android.net.NetworkStats addEntry(@NonNull android.net.NetworkStats.Entry);
method public android.net.NetworkStats clone();
method public int describeContents();
diff --git a/framework-t/src/android/net/NetworkStats.java b/framework-t/src/android/net/NetworkStats.java
index e9a3f58..a2c4fc3 100644
--- a/framework-t/src/android/net/NetworkStats.java
+++ b/framework-t/src/android/net/NetworkStats.java
@@ -18,6 +18,7 @@
import static com.android.net.module.util.NetworkStatsUtils.multiplySafeByRational;
+import android.annotation.FlaggedApi;
import android.annotation.IntDef;
import android.annotation.NonNull;
import android.annotation.Nullable;
@@ -33,6 +34,7 @@
import com.android.internal.annotations.VisibleForTesting;
import com.android.modules.utils.build.SdkLevel;
+import com.android.net.flags.Flags;
import com.android.net.module.util.CollectionUtils;
import libcore.util.EmptyArray;
@@ -845,6 +847,21 @@
}
/**
+ * Adds multiple entries to a copy of this NetworkStats instance.
+ *
+ * @param entries The entries to add.
+ * @return A new NetworkStats instance with the added entries.
+ */
+ @FlaggedApi(Flags.FLAG_NETSTATS_ADD_ENTRIES)
+ public @NonNull NetworkStats addEntries(@NonNull final List<Entry> entries) {
+ final NetworkStats newStats = this.clone();
+ for (final Entry entry : Objects.requireNonNull(entries)) {
+ newStats.combineValues(entry);
+ }
+ return newStats;
+ }
+
+ /**
* Add given values with an existing row, or create a new row if
* {@link #findIndex(String, int, int, int, int, int, int)} is unable to find match. Can
* also be used to subtract values from existing rows.
diff --git a/framework/Android.bp b/framework/Android.bp
index 282ba4b..4c4f792 100644
--- a/framework/Android.bp
+++ b/framework/Android.bp
@@ -63,6 +63,7 @@
":framework-connectivity-sources",
":net-utils-framework-common-srcs",
":framework-connectivity-api-shared-srcs",
+ ":framework-networksecurity-sources",
],
aidl: {
generate_get_transaction_name: true,
@@ -86,12 +87,14 @@
"framework-wifi.stubs.module_lib",
],
static_libs: [
+ "com.android.net.flags-aconfig-java",
// Not using the latest stable version because all functions in the latest version of
// mdns_aidl_interface are deprecated.
"mdns_aidl_interface-V1-java",
"modules-utils-backgroundthread",
"modules-utils-build",
"modules-utils-preconditions",
+ "networksecurity_flags_java_lib",
"framework-connectivity-javastream-protos",
],
impl_only_static_libs: [
@@ -206,6 +209,7 @@
},
aconfig_declarations: [
"com.android.net.flags-aconfig",
+ "com.android.networksecurity.flags-aconfig",
],
}
@@ -306,6 +310,7 @@
srcs: [
":framework-connectivity-sources",
":framework-connectivity-tiramisu-updatable-sources",
+ ":framework-networksecurity-sources",
":framework-nearby-java-sources",
":framework-thread-sources",
],
diff --git a/framework/src/android/net/ConnectivityManager.java b/framework/src/android/net/ConnectivityManager.java
index a6a967b..8cf6e04 100644
--- a/framework/src/android/net/ConnectivityManager.java
+++ b/framework/src/android/net/ConnectivityManager.java
@@ -4163,6 +4163,8 @@
*/
@FilteredCallback(methodId = METHOD_ONAVAILABLE_5ARGS,
calledByCallbackId = CALLBACK_AVAILABLE,
+ // If this list is modified, ConnectivityService#addAvailableStateUpdateCallbacks
+ // needs to be updated too.
mayCall = { METHOD_ONAVAILABLE_4ARGS,
METHOD_ONLOCALNETWORKINFOCHANGED,
METHOD_ONBLOCKEDSTATUSCHANGED_INT })
@@ -4193,6 +4195,8 @@
*/
@FilteredCallback(methodId = METHOD_ONAVAILABLE_4ARGS,
calledByCallbackId = CALLBACK_TRANSITIVE_CALLS_ONLY,
+ // If this list is modified, ConnectivityService#addAvailableStateUpdateCallbacks
+ // needs to be updated too.
mayCall = { METHOD_ONAVAILABLE_1ARG,
METHOD_ONNETWORKSUSPENDED,
METHOD_ONCAPABILITIESCHANGED,
diff --git a/netbpfload/Android.bp b/netbpfload/Android.bp
index 908bb13..b8c0ce7 100644
--- a/netbpfload/Android.bp
+++ b/netbpfload/Android.bp
@@ -48,10 +48,7 @@
"libbase",
"liblog",
],
- srcs: [
- "loader.cpp",
- "NetBpfLoad.cpp",
- ],
+ srcs: ["NetBpfLoad.cpp"],
apex_available: [
"com.android.tethering",
"//apex_available:platform",
diff --git a/netbpfload/NetBpfLoad.cpp b/netbpfload/NetBpfLoad.cpp
index 0d4a5c4..afb44cc 100644
--- a/netbpfload/NetBpfLoad.cpp
+++ b/netbpfload/NetBpfLoad.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2017-2023 The Android Open Source Project
+ * Copyright (C) 2018-2024 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -14,50 +14,1218 @@
* limitations under the License.
*/
-#ifndef LOG_TAG
#define LOG_TAG "NetBpfLoad"
-#endif
#include <arpa/inet.h>
+#include <cstdlib>
#include <dirent.h>
#include <elf.h>
+#include <errno.h>
#include <error.h>
#include <fcntl.h>
+#include <fstream>
#include <inttypes.h>
+#include <iostream>
#include <linux/bpf.h>
+#include <linux/elf.h>
#include <linux/unistd.h>
+#include <log/log.h>
#include <net/if.h>
+#include <optional>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
-#include <unistd.h>
-
+#include <string>
#include <sys/mman.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/types.h>
+#include <sys/utsname.h>
+#include <sys/wait.h>
+#include <sysexits.h>
+#include <unistd.h>
+#include <unordered_map>
+#include <vector>
-#include <android/api-level.h>
+#include <android-base/cmsg.h>
+#include <android-base/file.h>
#include <android-base/logging.h>
#include <android-base/macros.h>
#include <android-base/properties.h>
#include <android-base/stringprintf.h>
#include <android-base/strings.h>
#include <android-base/unique_fd.h>
-#include <log/log.h>
+#include <android/api-level.h>
#include "BpfSyscallWrappers.h"
#include "bpf/BpfUtils.h"
-#include "loader.h"
+#include "bpf/bpf_map_def.h"
+
+using android::base::EndsWith;
+using android::base::StartsWith;
+using android::base::unique_fd;
+using std::ifstream;
+using std::ios;
+using std::optional;
+using std::string;
+using std::vector;
namespace android {
namespace bpf {
-using base::StartsWith;
-using base::EndsWith;
-using std::string;
-using std::vector;
+// Bpf programs may specify per-program & per-map selinux_context and pin_subdir.
+//
+// The BpfLoader needs to convert these bpf.o specified strings into an enum
+// for internal use (to check that valid values were specified for the specific
+// location of the bpf.o file).
+//
+// It also needs to map selinux_context's into pin_subdir's.
+// This is because of how selinux_context is actually implemented via pin+rename.
+//
+// Thus 'domain' enumerates all selinux_context's/pin_subdir's that the BpfLoader
+// is aware of. Thus there currently needs to be a 1:1 mapping between the two.
+//
+enum class domain : int {
+ unspecified = 0, // means just use the default for that specific pin location
+ tethering, // (S+) fs_bpf_tethering /sys/fs/bpf/tethering
+ net_private, // (T+) fs_bpf_net_private /sys/fs/bpf/net_private
+ net_shared, // (T+) fs_bpf_net_shared /sys/fs/bpf/net_shared
+ netd_readonly, // (T+) fs_bpf_netd_readonly /sys/fs/bpf/netd_readonly
+ netd_shared, // (T+) fs_bpf_netd_shared /sys/fs/bpf/netd_shared
+};
+
+static constexpr domain AllDomains[] = {
+ domain::unspecified,
+ domain::tethering,
+ domain::net_private,
+ domain::net_shared,
+ domain::netd_readonly,
+ domain::netd_shared,
+};
+
+static constexpr bool specified(domain d) {
+ return d != domain::unspecified;
+}
+
+struct Location {
+ const char* const dir = "";
+ const char* const prefix = "";
+};
+
+// Returns the build type string (from ro.build.type).
+const std::string& getBuildType() {
+ static std::string t = android::base::GetProperty("ro.build.type", "unknown");
+ return t;
+}
+
+// The following functions classify the 3 Android build types.
+inline bool isEng() {
+ return getBuildType() == "eng";
+}
+
+inline bool isUser() {
+ return getBuildType() == "user";
+}
+
+inline bool isUserdebug() {
+ return getBuildType() == "userdebug";
+}
+
+#define BPF_FS_PATH "/sys/fs/bpf/"
+
+// Size of the BPF log buffer for verifier logging
+#define BPF_LOAD_LOG_SZ 0xfffff
+
+static unsigned int page_size = static_cast<unsigned int>(getpagesize());
+
+constexpr const char* lookupSelinuxContext(const domain d) {
+ switch (d) {
+ case domain::unspecified: return "";
+ case domain::tethering: return "fs_bpf_tethering";
+ case domain::net_private: return "fs_bpf_net_private";
+ case domain::net_shared: return "fs_bpf_net_shared";
+ case domain::netd_readonly: return "fs_bpf_netd_readonly";
+ case domain::netd_shared: return "fs_bpf_netd_shared";
+ }
+}
+
+domain getDomainFromSelinuxContext(const char s[BPF_SELINUX_CONTEXT_CHAR_ARRAY_SIZE]) {
+ for (domain d : AllDomains) {
+ // Not sure how to enforce this at compile time, so abort() bpfloader at boot instead
+ if (strlen(lookupSelinuxContext(d)) >= BPF_SELINUX_CONTEXT_CHAR_ARRAY_SIZE) abort();
+ if (!strncmp(s, lookupSelinuxContext(d), BPF_SELINUX_CONTEXT_CHAR_ARRAY_SIZE)) return d;
+ }
+ ALOGE("unrecognized selinux_context '%-32s'", s);
+ // Note: we *can* just abort() here as we only load bpf .o files shipped
+ // in the same mainline module / apex as NetBpfLoad itself.
+ abort();
+}
+
+constexpr const char* lookupPinSubdir(const domain d, const char* const unspecified = "") {
+ switch (d) {
+ case domain::unspecified: return unspecified;
+ case domain::tethering: return "tethering/";
+ case domain::net_private: return "net_private/";
+ case domain::net_shared: return "net_shared/";
+ case domain::netd_readonly: return "netd_readonly/";
+ case domain::netd_shared: return "netd_shared/";
+ }
+};
+
+domain getDomainFromPinSubdir(const char s[BPF_PIN_SUBDIR_CHAR_ARRAY_SIZE]) {
+ for (domain d : AllDomains) {
+ // Not sure how to enforce this at compile time, so abort() bpfloader at boot instead
+ if (strlen(lookupPinSubdir(d)) >= BPF_PIN_SUBDIR_CHAR_ARRAY_SIZE) abort();
+ if (!strncmp(s, lookupPinSubdir(d), BPF_PIN_SUBDIR_CHAR_ARRAY_SIZE)) return d;
+ }
+ ALOGE("unrecognized pin_subdir '%-32s'", s);
+ // Note: we *can* just abort() here as we only load bpf .o files shipped
+ // in the same mainline module / apex as NetBpfLoad itself.
+ abort();
+}
+
+static string pathToObjName(const string& path) {
+ // extract everything after the final slash, ie. this is the filename 'foo@1.o' or 'bar.o'
+ string filename = android::base::Split(path, "/").back();
+ // strip off everything from the final period onwards (strip '.o' suffix), ie. 'foo@1' or 'bar'
+ string name = filename.substr(0, filename.find_last_of('.'));
+ // strip any potential @1 suffix, this will leave us with just 'foo' or 'bar'
+ // this can be used to provide duplicate programs (mux based on the bpfloader version)
+ return name.substr(0, name.find_last_of('@'));
+}
+
+typedef struct {
+ const char* name;
+ enum bpf_prog_type type;
+ enum bpf_attach_type attach_type;
+} sectionType;
+
+/*
+ * Map section name prefixes to program types, the section name will be:
+ * SECTION(<prefix>/<name-of-program>)
+ * For example:
+ * SECTION("tracepoint/sched_switch_func") where sched_switch_funcs
+ * is the name of the program, and tracepoint is the type.
+ *
+ * However, be aware that you should not be directly using the SECTION() macro.
+ * Instead use the DEFINE_(BPF|XDP)_(PROG|MAP)... & LICENSE/CRITICAL macros.
+ *
+ * Programs shipped inside the tethering apex should be limited to networking stuff,
+ * as KPROBE, PERF_EVENT, TRACEPOINT are dangerous to use from mainline updatable code,
+ * since they are less stable abi/api and may conflict with platform uses of bpf.
+ */
+sectionType sectionNameTypes[] = {
+ {"bind4/", BPF_PROG_TYPE_CGROUP_SOCK_ADDR, BPF_CGROUP_INET4_BIND},
+ {"bind6/", BPF_PROG_TYPE_CGROUP_SOCK_ADDR, BPF_CGROUP_INET6_BIND},
+ {"cgroupskb/", BPF_PROG_TYPE_CGROUP_SKB},
+ {"cgroupsock/", BPF_PROG_TYPE_CGROUP_SOCK},
+ {"cgroupsockcreate/", BPF_PROG_TYPE_CGROUP_SOCK, BPF_CGROUP_INET_SOCK_CREATE},
+ {"cgroupsockrelease/", BPF_PROG_TYPE_CGROUP_SOCK, BPF_CGROUP_INET_SOCK_RELEASE},
+ {"connect4/", BPF_PROG_TYPE_CGROUP_SOCK_ADDR, BPF_CGROUP_INET4_CONNECT},
+ {"connect6/", BPF_PROG_TYPE_CGROUP_SOCK_ADDR, BPF_CGROUP_INET6_CONNECT},
+ {"egress/", BPF_PROG_TYPE_CGROUP_SKB, BPF_CGROUP_INET_EGRESS},
+ {"getsockopt/", BPF_PROG_TYPE_CGROUP_SOCKOPT, BPF_CGROUP_GETSOCKOPT},
+ {"ingress/", BPF_PROG_TYPE_CGROUP_SKB, BPF_CGROUP_INET_INGRESS},
+ {"postbind4/", BPF_PROG_TYPE_CGROUP_SOCK, BPF_CGROUP_INET4_POST_BIND},
+ {"postbind6/", BPF_PROG_TYPE_CGROUP_SOCK, BPF_CGROUP_INET6_POST_BIND},
+ {"recvmsg4/", BPF_PROG_TYPE_CGROUP_SOCK_ADDR, BPF_CGROUP_UDP4_RECVMSG},
+ {"recvmsg6/", BPF_PROG_TYPE_CGROUP_SOCK_ADDR, BPF_CGROUP_UDP6_RECVMSG},
+ {"schedact/", BPF_PROG_TYPE_SCHED_ACT},
+ {"schedcls/", BPF_PROG_TYPE_SCHED_CLS},
+ {"sendmsg4/", BPF_PROG_TYPE_CGROUP_SOCK_ADDR, BPF_CGROUP_UDP4_SENDMSG},
+ {"sendmsg6/", BPF_PROG_TYPE_CGROUP_SOCK_ADDR, BPF_CGROUP_UDP6_SENDMSG},
+ {"setsockopt/", BPF_PROG_TYPE_CGROUP_SOCKOPT, BPF_CGROUP_SETSOCKOPT},
+ {"skfilter/", BPF_PROG_TYPE_SOCKET_FILTER},
+ {"sockops/", BPF_PROG_TYPE_SOCK_OPS, BPF_CGROUP_SOCK_OPS},
+ {"sysctl", BPF_PROG_TYPE_CGROUP_SYSCTL, BPF_CGROUP_SYSCTL},
+ {"xdp/", BPF_PROG_TYPE_XDP},
+};
+
+typedef struct {
+ enum bpf_prog_type type;
+ enum bpf_attach_type attach_type;
+ string name;
+ vector<char> data;
+ vector<char> rel_data;
+ optional<struct bpf_prog_def> prog_def;
+
+ unique_fd prog_fd; /* fd after loading */
+} codeSection;
+
+static int readElfHeader(ifstream& elfFile, Elf64_Ehdr* eh) {
+ elfFile.seekg(0);
+ if (elfFile.fail()) return -1;
+
+ if (!elfFile.read((char*)eh, sizeof(*eh))) return -1;
+
+ return 0;
+}
+
+/* Reads all section header tables into an Shdr array */
+static int readSectionHeadersAll(ifstream& elfFile, vector<Elf64_Shdr>& shTable) {
+ Elf64_Ehdr eh;
+ int ret = 0;
+
+ ret = readElfHeader(elfFile, &eh);
+ if (ret) return ret;
+
+ elfFile.seekg(eh.e_shoff);
+ if (elfFile.fail()) return -1;
+
+ /* Read shdr table entries */
+ shTable.resize(eh.e_shnum);
+
+ if (!elfFile.read((char*)shTable.data(), (eh.e_shnum * eh.e_shentsize))) return -ENOMEM;
+
+ return 0;
+}
+
+/* Read a section by its index - for ex to get sec hdr strtab blob */
+static int readSectionByIdx(ifstream& elfFile, int id, vector<char>& sec) {
+ vector<Elf64_Shdr> shTable;
+ int ret = readSectionHeadersAll(elfFile, shTable);
+ if (ret) return ret;
+
+ elfFile.seekg(shTable[id].sh_offset);
+ if (elfFile.fail()) return -1;
+
+ sec.resize(shTable[id].sh_size);
+ if (!elfFile.read(sec.data(), shTable[id].sh_size)) return -1;
+
+ return 0;
+}
+
+/* Read whole section header string table */
+static int readSectionHeaderStrtab(ifstream& elfFile, vector<char>& strtab) {
+ Elf64_Ehdr eh;
+ int ret = readElfHeader(elfFile, &eh);
+ if (ret) return ret;
+
+ ret = readSectionByIdx(elfFile, eh.e_shstrndx, strtab);
+ if (ret) return ret;
+
+ return 0;
+}
+
+/* Get name from offset in strtab */
+static int getSymName(ifstream& elfFile, int nameOff, string& name) {
+ int ret;
+ vector<char> secStrTab;
+
+ ret = readSectionHeaderStrtab(elfFile, secStrTab);
+ if (ret) return ret;
+
+ if (nameOff >= (int)secStrTab.size()) return -1;
+
+ name = string((char*)secStrTab.data() + nameOff);
+ return 0;
+}
+
+/* Reads a full section by name - example to get the GPL license */
+static int readSectionByName(const char* name, ifstream& elfFile, vector<char>& data) {
+ vector<char> secStrTab;
+ vector<Elf64_Shdr> shTable;
+ int ret;
+
+ ret = readSectionHeadersAll(elfFile, shTable);
+ if (ret) return ret;
+
+ ret = readSectionHeaderStrtab(elfFile, secStrTab);
+ if (ret) return ret;
+
+ for (int i = 0; i < (int)shTable.size(); i++) {
+ char* secname = secStrTab.data() + shTable[i].sh_name;
+ if (!secname) continue;
+
+ if (!strcmp(secname, name)) {
+ vector<char> dataTmp;
+ dataTmp.resize(shTable[i].sh_size);
+
+ elfFile.seekg(shTable[i].sh_offset);
+ if (elfFile.fail()) return -1;
+
+ if (!elfFile.read((char*)dataTmp.data(), shTable[i].sh_size)) return -1;
+
+ data = dataTmp;
+ return 0;
+ }
+ }
+ return -2;
+}
+
+unsigned int readSectionUint(const char* name, ifstream& elfFile, unsigned int defVal) {
+ vector<char> theBytes;
+ int ret = readSectionByName(name, elfFile, theBytes);
+ if (ret) {
+ ALOGD("Couldn't find section %s (defaulting to %u [0x%x]).", name, defVal, defVal);
+ return defVal;
+ } else if (theBytes.size() < sizeof(unsigned int)) {
+ ALOGE("Section %s too short (defaulting to %u [0x%x]).", name, defVal, defVal);
+ return defVal;
+ } else {
+ // decode first 4 bytes as LE32 uint, there will likely be more bytes due to alignment.
+ unsigned int value = static_cast<unsigned char>(theBytes[3]);
+ value <<= 8;
+ value += static_cast<unsigned char>(theBytes[2]);
+ value <<= 8;
+ value += static_cast<unsigned char>(theBytes[1]);
+ value <<= 8;
+ value += static_cast<unsigned char>(theBytes[0]);
+ ALOGI("Section %s value is %u [0x%x]", name, value, value);
+ return value;
+ }
+}
+
+static int readSectionByType(ifstream& elfFile, int type, vector<char>& data) {
+ int ret;
+ vector<Elf64_Shdr> shTable;
+
+ ret = readSectionHeadersAll(elfFile, shTable);
+ if (ret) return ret;
+
+ for (int i = 0; i < (int)shTable.size(); i++) {
+ if ((int)shTable[i].sh_type != type) continue;
+
+ vector<char> dataTmp;
+ dataTmp.resize(shTable[i].sh_size);
+
+ elfFile.seekg(shTable[i].sh_offset);
+ if (elfFile.fail()) return -1;
+
+ if (!elfFile.read((char*)dataTmp.data(), shTable[i].sh_size)) return -1;
+
+ data = dataTmp;
+ return 0;
+ }
+ return -2;
+}
+
+static bool symCompare(Elf64_Sym a, Elf64_Sym b) {
+ return (a.st_value < b.st_value);
+}
+
+static int readSymTab(ifstream& elfFile, int sort, vector<Elf64_Sym>& data) {
+ int ret, numElems;
+ Elf64_Sym* buf;
+ vector<char> secData;
+
+ ret = readSectionByType(elfFile, SHT_SYMTAB, secData);
+ if (ret) return ret;
+
+ buf = (Elf64_Sym*)secData.data();
+ numElems = (secData.size() / sizeof(Elf64_Sym));
+ data.assign(buf, buf + numElems);
+
+ if (sort) std::sort(data.begin(), data.end(), symCompare);
+ return 0;
+}
+
+static enum bpf_prog_type getSectionType(string& name) {
+ for (auto& snt : sectionNameTypes)
+ if (StartsWith(name, snt.name)) return snt.type;
+
+ return BPF_PROG_TYPE_UNSPEC;
+}
+
+/*
+static string getSectionName(enum bpf_prog_type type)
+{
+ for (auto& snt : sectionNameTypes)
+ if (snt.type == type)
+ return string(snt.name);
+
+ return "UNKNOWN SECTION NAME " + std::to_string(type);
+}
+*/
+
+static int readProgDefs(ifstream& elfFile, vector<struct bpf_prog_def>& pd,
+ size_t sizeOfBpfProgDef) {
+ vector<char> pdData;
+ int ret = readSectionByName("progs", elfFile, pdData);
+ if (ret) return ret;
+
+ if (pdData.size() % sizeOfBpfProgDef) {
+ ALOGE("readProgDefs failed due to improper sized progs section, %zu %% %zu != 0",
+ pdData.size(), sizeOfBpfProgDef);
+ return -1;
+ };
+
+ int progCount = pdData.size() / sizeOfBpfProgDef;
+ pd.resize(progCount);
+ size_t trimmedSize = std::min(sizeOfBpfProgDef, sizeof(struct bpf_prog_def));
+
+ const char* dataPtr = pdData.data();
+ for (auto& p : pd) {
+ // First we zero initialize
+ memset(&p, 0, sizeof(p));
+ // Then we set non-zero defaults
+ p.bpfloader_max_ver = DEFAULT_BPFLOADER_MAX_VER; // v1.0
+ // Then we copy over the structure prefix from the ELF file.
+ memcpy(&p, dataPtr, trimmedSize);
+ // Move to next struct in the ELF file
+ dataPtr += sizeOfBpfProgDef;
+ }
+ return 0;
+}
+
+static int getSectionSymNames(ifstream& elfFile, const string& sectionName, vector<string>& names,
+ optional<unsigned> symbolType = std::nullopt) {
+ int ret;
+ string name;
+ vector<Elf64_Sym> symtab;
+ vector<Elf64_Shdr> shTable;
+
+ ret = readSymTab(elfFile, 1 /* sort */, symtab);
+ if (ret) return ret;
+
+ /* Get index of section */
+ ret = readSectionHeadersAll(elfFile, shTable);
+ if (ret) return ret;
+
+ int sec_idx = -1;
+ for (int i = 0; i < (int)shTable.size(); i++) {
+ ret = getSymName(elfFile, shTable[i].sh_name, name);
+ if (ret) return ret;
+
+ if (!name.compare(sectionName)) {
+ sec_idx = i;
+ break;
+ }
+ }
+
+ /* No section found with matching name*/
+ if (sec_idx == -1) {
+ ALOGW("No %s section could be found in elf object", sectionName.c_str());
+ return -1;
+ }
+
+ for (int i = 0; i < (int)symtab.size(); i++) {
+ if (symbolType.has_value() && ELF_ST_TYPE(symtab[i].st_info) != symbolType) continue;
+
+ if (symtab[i].st_shndx == sec_idx) {
+ string s;
+ ret = getSymName(elfFile, symtab[i].st_name, s);
+ if (ret) return ret;
+ names.push_back(s);
+ }
+ }
+
+ return 0;
+}
+
+/* Read a section by its index - for ex to get sec hdr strtab blob */
+static int readCodeSections(ifstream& elfFile, vector<codeSection>& cs, size_t sizeOfBpfProgDef) {
+ vector<Elf64_Shdr> shTable;
+ int entries, ret = 0;
+
+ ret = readSectionHeadersAll(elfFile, shTable);
+ if (ret) return ret;
+ entries = shTable.size();
+
+ vector<struct bpf_prog_def> pd;
+ ret = readProgDefs(elfFile, pd, sizeOfBpfProgDef);
+ if (ret) return ret;
+ vector<string> progDefNames;
+ ret = getSectionSymNames(elfFile, "progs", progDefNames);
+ if (!pd.empty() && ret) return ret;
+
+ for (int i = 0; i < entries; i++) {
+ string name;
+ codeSection cs_temp;
+ cs_temp.type = BPF_PROG_TYPE_UNSPEC;
+
+ ret = getSymName(elfFile, shTable[i].sh_name, name);
+ if (ret) return ret;
+
+ enum bpf_prog_type ptype = getSectionType(name);
+
+ if (ptype == BPF_PROG_TYPE_UNSPEC) continue;
+
+ // This must be done before '/' is replaced with '_'.
+ for (auto& snt : sectionNameTypes)
+ if (StartsWith(name, snt.name)) cs_temp.attach_type = snt.attach_type;
+
+ string oldName = name;
+
+ // convert all slashes to underscores
+ std::replace(name.begin(), name.end(), '/', '_');
+
+ cs_temp.type = ptype;
+ cs_temp.name = name;
+
+ ret = readSectionByIdx(elfFile, i, cs_temp.data);
+ if (ret) return ret;
+ ALOGV("Loaded code section %d (%s)", i, name.c_str());
+
+ vector<string> csSymNames;
+ ret = getSectionSymNames(elfFile, oldName, csSymNames, STT_FUNC);
+ if (ret || !csSymNames.size()) return ret;
+ for (size_t i = 0; i < progDefNames.size(); ++i) {
+ if (!progDefNames[i].compare(csSymNames[0] + "_def")) {
+ cs_temp.prog_def = pd[i];
+ break;
+ }
+ }
+
+ /* Check for rel section */
+ if (cs_temp.data.size() > 0 && i < entries) {
+ ret = getSymName(elfFile, shTable[i + 1].sh_name, name);
+ if (ret) return ret;
+
+ if (name == (".rel" + oldName)) {
+ ret = readSectionByIdx(elfFile, i + 1, cs_temp.rel_data);
+ if (ret) return ret;
+ ALOGV("Loaded relo section %d (%s)", i, name.c_str());
+ }
+ }
+
+ if (cs_temp.data.size() > 0) {
+ cs.push_back(std::move(cs_temp));
+ ALOGV("Adding section %d to cs list", i);
+ }
+ }
+ return 0;
+}
+
+static int getSymNameByIdx(ifstream& elfFile, int index, string& name) {
+ vector<Elf64_Sym> symtab;
+ int ret = 0;
+
+ ret = readSymTab(elfFile, 0 /* !sort */, symtab);
+ if (ret) return ret;
+
+ if (index >= (int)symtab.size()) return -1;
+
+ return getSymName(elfFile, symtab[index].st_name, name);
+}
+
+static bool mapMatchesExpectations(const unique_fd& fd, const string& mapName,
+ const struct bpf_map_def& mapDef, const enum bpf_map_type type) {
+ // bpfGetFd... family of functions require at minimum a 4.14 kernel,
+ // so on 4.9-T kernels just pretend the map matches our expectations.
+ // Additionally we'll get almost equivalent test coverage on newer devices/kernels.
+ // This is because the primary failure mode we're trying to detect here
+ // is either a source code misconfiguration (which is likely kernel independent)
+ // or a newly introduced kernel feature/bug (which is unlikely to get backported to 4.9).
+ if (!isAtLeastKernelVersion(4, 14, 0)) return true;
+
+ // Assuming fd is a valid Bpf Map file descriptor then
+ // all the following should always succeed on a 4.14+ kernel.
+ // If they somehow do fail, they'll return -1 (and set errno),
+ // which should then cause (among others) a key_size mismatch.
+ int fd_type = bpfGetFdMapType(fd);
+ int fd_key_size = bpfGetFdKeySize(fd);
+ int fd_value_size = bpfGetFdValueSize(fd);
+ int fd_max_entries = bpfGetFdMaxEntries(fd);
+ int fd_map_flags = bpfGetFdMapFlags(fd);
+
+ // DEVMAPs are readonly from the bpf program side's point of view, as such
+ // the kernel in kernel/bpf/devmap.c dev_map_init_map() will set the flag
+ int desired_map_flags = (int)mapDef.map_flags;
+ if (type == BPF_MAP_TYPE_DEVMAP || type == BPF_MAP_TYPE_DEVMAP_HASH)
+ desired_map_flags |= BPF_F_RDONLY_PROG;
+
+ // The .h file enforces that this is a power of two, and page size will
+ // also always be a power of two, so this logic is actually enough to
+ // force it to be a multiple of the page size, as required by the kernel.
+ unsigned int desired_max_entries = mapDef.max_entries;
+ if (type == BPF_MAP_TYPE_RINGBUF) {
+ if (desired_max_entries < page_size) desired_max_entries = page_size;
+ }
+
+ // The following checks should *never* trigger, if one of them somehow does,
+ // it probably means a bpf .o file has been changed/replaced at runtime
+ // and bpfloader was manually rerun (normally it should only run *once*
+ // early during the boot process).
+ // Another possibility is that something is misconfigured in the code:
+ // most likely a shared map is declared twice differently.
+ // But such a change should never be checked into the source tree...
+ if ((fd_type == type) &&
+ (fd_key_size == (int)mapDef.key_size) &&
+ (fd_value_size == (int)mapDef.value_size) &&
+ (fd_max_entries == (int)desired_max_entries) &&
+ (fd_map_flags == desired_map_flags)) {
+ return true;
+ }
+
+ ALOGE("bpf map name %s mismatch: desired/found: "
+ "type:%d/%d key:%u/%d value:%u/%d entries:%u/%d flags:%u/%d",
+ mapName.c_str(), type, fd_type, mapDef.key_size, fd_key_size, mapDef.value_size,
+ fd_value_size, mapDef.max_entries, fd_max_entries, desired_map_flags, fd_map_flags);
+ return false;
+}
+
+static int createMaps(const char* elfPath, ifstream& elfFile, vector<unique_fd>& mapFds,
+ const char* prefix, const size_t sizeOfBpfMapDef,
+ const unsigned int bpfloader_ver) {
+ int ret;
+ vector<char> mdData;
+ vector<struct bpf_map_def> md;
+ vector<string> mapNames;
+ string objName = pathToObjName(string(elfPath));
+
+ ret = readSectionByName("maps", elfFile, mdData);
+ if (ret == -2) return 0; // no maps to read
+ if (ret) return ret;
+
+ if (mdData.size() % sizeOfBpfMapDef) {
+ ALOGE("createMaps failed due to improper sized maps section, %zu %% %zu != 0",
+ mdData.size(), sizeOfBpfMapDef);
+ return -1;
+ };
+
+ int mapCount = mdData.size() / sizeOfBpfMapDef;
+ md.resize(mapCount);
+ size_t trimmedSize = std::min(sizeOfBpfMapDef, sizeof(struct bpf_map_def));
+
+ const char* dataPtr = mdData.data();
+ for (auto& m : md) {
+ // First we zero initialize
+ memset(&m, 0, sizeof(m));
+ // Then we set non-zero defaults
+ m.bpfloader_max_ver = DEFAULT_BPFLOADER_MAX_VER; // v1.0
+ m.max_kver = 0xFFFFFFFFu; // matches KVER_INF from bpf_helpers.h
+ // Then we copy over the structure prefix from the ELF file.
+ memcpy(&m, dataPtr, trimmedSize);
+ // Move to next struct in the ELF file
+ dataPtr += sizeOfBpfMapDef;
+ }
+
+ ret = getSectionSymNames(elfFile, "maps", mapNames);
+ if (ret) return ret;
+
+ unsigned kvers = kernelVersion();
+
+ for (int i = 0; i < (int)mapNames.size(); i++) {
+ if (md[i].zero != 0) abort();
+
+ if (bpfloader_ver < md[i].bpfloader_min_ver) {
+ ALOGI("skipping map %s which requires bpfloader min ver 0x%05x", mapNames[i].c_str(),
+ md[i].bpfloader_min_ver);
+ mapFds.push_back(unique_fd());
+ continue;
+ }
+
+ if (bpfloader_ver >= md[i].bpfloader_max_ver) {
+ ALOGI("skipping map %s which requires bpfloader max ver 0x%05x", mapNames[i].c_str(),
+ md[i].bpfloader_max_ver);
+ mapFds.push_back(unique_fd());
+ continue;
+ }
+
+ if (kvers < md[i].min_kver) {
+ ALOGI("skipping map %s which requires kernel version 0x%x >= 0x%x",
+ mapNames[i].c_str(), kvers, md[i].min_kver);
+ mapFds.push_back(unique_fd());
+ continue;
+ }
+
+ if (kvers >= md[i].max_kver) {
+ ALOGI("skipping map %s which requires kernel version 0x%x < 0x%x",
+ mapNames[i].c_str(), kvers, md[i].max_kver);
+ mapFds.push_back(unique_fd());
+ continue;
+ }
+
+ if ((md[i].ignore_on_eng && isEng()) || (md[i].ignore_on_user && isUser()) ||
+ (md[i].ignore_on_userdebug && isUserdebug())) {
+ ALOGI("skipping map %s which is ignored on %s builds", mapNames[i].c_str(),
+ getBuildType().c_str());
+ mapFds.push_back(unique_fd());
+ continue;
+ }
+
+ if ((isArm() && isKernel32Bit() && md[i].ignore_on_arm32) ||
+ (isArm() && isKernel64Bit() && md[i].ignore_on_aarch64) ||
+ (isX86() && isKernel32Bit() && md[i].ignore_on_x86_32) ||
+ (isX86() && isKernel64Bit() && md[i].ignore_on_x86_64) ||
+ (isRiscV() && md[i].ignore_on_riscv64)) {
+ ALOGI("skipping map %s which is ignored on %s", mapNames[i].c_str(),
+ describeArch());
+ mapFds.push_back(unique_fd());
+ continue;
+ }
+
+ enum bpf_map_type type = md[i].type;
+ if (type == BPF_MAP_TYPE_DEVMAP && !isAtLeastKernelVersion(4, 14, 0)) {
+ // On Linux Kernels older than 4.14 this map type doesn't exist, but it can kind
+ // of be approximated: ARRAY has the same userspace api, though it is not usable
+ // by the same ebpf programs. However, that's okay because the bpf_redirect_map()
+ // helper doesn't exist on 4.9-T anyway (so the bpf program would fail to load,
+ // and thus needs to be tagged as 4.14+ either way), so there's nothing useful you
+ // could do with a DEVMAP anyway (that isn't already provided by an ARRAY)...
+ // Hence using an ARRAY instead of a DEVMAP simply makes life easier for userspace.
+ type = BPF_MAP_TYPE_ARRAY;
+ }
+ if (type == BPF_MAP_TYPE_DEVMAP_HASH && !isAtLeastKernelVersion(5, 4, 0)) {
+ // On Linux Kernels older than 5.4 this map type doesn't exist, but it can kind
+ // of be approximated: HASH has the same userspace visible api.
+ // However it cannot be used by ebpf programs in the same way.
+ // Since bpf_redirect_map() only requires 4.14, a program using a DEVMAP_HASH map
+ // would fail to load (due to trying to redirect to a HASH instead of DEVMAP_HASH).
+ // One must thus tag any BPF_MAP_TYPE_DEVMAP_HASH + bpf_redirect_map() using
+ // programs as being 5.4+...
+ type = BPF_MAP_TYPE_HASH;
+ }
+
+ // The .h file enforces that this is a power of two, and page size will
+ // also always be a power of two, so this logic is actually enough to
+ // force it to be a multiple of the page size, as required by the kernel.
+ unsigned int max_entries = md[i].max_entries;
+ if (type == BPF_MAP_TYPE_RINGBUF) {
+ if (max_entries < page_size) max_entries = page_size;
+ }
+
+ domain selinux_context = getDomainFromSelinuxContext(md[i].selinux_context);
+ if (specified(selinux_context)) {
+ ALOGI("map %s selinux_context [%-32s] -> %d -> '%s' (%s)", mapNames[i].c_str(),
+ md[i].selinux_context, static_cast<int>(selinux_context),
+ lookupSelinuxContext(selinux_context), lookupPinSubdir(selinux_context));
+ }
+
+ domain pin_subdir = getDomainFromPinSubdir(md[i].pin_subdir);
+ if (specified(pin_subdir)) {
+ ALOGI("map %s pin_subdir [%-32s] -> %d -> '%s'", mapNames[i].c_str(), md[i].pin_subdir,
+ static_cast<int>(pin_subdir), lookupPinSubdir(pin_subdir));
+ }
+
+ // Format of pin location is /sys/fs/bpf/<pin_subdir|prefix>map_<objName>_<mapName>
+ // except that maps shared across .o's have empty <objName>
+ // Note: <objName> refers to the extension-less basename of the .o file (without @ suffix).
+ string mapPinLoc = string(BPF_FS_PATH) + lookupPinSubdir(pin_subdir, prefix) + "map_" +
+ (md[i].shared ? "" : objName) + "_" + mapNames[i];
+ bool reuse = false;
+ unique_fd fd;
+ int saved_errno;
+
+ if (access(mapPinLoc.c_str(), F_OK) == 0) {
+ fd.reset(mapRetrieveRO(mapPinLoc.c_str()));
+ saved_errno = errno;
+ ALOGD("bpf_create_map reusing map %s, ret: %d", mapNames[i].c_str(), fd.get());
+ reuse = true;
+ } else {
+ union bpf_attr req = {
+ .map_type = type,
+ .key_size = md[i].key_size,
+ .value_size = md[i].value_size,
+ .max_entries = max_entries,
+ .map_flags = md[i].map_flags,
+ };
+ if (isAtLeastKernelVersion(4, 15, 0))
+ strlcpy(req.map_name, mapNames[i].c_str(), sizeof(req.map_name));
+ fd.reset(bpf(BPF_MAP_CREATE, req));
+ saved_errno = errno;
+ ALOGD("bpf_create_map name %s, ret: %d", mapNames[i].c_str(), fd.get());
+ }
+
+ if (!fd.ok()) return -saved_errno;
+
+ // When reusing a pinned map, we need to check the map type/sizes/etc match, but for
+ // safety (since reuse code path is rare) run these checks even if we just created it.
+ // We assume failure is due to pinned map mismatch, hence the 'NOT UNIQUE' return code.
+ if (!mapMatchesExpectations(fd, mapNames[i], md[i], type)) return -ENOTUNIQ;
+
+ if (!reuse) {
+ if (specified(selinux_context)) {
+ string createLoc = string(BPF_FS_PATH) + lookupPinSubdir(selinux_context) +
+ "tmp_map_" + objName + "_" + mapNames[i];
+ ret = bpfFdPin(fd, createLoc.c_str());
+ if (ret) {
+ int err = errno;
+ ALOGE("create %s -> %d [%d:%s]", createLoc.c_str(), ret, err, strerror(err));
+ return -err;
+ }
+ ret = renameat2(AT_FDCWD, createLoc.c_str(),
+ AT_FDCWD, mapPinLoc.c_str(), RENAME_NOREPLACE);
+ if (ret) {
+ int err = errno;
+ ALOGE("rename %s %s -> %d [%d:%s]", createLoc.c_str(), mapPinLoc.c_str(), ret,
+ err, strerror(err));
+ return -err;
+ }
+ } else {
+ ret = bpfFdPin(fd, mapPinLoc.c_str());
+ if (ret) {
+ int err = errno;
+ ALOGE("pin %s -> %d [%d:%s]", mapPinLoc.c_str(), ret, err, strerror(err));
+ return -err;
+ }
+ }
+ ret = chmod(mapPinLoc.c_str(), md[i].mode);
+ if (ret) {
+ int err = errno;
+ ALOGE("chmod(%s, 0%o) = %d [%d:%s]", mapPinLoc.c_str(), md[i].mode, ret, err,
+ strerror(err));
+ return -err;
+ }
+ ret = chown(mapPinLoc.c_str(), (uid_t)md[i].uid, (gid_t)md[i].gid);
+ if (ret) {
+ int err = errno;
+ ALOGE("chown(%s, %u, %u) = %d [%d:%s]", mapPinLoc.c_str(), md[i].uid, md[i].gid,
+ ret, err, strerror(err));
+ return -err;
+ }
+ }
+
+ int mapId = bpfGetFdMapId(fd);
+ if (mapId == -1) {
+ ALOGE("bpfGetFdMapId failed, ret: %d [%d]", mapId, errno);
+ } else {
+ ALOGI("map %s id %d", mapPinLoc.c_str(), mapId);
+ }
+
+ mapFds.push_back(std::move(fd));
+ }
+
+ return ret;
+}
+
+/* For debugging, dump all instructions */
+static void dumpIns(char* ins, int size) {
+ for (int row = 0; row < size / 8; row++) {
+ ALOGE("%d: ", row);
+ for (int j = 0; j < 8; j++) {
+ ALOGE("%3x ", ins[(row * 8) + j]);
+ }
+ ALOGE("\n");
+ }
+}
+
+/* For debugging, dump all code sections from cs list */
+static void dumpAllCs(vector<codeSection>& cs) {
+ for (int i = 0; i < (int)cs.size(); i++) {
+ ALOGE("Dumping cs %d, name %s", int(i), cs[i].name.c_str());
+ dumpIns((char*)cs[i].data.data(), cs[i].data.size());
+ ALOGE("-----------");
+ }
+}
+
+static void applyRelo(void* insnsPtr, Elf64_Addr offset, int fd) {
+ int insnIndex;
+ struct bpf_insn *insn, *insns;
+
+ insns = (struct bpf_insn*)(insnsPtr);
+
+ insnIndex = offset / sizeof(struct bpf_insn);
+ insn = &insns[insnIndex];
+
+ // Occasionally might be useful for relocation debugging, but pretty spammy
+ if (0) {
+ ALOGV("applying relo to instruction at byte offset: %llu, "
+ "insn offset %d, insn %llx",
+ (unsigned long long)offset, insnIndex, *(unsigned long long*)insn);
+ }
+
+ if (insn->code != (BPF_LD | BPF_IMM | BPF_DW)) {
+ ALOGE("Dumping all instructions till ins %d", insnIndex);
+ ALOGE("invalid relo for insn %d: code 0x%x", insnIndex, insn->code);
+ dumpIns((char*)insnsPtr, (insnIndex + 3) * 8);
+ return;
+ }
+
+ insn->imm = fd;
+ insn->src_reg = BPF_PSEUDO_MAP_FD;
+}
+
+static void applyMapRelo(ifstream& elfFile, vector<unique_fd> &mapFds, vector<codeSection>& cs) {
+ vector<string> mapNames;
+
+ int ret = getSectionSymNames(elfFile, "maps", mapNames);
+ if (ret) return;
+
+ for (int k = 0; k != (int)cs.size(); k++) {
+ Elf64_Rel* rel = (Elf64_Rel*)(cs[k].rel_data.data());
+ int n_rel = cs[k].rel_data.size() / sizeof(*rel);
+
+ for (int i = 0; i < n_rel; i++) {
+ int symIndex = ELF64_R_SYM(rel[i].r_info);
+ string symName;
+
+ ret = getSymNameByIdx(elfFile, symIndex, symName);
+ if (ret) return;
+
+ /* Find the map fd and apply relo */
+ for (int j = 0; j < (int)mapNames.size(); j++) {
+ if (!mapNames[j].compare(symName)) {
+ applyRelo(cs[k].data.data(), rel[i].r_offset, mapFds[j]);
+ break;
+ }
+ }
+ }
+ }
+}
+
+static int loadCodeSections(const char* elfPath, vector<codeSection>& cs, const string& license,
+ const char* prefix, const unsigned int bpfloader_ver) {
+ unsigned kvers = kernelVersion();
+
+ if (!kvers) {
+ ALOGE("unable to get kernel version");
+ return -EINVAL;
+ }
+
+ string objName = pathToObjName(string(elfPath));
+
+ for (int i = 0; i < (int)cs.size(); i++) {
+ unique_fd& fd = cs[i].prog_fd;
+ int ret;
+ string name = cs[i].name;
+
+ if (!cs[i].prog_def.has_value()) {
+ ALOGE("[%d] '%s' missing program definition! bad bpf.o build?", i, name.c_str());
+ return -EINVAL;
+ }
+
+ unsigned min_kver = cs[i].prog_def->min_kver;
+ unsigned max_kver = cs[i].prog_def->max_kver;
+ ALOGD("cs[%d].name:%s min_kver:%x .max_kver:%x (kvers:%x)", i, name.c_str(), min_kver,
+ max_kver, kvers);
+ if (kvers < min_kver) continue;
+ if (kvers >= max_kver) continue;
+
+ unsigned bpfMinVer = cs[i].prog_def->bpfloader_min_ver;
+ unsigned bpfMaxVer = cs[i].prog_def->bpfloader_max_ver;
+ domain selinux_context = getDomainFromSelinuxContext(cs[i].prog_def->selinux_context);
+ domain pin_subdir = getDomainFromPinSubdir(cs[i].prog_def->pin_subdir);
+
+ ALOGD("cs[%d].name:%s requires bpfloader version [0x%05x,0x%05x)", i, name.c_str(),
+ bpfMinVer, bpfMaxVer);
+ if (bpfloader_ver < bpfMinVer) continue;
+ if (bpfloader_ver >= bpfMaxVer) continue;
+
+ if ((cs[i].prog_def->ignore_on_eng && isEng()) ||
+ (cs[i].prog_def->ignore_on_user && isUser()) ||
+ (cs[i].prog_def->ignore_on_userdebug && isUserdebug())) {
+ ALOGD("cs[%d].name:%s is ignored on %s builds", i, name.c_str(),
+ getBuildType().c_str());
+ continue;
+ }
+
+ if ((isArm() && isKernel32Bit() && cs[i].prog_def->ignore_on_arm32) ||
+ (isArm() && isKernel64Bit() && cs[i].prog_def->ignore_on_aarch64) ||
+ (isX86() && isKernel32Bit() && cs[i].prog_def->ignore_on_x86_32) ||
+ (isX86() && isKernel64Bit() && cs[i].prog_def->ignore_on_x86_64) ||
+ (isRiscV() && cs[i].prog_def->ignore_on_riscv64)) {
+ ALOGD("cs[%d].name:%s is ignored on %s", i, name.c_str(), describeArch());
+ continue;
+ }
+
+ if (specified(selinux_context)) {
+ ALOGI("prog %s selinux_context [%-32s] -> %d -> '%s' (%s)", name.c_str(),
+ cs[i].prog_def->selinux_context, static_cast<int>(selinux_context),
+ lookupSelinuxContext(selinux_context), lookupPinSubdir(selinux_context));
+ }
+
+ if (specified(pin_subdir)) {
+ ALOGI("prog %s pin_subdir [%-32s] -> %d -> '%s'", name.c_str(),
+ cs[i].prog_def->pin_subdir, static_cast<int>(pin_subdir),
+ lookupPinSubdir(pin_subdir));
+ }
+
+ // strip any potential $foo suffix
+ // this can be used to provide duplicate programs
+ // conditionally loaded based on running kernel version
+ name = name.substr(0, name.find_last_of('$'));
+
+ bool reuse = false;
+ // Format of pin location is
+ // /sys/fs/bpf/<prefix>prog_<objName>_<progName>
+ string progPinLoc = string(BPF_FS_PATH) + lookupPinSubdir(pin_subdir, prefix) + "prog_" +
+ objName + '_' + string(name);
+ if (access(progPinLoc.c_str(), F_OK) == 0) {
+ fd.reset(retrieveProgram(progPinLoc.c_str()));
+ ALOGD("New bpf prog load reusing prog %s, ret: %d (%s)", progPinLoc.c_str(), fd.get(),
+ (!fd.ok() ? std::strerror(errno) : "no error"));
+ reuse = true;
+ } else {
+ vector<char> log_buf(BPF_LOAD_LOG_SZ, 0);
+
+ union bpf_attr req = {
+ .prog_type = cs[i].type,
+ .kern_version = kvers,
+ .license = ptr_to_u64(license.c_str()),
+ .insns = ptr_to_u64(cs[i].data.data()),
+ .insn_cnt = static_cast<__u32>(cs[i].data.size() / sizeof(struct bpf_insn)),
+ .log_level = 1,
+ .log_buf = ptr_to_u64(log_buf.data()),
+ .log_size = static_cast<__u32>(log_buf.size()),
+ .expected_attach_type = cs[i].attach_type,
+ };
+ if (isAtLeastKernelVersion(4, 15, 0))
+ strlcpy(req.prog_name, cs[i].name.c_str(), sizeof(req.prog_name));
+ fd.reset(bpf(BPF_PROG_LOAD, req));
+
+ ALOGD("BPF_PROG_LOAD call for %s (%s) returned fd: %d (%s)", elfPath,
+ cs[i].name.c_str(), fd.get(), (!fd.ok() ? std::strerror(errno) : "no error"));
+
+ if (!fd.ok()) {
+ vector<string> lines = android::base::Split(log_buf.data(), "\n");
+
+ ALOGW("BPF_PROG_LOAD - BEGIN log_buf contents:");
+ for (const auto& line : lines) ALOGW("%s", line.c_str());
+ ALOGW("BPF_PROG_LOAD - END log_buf contents.");
+
+ if (cs[i].prog_def->optional) {
+ ALOGW("failed program is marked optional - continuing...");
+ continue;
+ }
+ ALOGE("non-optional program failed to load.");
+ }
+ }
+
+ if (!fd.ok()) return fd.get();
+
+ if (!reuse) {
+ if (specified(selinux_context)) {
+ string createLoc = string(BPF_FS_PATH) + lookupPinSubdir(selinux_context) +
+ "tmp_prog_" + objName + '_' + string(name);
+ ret = bpfFdPin(fd, createLoc.c_str());
+ if (ret) {
+ int err = errno;
+ ALOGE("create %s -> %d [%d:%s]", createLoc.c_str(), ret, err, strerror(err));
+ return -err;
+ }
+ ret = renameat2(AT_FDCWD, createLoc.c_str(),
+ AT_FDCWD, progPinLoc.c_str(), RENAME_NOREPLACE);
+ if (ret) {
+ int err = errno;
+ ALOGE("rename %s %s -> %d [%d:%s]", createLoc.c_str(), progPinLoc.c_str(), ret,
+ err, strerror(err));
+ return -err;
+ }
+ } else {
+ ret = bpfFdPin(fd, progPinLoc.c_str());
+ if (ret) {
+ int err = errno;
+ ALOGE("create %s -> %d [%d:%s]", progPinLoc.c_str(), ret, err, strerror(err));
+ return -err;
+ }
+ }
+ if (chmod(progPinLoc.c_str(), 0440)) {
+ int err = errno;
+ ALOGE("chmod %s 0440 -> [%d:%s]", progPinLoc.c_str(), err, strerror(err));
+ return -err;
+ }
+ if (chown(progPinLoc.c_str(), (uid_t)cs[i].prog_def->uid,
+ (gid_t)cs[i].prog_def->gid)) {
+ int err = errno;
+ ALOGE("chown %s %d %d -> [%d:%s]", progPinLoc.c_str(), cs[i].prog_def->uid,
+ cs[i].prog_def->gid, err, strerror(err));
+ return -err;
+ }
+ }
+
+ int progId = bpfGetFdProgId(fd);
+ if (progId == -1) {
+ ALOGE("bpfGetFdProgId failed, ret: %d [%d]", progId, errno);
+ } else {
+ ALOGI("prog %s id %d", progPinLoc.c_str(), progId);
+ }
+ }
+
+ return 0;
+}
+
+int loadProg(const char* const elfPath, bool* const isCritical, const unsigned int bpfloader_ver,
+ const Location& location) {
+ vector<char> license;
+ vector<char> critical;
+ vector<codeSection> cs;
+ vector<unique_fd> mapFds;
+ int ret;
+
+ if (!isCritical) return -1;
+ *isCritical = false;
+
+ ifstream elfFile(elfPath, ios::in | ios::binary);
+ if (!elfFile.is_open()) return -1;
+
+ ret = readSectionByName("critical", elfFile, critical);
+ *isCritical = !ret;
+
+ ret = readSectionByName("license", elfFile, license);
+ if (ret) {
+ ALOGE("Couldn't find license in %s", elfPath);
+ return ret;
+ } else {
+ ALOGD("Loading %s%s ELF object %s with license %s",
+ *isCritical ? "critical for " : "optional", *isCritical ? (char*)critical.data() : "",
+ elfPath, (char*)license.data());
+ }
+
+ // the following default values are for bpfloader V0.0 format which does not include them
+ unsigned int bpfLoaderMinVer =
+ readSectionUint("bpfloader_min_ver", elfFile, DEFAULT_BPFLOADER_MIN_VER);
+ unsigned int bpfLoaderMaxVer =
+ readSectionUint("bpfloader_max_ver", elfFile, DEFAULT_BPFLOADER_MAX_VER);
+ unsigned int bpfLoaderMinRequiredVer =
+ readSectionUint("bpfloader_min_required_ver", elfFile, 0);
+ size_t sizeOfBpfMapDef =
+ readSectionUint("size_of_bpf_map_def", elfFile, DEFAULT_SIZEOF_BPF_MAP_DEF);
+ size_t sizeOfBpfProgDef =
+ readSectionUint("size_of_bpf_prog_def", elfFile, DEFAULT_SIZEOF_BPF_PROG_DEF);
+
+ // inclusive lower bound check
+ if (bpfloader_ver < bpfLoaderMinVer) {
+ ALOGI("BpfLoader version 0x%05x ignoring ELF object %s with min ver 0x%05x",
+ bpfloader_ver, elfPath, bpfLoaderMinVer);
+ return 0;
+ }
+
+ // exclusive upper bound check
+ if (bpfloader_ver >= bpfLoaderMaxVer) {
+ ALOGI("BpfLoader version 0x%05x ignoring ELF object %s with max ver 0x%05x",
+ bpfloader_ver, elfPath, bpfLoaderMaxVer);
+ return 0;
+ }
+
+ if (bpfloader_ver < bpfLoaderMinRequiredVer) {
+ ALOGI("BpfLoader version 0x%05x failing due to ELF object %s with required min ver 0x%05x",
+ bpfloader_ver, elfPath, bpfLoaderMinRequiredVer);
+ return -1;
+ }
+
+ ALOGI("BpfLoader version 0x%05x processing ELF object %s with ver [0x%05x,0x%05x)",
+ bpfloader_ver, elfPath, bpfLoaderMinVer, bpfLoaderMaxVer);
+
+ if (sizeOfBpfMapDef < DEFAULT_SIZEOF_BPF_MAP_DEF) {
+ ALOGE("sizeof(bpf_map_def) of %zu is too small (< %d)", sizeOfBpfMapDef,
+ DEFAULT_SIZEOF_BPF_MAP_DEF);
+ return -1;
+ }
+
+ if (sizeOfBpfProgDef < DEFAULT_SIZEOF_BPF_PROG_DEF) {
+ ALOGE("sizeof(bpf_prog_def) of %zu is too small (< %d)", sizeOfBpfProgDef,
+ DEFAULT_SIZEOF_BPF_PROG_DEF);
+ return -1;
+ }
+
+ ret = readCodeSections(elfFile, cs, sizeOfBpfProgDef);
+ if (ret) {
+ ALOGE("Couldn't read all code sections in %s", elfPath);
+ return ret;
+ }
+
+ /* Just for future debugging */
+ if (0) dumpAllCs(cs);
+
+ ret = createMaps(elfPath, elfFile, mapFds, location.prefix, sizeOfBpfMapDef, bpfloader_ver);
+ if (ret) {
+ ALOGE("Failed to create maps: (ret=%d) in %s", ret, elfPath);
+ return ret;
+ }
+
+ for (int i = 0; i < (int)mapFds.size(); i++)
+ ALOGV("map_fd found at %d is %d in %s", i, mapFds[i].get(), elfPath);
+
+ applyMapRelo(elfFile, mapFds, cs);
+
+ ret = loadCodeSections(elfPath, cs, string(license.data()), location.prefix, bpfloader_ver);
+ if (ret) ALOGE("Failed to load programs, loadCodeSections ret=%d", ret);
+
+ return ret;
+}
static bool exists(const char* const path) {
int v = access(path, F_OK);
diff --git a/netbpfload/loader.cpp b/netbpfload/loader.cpp
deleted file mode 100644
index 5141095..0000000
--- a/netbpfload/loader.cpp
+++ /dev/null
@@ -1,1194 +0,0 @@
-/*
- * Copyright (C) 2018-2023 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#define LOG_TAG "NetBpfLoad"
-
-#include <errno.h>
-#include <fcntl.h>
-#include <linux/bpf.h>
-#include <linux/elf.h>
-#include <log/log.h>
-#include <stdint.h>
-#include <stdio.h>
-#include <stdlib.h>
-#include <string.h>
-#include <sysexits.h>
-#include <sys/stat.h>
-#include <sys/utsname.h>
-#include <sys/wait.h>
-#include <unistd.h>
-
-#include "BpfSyscallWrappers.h"
-#include "bpf/BpfUtils.h"
-#include "bpf/bpf_map_def.h"
-#include "loader.h"
-
-#include <cstdlib>
-#include <fstream>
-#include <iostream>
-#include <optional>
-#include <string>
-#include <unordered_map>
-#include <vector>
-
-#include <android-base/cmsg.h>
-#include <android-base/file.h>
-#include <android-base/properties.h>
-#include <android-base/strings.h>
-#include <android-base/unique_fd.h>
-
-#define BPF_FS_PATH "/sys/fs/bpf/"
-
-// Size of the BPF log buffer for verifier logging
-#define BPF_LOAD_LOG_SZ 0xfffff
-
-// Unspecified attach type is 0 which is BPF_CGROUP_INET_INGRESS.
-#define BPF_ATTACH_TYPE_UNSPEC BPF_CGROUP_INET_INGRESS
-
-using android::base::StartsWith;
-using android::base::unique_fd;
-using std::ifstream;
-using std::ios;
-using std::optional;
-using std::string;
-using std::vector;
-
-namespace android {
-namespace bpf {
-
-const std::string& getBuildType() {
- static std::string t = android::base::GetProperty("ro.build.type", "unknown");
- return t;
-}
-
-static unsigned int page_size = static_cast<unsigned int>(getpagesize());
-
-constexpr const char* lookupSelinuxContext(const domain d, const char* const unspecified = "") {
- switch (d) {
- case domain::unspecified: return unspecified;
- case domain::tethering: return "fs_bpf_tethering";
- case domain::net_private: return "fs_bpf_net_private";
- case domain::net_shared: return "fs_bpf_net_shared";
- case domain::netd_readonly: return "fs_bpf_netd_readonly";
- case domain::netd_shared: return "fs_bpf_netd_shared";
- default: return "(unrecognized)";
- }
-}
-
-domain getDomainFromSelinuxContext(const char s[BPF_SELINUX_CONTEXT_CHAR_ARRAY_SIZE]) {
- for (domain d : AllDomains) {
- // Not sure how to enforce this at compile time, so abort() bpfloader at boot instead
- if (strlen(lookupSelinuxContext(d)) >= BPF_SELINUX_CONTEXT_CHAR_ARRAY_SIZE) abort();
- if (!strncmp(s, lookupSelinuxContext(d), BPF_SELINUX_CONTEXT_CHAR_ARRAY_SIZE)) return d;
- }
- ALOGW("ignoring unrecognized selinux_context '%-32s'", s);
- // We should return 'unrecognized' here, however: returning unspecified will
- // result in the system simply using the default context, which in turn
- // will allow future expansion by adding more restrictive selinux types.
- // Older bpfloader will simply ignore that, and use the less restrictive default.
- // This does mean you CANNOT later add a *less* restrictive type than the default.
- //
- // Note: we cannot just abort() here as this might be a mainline module shipped optional update
- return domain::unspecified;
-}
-
-constexpr const char* lookupPinSubdir(const domain d, const char* const unspecified = "") {
- switch (d) {
- case domain::unspecified: return unspecified;
- case domain::tethering: return "tethering/";
- case domain::net_private: return "net_private/";
- case domain::net_shared: return "net_shared/";
- case domain::netd_readonly: return "netd_readonly/";
- case domain::netd_shared: return "netd_shared/";
- default: return "(unrecognized)";
- }
-};
-
-domain getDomainFromPinSubdir(const char s[BPF_PIN_SUBDIR_CHAR_ARRAY_SIZE]) {
- for (domain d : AllDomains) {
- // Not sure how to enforce this at compile time, so abort() bpfloader at boot instead
- if (strlen(lookupPinSubdir(d)) >= BPF_PIN_SUBDIR_CHAR_ARRAY_SIZE) abort();
- if (!strncmp(s, lookupPinSubdir(d), BPF_PIN_SUBDIR_CHAR_ARRAY_SIZE)) return d;
- }
- ALOGE("unrecognized pin_subdir '%-32s'", s);
- // pin_subdir affects the object's full pathname,
- // and thus using the default would change the location and thus our code's ability to find it,
- // hence this seems worth treating as a true error condition.
- //
- // Note: we cannot just abort() here as this might be a mainline module shipped optional update
- // However, our callers will treat this as an error, and stop loading the specific .o,
- // which will fail bpfloader if the .o is marked critical.
- return domain::unrecognized;
-}
-
-static string pathToObjName(const string& path) {
- // extract everything after the final slash, ie. this is the filename 'foo@1.o' or 'bar.o'
- string filename = android::base::Split(path, "/").back();
- // strip off everything from the final period onwards (strip '.o' suffix), ie. 'foo@1' or 'bar'
- string name = filename.substr(0, filename.find_last_of('.'));
- // strip any potential @1 suffix, this will leave us with just 'foo' or 'bar'
- // this can be used to provide duplicate programs (mux based on the bpfloader version)
- return name.substr(0, name.find_last_of('@'));
-}
-
-typedef struct {
- const char* name;
- enum bpf_prog_type type;
- enum bpf_attach_type expected_attach_type;
-} sectionType;
-
-/*
- * Map section name prefixes to program types, the section name will be:
- * SECTION(<prefix>/<name-of-program>)
- * For example:
- * SECTION("tracepoint/sched_switch_func") where sched_switch_funcs
- * is the name of the program, and tracepoint is the type.
- *
- * However, be aware that you should not be directly using the SECTION() macro.
- * Instead use the DEFINE_(BPF|XDP)_(PROG|MAP)... & LICENSE/CRITICAL macros.
- *
- * Programs shipped inside the tethering apex should be limited to networking stuff,
- * as KPROBE, PERF_EVENT, TRACEPOINT are dangerous to use from mainline updatable code,
- * since they are less stable abi/api and may conflict with platform uses of bpf.
- */
-sectionType sectionNameTypes[] = {
- {"bind4/", BPF_PROG_TYPE_CGROUP_SOCK_ADDR, BPF_CGROUP_INET4_BIND},
- {"bind6/", BPF_PROG_TYPE_CGROUP_SOCK_ADDR, BPF_CGROUP_INET6_BIND},
- {"cgroupskb/", BPF_PROG_TYPE_CGROUP_SKB, BPF_ATTACH_TYPE_UNSPEC},
- {"cgroupsock/", BPF_PROG_TYPE_CGROUP_SOCK, BPF_ATTACH_TYPE_UNSPEC},
- {"cgroupsockcreate/", BPF_PROG_TYPE_CGROUP_SOCK, BPF_CGROUP_INET_SOCK_CREATE},
- {"cgroupsockrelease/", BPF_PROG_TYPE_CGROUP_SOCK, BPF_CGROUP_INET_SOCK_RELEASE},
- {"connect4/", BPF_PROG_TYPE_CGROUP_SOCK_ADDR, BPF_CGROUP_INET4_CONNECT},
- {"connect6/", BPF_PROG_TYPE_CGROUP_SOCK_ADDR, BPF_CGROUP_INET6_CONNECT},
- {"egress/", BPF_PROG_TYPE_CGROUP_SKB, BPF_CGROUP_INET_EGRESS},
- {"getsockopt/", BPF_PROG_TYPE_CGROUP_SOCKOPT, BPF_CGROUP_GETSOCKOPT},
- {"ingress/", BPF_PROG_TYPE_CGROUP_SKB, BPF_CGROUP_INET_INGRESS},
- {"lwt_in/", BPF_PROG_TYPE_LWT_IN, BPF_ATTACH_TYPE_UNSPEC},
- {"lwt_out/", BPF_PROG_TYPE_LWT_OUT, BPF_ATTACH_TYPE_UNSPEC},
- {"lwt_seg6local/", BPF_PROG_TYPE_LWT_SEG6LOCAL, BPF_ATTACH_TYPE_UNSPEC},
- {"lwt_xmit/", BPF_PROG_TYPE_LWT_XMIT, BPF_ATTACH_TYPE_UNSPEC},
- {"postbind4/", BPF_PROG_TYPE_CGROUP_SOCK, BPF_CGROUP_INET4_POST_BIND},
- {"postbind6/", BPF_PROG_TYPE_CGROUP_SOCK, BPF_CGROUP_INET6_POST_BIND},
- {"recvmsg4/", BPF_PROG_TYPE_CGROUP_SOCK_ADDR, BPF_CGROUP_UDP4_RECVMSG},
- {"recvmsg6/", BPF_PROG_TYPE_CGROUP_SOCK_ADDR, BPF_CGROUP_UDP6_RECVMSG},
- {"schedact/", BPF_PROG_TYPE_SCHED_ACT, BPF_ATTACH_TYPE_UNSPEC},
- {"schedcls/", BPF_PROG_TYPE_SCHED_CLS, BPF_ATTACH_TYPE_UNSPEC},
- {"sendmsg4/", BPF_PROG_TYPE_CGROUP_SOCK_ADDR, BPF_CGROUP_UDP4_SENDMSG},
- {"sendmsg6/", BPF_PROG_TYPE_CGROUP_SOCK_ADDR, BPF_CGROUP_UDP6_SENDMSG},
- {"setsockopt/", BPF_PROG_TYPE_CGROUP_SOCKOPT, BPF_CGROUP_SETSOCKOPT},
- {"skfilter/", BPF_PROG_TYPE_SOCKET_FILTER, BPF_ATTACH_TYPE_UNSPEC},
- {"sockops/", BPF_PROG_TYPE_SOCK_OPS, BPF_CGROUP_SOCK_OPS},
- {"sysctl", BPF_PROG_TYPE_CGROUP_SYSCTL, BPF_CGROUP_SYSCTL},
- {"xdp/", BPF_PROG_TYPE_XDP, BPF_ATTACH_TYPE_UNSPEC},
-};
-
-typedef struct {
- enum bpf_prog_type type;
- enum bpf_attach_type expected_attach_type;
- string name;
- vector<char> data;
- vector<char> rel_data;
- optional<struct bpf_prog_def> prog_def;
-
- unique_fd prog_fd; /* fd after loading */
-} codeSection;
-
-static int readElfHeader(ifstream& elfFile, Elf64_Ehdr* eh) {
- elfFile.seekg(0);
- if (elfFile.fail()) return -1;
-
- if (!elfFile.read((char*)eh, sizeof(*eh))) return -1;
-
- return 0;
-}
-
-/* Reads all section header tables into an Shdr array */
-static int readSectionHeadersAll(ifstream& elfFile, vector<Elf64_Shdr>& shTable) {
- Elf64_Ehdr eh;
- int ret = 0;
-
- ret = readElfHeader(elfFile, &eh);
- if (ret) return ret;
-
- elfFile.seekg(eh.e_shoff);
- if (elfFile.fail()) return -1;
-
- /* Read shdr table entries */
- shTable.resize(eh.e_shnum);
-
- if (!elfFile.read((char*)shTable.data(), (eh.e_shnum * eh.e_shentsize))) return -ENOMEM;
-
- return 0;
-}
-
-/* Read a section by its index - for ex to get sec hdr strtab blob */
-static int readSectionByIdx(ifstream& elfFile, int id, vector<char>& sec) {
- vector<Elf64_Shdr> shTable;
- int ret = readSectionHeadersAll(elfFile, shTable);
- if (ret) return ret;
-
- elfFile.seekg(shTable[id].sh_offset);
- if (elfFile.fail()) return -1;
-
- sec.resize(shTable[id].sh_size);
- if (!elfFile.read(sec.data(), shTable[id].sh_size)) return -1;
-
- return 0;
-}
-
-/* Read whole section header string table */
-static int readSectionHeaderStrtab(ifstream& elfFile, vector<char>& strtab) {
- Elf64_Ehdr eh;
- int ret = readElfHeader(elfFile, &eh);
- if (ret) return ret;
-
- ret = readSectionByIdx(elfFile, eh.e_shstrndx, strtab);
- if (ret) return ret;
-
- return 0;
-}
-
-/* Get name from offset in strtab */
-static int getSymName(ifstream& elfFile, int nameOff, string& name) {
- int ret;
- vector<char> secStrTab;
-
- ret = readSectionHeaderStrtab(elfFile, secStrTab);
- if (ret) return ret;
-
- if (nameOff >= (int)secStrTab.size()) return -1;
-
- name = string((char*)secStrTab.data() + nameOff);
- return 0;
-}
-
-/* Reads a full section by name - example to get the GPL license */
-static int readSectionByName(const char* name, ifstream& elfFile, vector<char>& data) {
- vector<char> secStrTab;
- vector<Elf64_Shdr> shTable;
- int ret;
-
- ret = readSectionHeadersAll(elfFile, shTable);
- if (ret) return ret;
-
- ret = readSectionHeaderStrtab(elfFile, secStrTab);
- if (ret) return ret;
-
- for (int i = 0; i < (int)shTable.size(); i++) {
- char* secname = secStrTab.data() + shTable[i].sh_name;
- if (!secname) continue;
-
- if (!strcmp(secname, name)) {
- vector<char> dataTmp;
- dataTmp.resize(shTable[i].sh_size);
-
- elfFile.seekg(shTable[i].sh_offset);
- if (elfFile.fail()) return -1;
-
- if (!elfFile.read((char*)dataTmp.data(), shTable[i].sh_size)) return -1;
-
- data = dataTmp;
- return 0;
- }
- }
- return -2;
-}
-
-unsigned int readSectionUint(const char* name, ifstream& elfFile, unsigned int defVal) {
- vector<char> theBytes;
- int ret = readSectionByName(name, elfFile, theBytes);
- if (ret) {
- ALOGD("Couldn't find section %s (defaulting to %u [0x%x]).", name, defVal, defVal);
- return defVal;
- } else if (theBytes.size() < sizeof(unsigned int)) {
- ALOGE("Section %s too short (defaulting to %u [0x%x]).", name, defVal, defVal);
- return defVal;
- } else {
- // decode first 4 bytes as LE32 uint, there will likely be more bytes due to alignment.
- unsigned int value = static_cast<unsigned char>(theBytes[3]);
- value <<= 8;
- value += static_cast<unsigned char>(theBytes[2]);
- value <<= 8;
- value += static_cast<unsigned char>(theBytes[1]);
- value <<= 8;
- value += static_cast<unsigned char>(theBytes[0]);
- ALOGI("Section %s value is %u [0x%x]", name, value, value);
- return value;
- }
-}
-
-static int readSectionByType(ifstream& elfFile, int type, vector<char>& data) {
- int ret;
- vector<Elf64_Shdr> shTable;
-
- ret = readSectionHeadersAll(elfFile, shTable);
- if (ret) return ret;
-
- for (int i = 0; i < (int)shTable.size(); i++) {
- if ((int)shTable[i].sh_type != type) continue;
-
- vector<char> dataTmp;
- dataTmp.resize(shTable[i].sh_size);
-
- elfFile.seekg(shTable[i].sh_offset);
- if (elfFile.fail()) return -1;
-
- if (!elfFile.read((char*)dataTmp.data(), shTable[i].sh_size)) return -1;
-
- data = dataTmp;
- return 0;
- }
- return -2;
-}
-
-static bool symCompare(Elf64_Sym a, Elf64_Sym b) {
- return (a.st_value < b.st_value);
-}
-
-static int readSymTab(ifstream& elfFile, int sort, vector<Elf64_Sym>& data) {
- int ret, numElems;
- Elf64_Sym* buf;
- vector<char> secData;
-
- ret = readSectionByType(elfFile, SHT_SYMTAB, secData);
- if (ret) return ret;
-
- buf = (Elf64_Sym*)secData.data();
- numElems = (secData.size() / sizeof(Elf64_Sym));
- data.assign(buf, buf + numElems);
-
- if (sort) std::sort(data.begin(), data.end(), symCompare);
- return 0;
-}
-
-static enum bpf_prog_type getSectionType(string& name) {
- for (auto& snt : sectionNameTypes)
- if (StartsWith(name, snt.name)) return snt.type;
-
- return BPF_PROG_TYPE_UNSPEC;
-}
-
-static enum bpf_attach_type getExpectedAttachType(string& name) {
- for (auto& snt : sectionNameTypes)
- if (StartsWith(name, snt.name)) return snt.expected_attach_type;
- return BPF_ATTACH_TYPE_UNSPEC;
-}
-
-/*
-static string getSectionName(enum bpf_prog_type type)
-{
- for (auto& snt : sectionNameTypes)
- if (snt.type == type)
- return string(snt.name);
-
- return "UNKNOWN SECTION NAME " + std::to_string(type);
-}
-*/
-
-static int readProgDefs(ifstream& elfFile, vector<struct bpf_prog_def>& pd,
- size_t sizeOfBpfProgDef) {
- vector<char> pdData;
- int ret = readSectionByName("progs", elfFile, pdData);
- if (ret) return ret;
-
- if (pdData.size() % sizeOfBpfProgDef) {
- ALOGE("readProgDefs failed due to improper sized progs section, %zu %% %zu != 0",
- pdData.size(), sizeOfBpfProgDef);
- return -1;
- };
-
- int progCount = pdData.size() / sizeOfBpfProgDef;
- pd.resize(progCount);
- size_t trimmedSize = std::min(sizeOfBpfProgDef, sizeof(struct bpf_prog_def));
-
- const char* dataPtr = pdData.data();
- for (auto& p : pd) {
- // First we zero initialize
- memset(&p, 0, sizeof(p));
- // Then we set non-zero defaults
- p.bpfloader_max_ver = DEFAULT_BPFLOADER_MAX_VER; // v1.0
- // Then we copy over the structure prefix from the ELF file.
- memcpy(&p, dataPtr, trimmedSize);
- // Move to next struct in the ELF file
- dataPtr += sizeOfBpfProgDef;
- }
- return 0;
-}
-
-static int getSectionSymNames(ifstream& elfFile, const string& sectionName, vector<string>& names,
- optional<unsigned> symbolType = std::nullopt) {
- int ret;
- string name;
- vector<Elf64_Sym> symtab;
- vector<Elf64_Shdr> shTable;
-
- ret = readSymTab(elfFile, 1 /* sort */, symtab);
- if (ret) return ret;
-
- /* Get index of section */
- ret = readSectionHeadersAll(elfFile, shTable);
- if (ret) return ret;
-
- int sec_idx = -1;
- for (int i = 0; i < (int)shTable.size(); i++) {
- ret = getSymName(elfFile, shTable[i].sh_name, name);
- if (ret) return ret;
-
- if (!name.compare(sectionName)) {
- sec_idx = i;
- break;
- }
- }
-
- /* No section found with matching name*/
- if (sec_idx == -1) {
- ALOGW("No %s section could be found in elf object", sectionName.c_str());
- return -1;
- }
-
- for (int i = 0; i < (int)symtab.size(); i++) {
- if (symbolType.has_value() && ELF_ST_TYPE(symtab[i].st_info) != symbolType) continue;
-
- if (symtab[i].st_shndx == sec_idx) {
- string s;
- ret = getSymName(elfFile, symtab[i].st_name, s);
- if (ret) return ret;
- names.push_back(s);
- }
- }
-
- return 0;
-}
-
-/* Read a section by its index - for ex to get sec hdr strtab blob */
-static int readCodeSections(ifstream& elfFile, vector<codeSection>& cs, size_t sizeOfBpfProgDef) {
- vector<Elf64_Shdr> shTable;
- int entries, ret = 0;
-
- ret = readSectionHeadersAll(elfFile, shTable);
- if (ret) return ret;
- entries = shTable.size();
-
- vector<struct bpf_prog_def> pd;
- ret = readProgDefs(elfFile, pd, sizeOfBpfProgDef);
- if (ret) return ret;
- vector<string> progDefNames;
- ret = getSectionSymNames(elfFile, "progs", progDefNames);
- if (!pd.empty() && ret) return ret;
-
- for (int i = 0; i < entries; i++) {
- string name;
- codeSection cs_temp;
- cs_temp.type = BPF_PROG_TYPE_UNSPEC;
-
- ret = getSymName(elfFile, shTable[i].sh_name, name);
- if (ret) return ret;
-
- enum bpf_prog_type ptype = getSectionType(name);
-
- if (ptype == BPF_PROG_TYPE_UNSPEC) continue;
-
- // This must be done before '/' is replaced with '_'.
- cs_temp.expected_attach_type = getExpectedAttachType(name);
-
- string oldName = name;
-
- // convert all slashes to underscores
- std::replace(name.begin(), name.end(), '/', '_');
-
- cs_temp.type = ptype;
- cs_temp.name = name;
-
- ret = readSectionByIdx(elfFile, i, cs_temp.data);
- if (ret) return ret;
- ALOGV("Loaded code section %d (%s)", i, name.c_str());
-
- vector<string> csSymNames;
- ret = getSectionSymNames(elfFile, oldName, csSymNames, STT_FUNC);
- if (ret || !csSymNames.size()) return ret;
- for (size_t i = 0; i < progDefNames.size(); ++i) {
- if (!progDefNames[i].compare(csSymNames[0] + "_def")) {
- cs_temp.prog_def = pd[i];
- break;
- }
- }
-
- /* Check for rel section */
- if (cs_temp.data.size() > 0 && i < entries) {
- ret = getSymName(elfFile, shTable[i + 1].sh_name, name);
- if (ret) return ret;
-
- if (name == (".rel" + oldName)) {
- ret = readSectionByIdx(elfFile, i + 1, cs_temp.rel_data);
- if (ret) return ret;
- ALOGV("Loaded relo section %d (%s)", i, name.c_str());
- }
- }
-
- if (cs_temp.data.size() > 0) {
- cs.push_back(std::move(cs_temp));
- ALOGV("Adding section %d to cs list", i);
- }
- }
- return 0;
-}
-
-static int getSymNameByIdx(ifstream& elfFile, int index, string& name) {
- vector<Elf64_Sym> symtab;
- int ret = 0;
-
- ret = readSymTab(elfFile, 0 /* !sort */, symtab);
- if (ret) return ret;
-
- if (index >= (int)symtab.size()) return -1;
-
- return getSymName(elfFile, symtab[index].st_name, name);
-}
-
-static bool mapMatchesExpectations(const unique_fd& fd, const string& mapName,
- const struct bpf_map_def& mapDef, const enum bpf_map_type type) {
- // bpfGetFd... family of functions require at minimum a 4.14 kernel,
- // so on 4.9-T kernels just pretend the map matches our expectations.
- // Additionally we'll get almost equivalent test coverage on newer devices/kernels.
- // This is because the primary failure mode we're trying to detect here
- // is either a source code misconfiguration (which is likely kernel independent)
- // or a newly introduced kernel feature/bug (which is unlikely to get backported to 4.9).
- if (!isAtLeastKernelVersion(4, 14, 0)) return true;
-
- // Assuming fd is a valid Bpf Map file descriptor then
- // all the following should always succeed on a 4.14+ kernel.
- // If they somehow do fail, they'll return -1 (and set errno),
- // which should then cause (among others) a key_size mismatch.
- int fd_type = bpfGetFdMapType(fd);
- int fd_key_size = bpfGetFdKeySize(fd);
- int fd_value_size = bpfGetFdValueSize(fd);
- int fd_max_entries = bpfGetFdMaxEntries(fd);
- int fd_map_flags = bpfGetFdMapFlags(fd);
-
- // DEVMAPs are readonly from the bpf program side's point of view, as such
- // the kernel in kernel/bpf/devmap.c dev_map_init_map() will set the flag
- int desired_map_flags = (int)mapDef.map_flags;
- if (type == BPF_MAP_TYPE_DEVMAP || type == BPF_MAP_TYPE_DEVMAP_HASH)
- desired_map_flags |= BPF_F_RDONLY_PROG;
-
- // The .h file enforces that this is a power of two, and page size will
- // also always be a power of two, so this logic is actually enough to
- // force it to be a multiple of the page size, as required by the kernel.
- unsigned int desired_max_entries = mapDef.max_entries;
- if (type == BPF_MAP_TYPE_RINGBUF) {
- if (desired_max_entries < page_size) desired_max_entries = page_size;
- }
-
- // The following checks should *never* trigger, if one of them somehow does,
- // it probably means a bpf .o file has been changed/replaced at runtime
- // and bpfloader was manually rerun (normally it should only run *once*
- // early during the boot process).
- // Another possibility is that something is misconfigured in the code:
- // most likely a shared map is declared twice differently.
- // But such a change should never be checked into the source tree...
- if ((fd_type == type) &&
- (fd_key_size == (int)mapDef.key_size) &&
- (fd_value_size == (int)mapDef.value_size) &&
- (fd_max_entries == (int)desired_max_entries) &&
- (fd_map_flags == desired_map_flags)) {
- return true;
- }
-
- ALOGE("bpf map name %s mismatch: desired/found: "
- "type:%d/%d key:%u/%d value:%u/%d entries:%u/%d flags:%u/%d",
- mapName.c_str(), type, fd_type, mapDef.key_size, fd_key_size, mapDef.value_size,
- fd_value_size, mapDef.max_entries, fd_max_entries, desired_map_flags, fd_map_flags);
- return false;
-}
-
-static int createMaps(const char* elfPath, ifstream& elfFile, vector<unique_fd>& mapFds,
- const char* prefix, const size_t sizeOfBpfMapDef,
- const unsigned int bpfloader_ver) {
- int ret;
- vector<char> mdData;
- vector<struct bpf_map_def> md;
- vector<string> mapNames;
- string objName = pathToObjName(string(elfPath));
-
- ret = readSectionByName("maps", elfFile, mdData);
- if (ret == -2) return 0; // no maps to read
- if (ret) return ret;
-
- if (mdData.size() % sizeOfBpfMapDef) {
- ALOGE("createMaps failed due to improper sized maps section, %zu %% %zu != 0",
- mdData.size(), sizeOfBpfMapDef);
- return -1;
- };
-
- int mapCount = mdData.size() / sizeOfBpfMapDef;
- md.resize(mapCount);
- size_t trimmedSize = std::min(sizeOfBpfMapDef, sizeof(struct bpf_map_def));
-
- const char* dataPtr = mdData.data();
- for (auto& m : md) {
- // First we zero initialize
- memset(&m, 0, sizeof(m));
- // Then we set non-zero defaults
- m.bpfloader_max_ver = DEFAULT_BPFLOADER_MAX_VER; // v1.0
- m.max_kver = 0xFFFFFFFFu; // matches KVER_INF from bpf_helpers.h
- // Then we copy over the structure prefix from the ELF file.
- memcpy(&m, dataPtr, trimmedSize);
- // Move to next struct in the ELF file
- dataPtr += sizeOfBpfMapDef;
- }
-
- ret = getSectionSymNames(elfFile, "maps", mapNames);
- if (ret) return ret;
-
- unsigned kvers = kernelVersion();
-
- for (int i = 0; i < (int)mapNames.size(); i++) {
- if (md[i].zero != 0) abort();
-
- if (bpfloader_ver < md[i].bpfloader_min_ver) {
- ALOGI("skipping map %s which requires bpfloader min ver 0x%05x", mapNames[i].c_str(),
- md[i].bpfloader_min_ver);
- mapFds.push_back(unique_fd());
- continue;
- }
-
- if (bpfloader_ver >= md[i].bpfloader_max_ver) {
- ALOGI("skipping map %s which requires bpfloader max ver 0x%05x", mapNames[i].c_str(),
- md[i].bpfloader_max_ver);
- mapFds.push_back(unique_fd());
- continue;
- }
-
- if (kvers < md[i].min_kver) {
- ALOGI("skipping map %s which requires kernel version 0x%x >= 0x%x",
- mapNames[i].c_str(), kvers, md[i].min_kver);
- mapFds.push_back(unique_fd());
- continue;
- }
-
- if (kvers >= md[i].max_kver) {
- ALOGI("skipping map %s which requires kernel version 0x%x < 0x%x",
- mapNames[i].c_str(), kvers, md[i].max_kver);
- mapFds.push_back(unique_fd());
- continue;
- }
-
- if ((md[i].ignore_on_eng && isEng()) || (md[i].ignore_on_user && isUser()) ||
- (md[i].ignore_on_userdebug && isUserdebug())) {
- ALOGI("skipping map %s which is ignored on %s builds", mapNames[i].c_str(),
- getBuildType().c_str());
- mapFds.push_back(unique_fd());
- continue;
- }
-
- if ((isArm() && isKernel32Bit() && md[i].ignore_on_arm32) ||
- (isArm() && isKernel64Bit() && md[i].ignore_on_aarch64) ||
- (isX86() && isKernel32Bit() && md[i].ignore_on_x86_32) ||
- (isX86() && isKernel64Bit() && md[i].ignore_on_x86_64) ||
- (isRiscV() && md[i].ignore_on_riscv64)) {
- ALOGI("skipping map %s which is ignored on %s", mapNames[i].c_str(),
- describeArch());
- mapFds.push_back(unique_fd());
- continue;
- }
-
- enum bpf_map_type type = md[i].type;
- if (type == BPF_MAP_TYPE_DEVMAP && !isAtLeastKernelVersion(4, 14, 0)) {
- // On Linux Kernels older than 4.14 this map type doesn't exist, but it can kind
- // of be approximated: ARRAY has the same userspace api, though it is not usable
- // by the same ebpf programs. However, that's okay because the bpf_redirect_map()
- // helper doesn't exist on 4.9-T anyway (so the bpf program would fail to load,
- // and thus needs to be tagged as 4.14+ either way), so there's nothing useful you
- // could do with a DEVMAP anyway (that isn't already provided by an ARRAY)...
- // Hence using an ARRAY instead of a DEVMAP simply makes life easier for userspace.
- type = BPF_MAP_TYPE_ARRAY;
- }
- if (type == BPF_MAP_TYPE_DEVMAP_HASH && !isAtLeastKernelVersion(5, 4, 0)) {
- // On Linux Kernels older than 5.4 this map type doesn't exist, but it can kind
- // of be approximated: HASH has the same userspace visible api.
- // However it cannot be used by ebpf programs in the same way.
- // Since bpf_redirect_map() only requires 4.14, a program using a DEVMAP_HASH map
- // would fail to load (due to trying to redirect to a HASH instead of DEVMAP_HASH).
- // One must thus tag any BPF_MAP_TYPE_DEVMAP_HASH + bpf_redirect_map() using
- // programs as being 5.4+...
- type = BPF_MAP_TYPE_HASH;
- }
-
- // The .h file enforces that this is a power of two, and page size will
- // also always be a power of two, so this logic is actually enough to
- // force it to be a multiple of the page size, as required by the kernel.
- unsigned int max_entries = md[i].max_entries;
- if (type == BPF_MAP_TYPE_RINGBUF) {
- if (max_entries < page_size) max_entries = page_size;
- }
-
- domain selinux_context = getDomainFromSelinuxContext(md[i].selinux_context);
- if (specified(selinux_context)) {
- ALOGI("map %s selinux_context [%-32s] -> %d -> '%s' (%s)", mapNames[i].c_str(),
- md[i].selinux_context, static_cast<int>(selinux_context),
- lookupSelinuxContext(selinux_context), lookupPinSubdir(selinux_context));
- }
-
- domain pin_subdir = getDomainFromPinSubdir(md[i].pin_subdir);
- if (unrecognized(pin_subdir)) return -ENOTDIR;
- if (specified(pin_subdir)) {
- ALOGI("map %s pin_subdir [%-32s] -> %d -> '%s'", mapNames[i].c_str(), md[i].pin_subdir,
- static_cast<int>(pin_subdir), lookupPinSubdir(pin_subdir));
- }
-
- // Format of pin location is /sys/fs/bpf/<pin_subdir|prefix>map_<objName>_<mapName>
- // except that maps shared across .o's have empty <objName>
- // Note: <objName> refers to the extension-less basename of the .o file (without @ suffix).
- string mapPinLoc = string(BPF_FS_PATH) + lookupPinSubdir(pin_subdir, prefix) + "map_" +
- (md[i].shared ? "" : objName) + "_" + mapNames[i];
- bool reuse = false;
- unique_fd fd;
- int saved_errno;
-
- if (access(mapPinLoc.c_str(), F_OK) == 0) {
- fd.reset(mapRetrieveRO(mapPinLoc.c_str()));
- saved_errno = errno;
- ALOGD("bpf_create_map reusing map %s, ret: %d", mapNames[i].c_str(), fd.get());
- reuse = true;
- } else {
- union bpf_attr req = {
- .map_type = type,
- .key_size = md[i].key_size,
- .value_size = md[i].value_size,
- .max_entries = max_entries,
- .map_flags = md[i].map_flags,
- };
- if (isAtLeastKernelVersion(4, 15, 0))
- strlcpy(req.map_name, mapNames[i].c_str(), sizeof(req.map_name));
- fd.reset(bpf(BPF_MAP_CREATE, req));
- saved_errno = errno;
- ALOGD("bpf_create_map name %s, ret: %d", mapNames[i].c_str(), fd.get());
- }
-
- if (!fd.ok()) return -saved_errno;
-
- // When reusing a pinned map, we need to check the map type/sizes/etc match, but for
- // safety (since reuse code path is rare) run these checks even if we just created it.
- // We assume failure is due to pinned map mismatch, hence the 'NOT UNIQUE' return code.
- if (!mapMatchesExpectations(fd, mapNames[i], md[i], type)) return -ENOTUNIQ;
-
- if (!reuse) {
- if (specified(selinux_context)) {
- string createLoc = string(BPF_FS_PATH) + lookupPinSubdir(selinux_context) +
- "tmp_map_" + objName + "_" + mapNames[i];
- ret = bpfFdPin(fd, createLoc.c_str());
- if (ret) {
- int err = errno;
- ALOGE("create %s -> %d [%d:%s]", createLoc.c_str(), ret, err, strerror(err));
- return -err;
- }
- ret = renameat2(AT_FDCWD, createLoc.c_str(),
- AT_FDCWD, mapPinLoc.c_str(), RENAME_NOREPLACE);
- if (ret) {
- int err = errno;
- ALOGE("rename %s %s -> %d [%d:%s]", createLoc.c_str(), mapPinLoc.c_str(), ret,
- err, strerror(err));
- return -err;
- }
- } else {
- ret = bpfFdPin(fd, mapPinLoc.c_str());
- if (ret) {
- int err = errno;
- ALOGE("pin %s -> %d [%d:%s]", mapPinLoc.c_str(), ret, err, strerror(err));
- return -err;
- }
- }
- ret = chmod(mapPinLoc.c_str(), md[i].mode);
- if (ret) {
- int err = errno;
- ALOGE("chmod(%s, 0%o) = %d [%d:%s]", mapPinLoc.c_str(), md[i].mode, ret, err,
- strerror(err));
- return -err;
- }
- ret = chown(mapPinLoc.c_str(), (uid_t)md[i].uid, (gid_t)md[i].gid);
- if (ret) {
- int err = errno;
- ALOGE("chown(%s, %u, %u) = %d [%d:%s]", mapPinLoc.c_str(), md[i].uid, md[i].gid,
- ret, err, strerror(err));
- return -err;
- }
- }
-
- int mapId = bpfGetFdMapId(fd);
- if (mapId == -1) {
- ALOGE("bpfGetFdMapId failed, ret: %d [%d]", mapId, errno);
- } else {
- ALOGI("map %s id %d", mapPinLoc.c_str(), mapId);
- }
-
- mapFds.push_back(std::move(fd));
- }
-
- return ret;
-}
-
-/* For debugging, dump all instructions */
-static void dumpIns(char* ins, int size) {
- for (int row = 0; row < size / 8; row++) {
- ALOGE("%d: ", row);
- for (int j = 0; j < 8; j++) {
- ALOGE("%3x ", ins[(row * 8) + j]);
- }
- ALOGE("\n");
- }
-}
-
-/* For debugging, dump all code sections from cs list */
-static void dumpAllCs(vector<codeSection>& cs) {
- for (int i = 0; i < (int)cs.size(); i++) {
- ALOGE("Dumping cs %d, name %s", int(i), cs[i].name.c_str());
- dumpIns((char*)cs[i].data.data(), cs[i].data.size());
- ALOGE("-----------");
- }
-}
-
-static void applyRelo(void* insnsPtr, Elf64_Addr offset, int fd) {
- int insnIndex;
- struct bpf_insn *insn, *insns;
-
- insns = (struct bpf_insn*)(insnsPtr);
-
- insnIndex = offset / sizeof(struct bpf_insn);
- insn = &insns[insnIndex];
-
- // Occasionally might be useful for relocation debugging, but pretty spammy
- if (0) {
- ALOGV("applying relo to instruction at byte offset: %llu, "
- "insn offset %d, insn %llx",
- (unsigned long long)offset, insnIndex, *(unsigned long long*)insn);
- }
-
- if (insn->code != (BPF_LD | BPF_IMM | BPF_DW)) {
- ALOGE("Dumping all instructions till ins %d", insnIndex);
- ALOGE("invalid relo for insn %d: code 0x%x", insnIndex, insn->code);
- dumpIns((char*)insnsPtr, (insnIndex + 3) * 8);
- return;
- }
-
- insn->imm = fd;
- insn->src_reg = BPF_PSEUDO_MAP_FD;
-}
-
-static void applyMapRelo(ifstream& elfFile, vector<unique_fd> &mapFds, vector<codeSection>& cs) {
- vector<string> mapNames;
-
- int ret = getSectionSymNames(elfFile, "maps", mapNames);
- if (ret) return;
-
- for (int k = 0; k != (int)cs.size(); k++) {
- Elf64_Rel* rel = (Elf64_Rel*)(cs[k].rel_data.data());
- int n_rel = cs[k].rel_data.size() / sizeof(*rel);
-
- for (int i = 0; i < n_rel; i++) {
- int symIndex = ELF64_R_SYM(rel[i].r_info);
- string symName;
-
- ret = getSymNameByIdx(elfFile, symIndex, symName);
- if (ret) return;
-
- /* Find the map fd and apply relo */
- for (int j = 0; j < (int)mapNames.size(); j++) {
- if (!mapNames[j].compare(symName)) {
- applyRelo(cs[k].data.data(), rel[i].r_offset, mapFds[j]);
- break;
- }
- }
- }
- }
-}
-
-static int loadCodeSections(const char* elfPath, vector<codeSection>& cs, const string& license,
- const char* prefix, const unsigned int bpfloader_ver) {
- unsigned kvers = kernelVersion();
-
- if (!kvers) {
- ALOGE("unable to get kernel version");
- return -EINVAL;
- }
-
- string objName = pathToObjName(string(elfPath));
-
- for (int i = 0; i < (int)cs.size(); i++) {
- unique_fd& fd = cs[i].prog_fd;
- int ret;
- string name = cs[i].name;
-
- if (!cs[i].prog_def.has_value()) {
- ALOGE("[%d] '%s' missing program definition! bad bpf.o build?", i, name.c_str());
- return -EINVAL;
- }
-
- unsigned min_kver = cs[i].prog_def->min_kver;
- unsigned max_kver = cs[i].prog_def->max_kver;
- ALOGD("cs[%d].name:%s min_kver:%x .max_kver:%x (kvers:%x)", i, name.c_str(), min_kver,
- max_kver, kvers);
- if (kvers < min_kver) continue;
- if (kvers >= max_kver) continue;
-
- unsigned bpfMinVer = cs[i].prog_def->bpfloader_min_ver;
- unsigned bpfMaxVer = cs[i].prog_def->bpfloader_max_ver;
- domain selinux_context = getDomainFromSelinuxContext(cs[i].prog_def->selinux_context);
- domain pin_subdir = getDomainFromPinSubdir(cs[i].prog_def->pin_subdir);
- // Note: make sure to only check for unrecognized *after* verifying bpfloader
- // version limits include this bpfloader's version.
-
- ALOGD("cs[%d].name:%s requires bpfloader version [0x%05x,0x%05x)", i, name.c_str(),
- bpfMinVer, bpfMaxVer);
- if (bpfloader_ver < bpfMinVer) continue;
- if (bpfloader_ver >= bpfMaxVer) continue;
-
- if ((cs[i].prog_def->ignore_on_eng && isEng()) ||
- (cs[i].prog_def->ignore_on_user && isUser()) ||
- (cs[i].prog_def->ignore_on_userdebug && isUserdebug())) {
- ALOGD("cs[%d].name:%s is ignored on %s builds", i, name.c_str(),
- getBuildType().c_str());
- continue;
- }
-
- if ((isArm() && isKernel32Bit() && cs[i].prog_def->ignore_on_arm32) ||
- (isArm() && isKernel64Bit() && cs[i].prog_def->ignore_on_aarch64) ||
- (isX86() && isKernel32Bit() && cs[i].prog_def->ignore_on_x86_32) ||
- (isX86() && isKernel64Bit() && cs[i].prog_def->ignore_on_x86_64) ||
- (isRiscV() && cs[i].prog_def->ignore_on_riscv64)) {
- ALOGD("cs[%d].name:%s is ignored on %s", i, name.c_str(), describeArch());
- continue;
- }
-
- if (unrecognized(pin_subdir)) return -ENOTDIR;
-
- if (specified(selinux_context)) {
- ALOGI("prog %s selinux_context [%-32s] -> %d -> '%s' (%s)", name.c_str(),
- cs[i].prog_def->selinux_context, static_cast<int>(selinux_context),
- lookupSelinuxContext(selinux_context), lookupPinSubdir(selinux_context));
- }
-
- if (specified(pin_subdir)) {
- ALOGI("prog %s pin_subdir [%-32s] -> %d -> '%s'", name.c_str(),
- cs[i].prog_def->pin_subdir, static_cast<int>(pin_subdir),
- lookupPinSubdir(pin_subdir));
- }
-
- // strip any potential $foo suffix
- // this can be used to provide duplicate programs
- // conditionally loaded based on running kernel version
- name = name.substr(0, name.find_last_of('$'));
-
- bool reuse = false;
- // Format of pin location is
- // /sys/fs/bpf/<prefix>prog_<objName>_<progName>
- string progPinLoc = string(BPF_FS_PATH) + lookupPinSubdir(pin_subdir, prefix) + "prog_" +
- objName + '_' + string(name);
- if (access(progPinLoc.c_str(), F_OK) == 0) {
- fd.reset(retrieveProgram(progPinLoc.c_str()));
- ALOGD("New bpf prog load reusing prog %s, ret: %d (%s)", progPinLoc.c_str(), fd.get(),
- (!fd.ok() ? std::strerror(errno) : "no error"));
- reuse = true;
- } else {
- vector<char> log_buf(BPF_LOAD_LOG_SZ, 0);
-
- union bpf_attr req = {
- .prog_type = cs[i].type,
- .kern_version = kvers,
- .license = ptr_to_u64(license.c_str()),
- .insns = ptr_to_u64(cs[i].data.data()),
- .insn_cnt = static_cast<__u32>(cs[i].data.size() / sizeof(struct bpf_insn)),
- .log_level = 1,
- .log_buf = ptr_to_u64(log_buf.data()),
- .log_size = static_cast<__u32>(log_buf.size()),
- .expected_attach_type = cs[i].expected_attach_type,
- };
- if (isAtLeastKernelVersion(4, 15, 0))
- strlcpy(req.prog_name, cs[i].name.c_str(), sizeof(req.prog_name));
- fd.reset(bpf(BPF_PROG_LOAD, req));
-
- ALOGD("BPF_PROG_LOAD call for %s (%s) returned fd: %d (%s)", elfPath,
- cs[i].name.c_str(), fd.get(), (!fd.ok() ? std::strerror(errno) : "no error"));
-
- if (!fd.ok()) {
- vector<string> lines = android::base::Split(log_buf.data(), "\n");
-
- ALOGW("BPF_PROG_LOAD - BEGIN log_buf contents:");
- for (const auto& line : lines) ALOGW("%s", line.c_str());
- ALOGW("BPF_PROG_LOAD - END log_buf contents.");
-
- if (cs[i].prog_def->optional) {
- ALOGW("failed program is marked optional - continuing...");
- continue;
- }
- ALOGE("non-optional program failed to load.");
- }
- }
-
- if (!fd.ok()) return fd.get();
-
- if (!reuse) {
- if (specified(selinux_context)) {
- string createLoc = string(BPF_FS_PATH) + lookupPinSubdir(selinux_context) +
- "tmp_prog_" + objName + '_' + string(name);
- ret = bpfFdPin(fd, createLoc.c_str());
- if (ret) {
- int err = errno;
- ALOGE("create %s -> %d [%d:%s]", createLoc.c_str(), ret, err, strerror(err));
- return -err;
- }
- ret = renameat2(AT_FDCWD, createLoc.c_str(),
- AT_FDCWD, progPinLoc.c_str(), RENAME_NOREPLACE);
- if (ret) {
- int err = errno;
- ALOGE("rename %s %s -> %d [%d:%s]", createLoc.c_str(), progPinLoc.c_str(), ret,
- err, strerror(err));
- return -err;
- }
- } else {
- ret = bpfFdPin(fd, progPinLoc.c_str());
- if (ret) {
- int err = errno;
- ALOGE("create %s -> %d [%d:%s]", progPinLoc.c_str(), ret, err, strerror(err));
- return -err;
- }
- }
- if (chmod(progPinLoc.c_str(), 0440)) {
- int err = errno;
- ALOGE("chmod %s 0440 -> [%d:%s]", progPinLoc.c_str(), err, strerror(err));
- return -err;
- }
- if (chown(progPinLoc.c_str(), (uid_t)cs[i].prog_def->uid,
- (gid_t)cs[i].prog_def->gid)) {
- int err = errno;
- ALOGE("chown %s %d %d -> [%d:%s]", progPinLoc.c_str(), cs[i].prog_def->uid,
- cs[i].prog_def->gid, err, strerror(err));
- return -err;
- }
- }
-
- int progId = bpfGetFdProgId(fd);
- if (progId == -1) {
- ALOGE("bpfGetFdProgId failed, ret: %d [%d]", progId, errno);
- } else {
- ALOGI("prog %s id %d", progPinLoc.c_str(), progId);
- }
- }
-
- return 0;
-}
-
-int loadProg(const char* const elfPath, bool* const isCritical, const unsigned int bpfloader_ver,
- const Location& location) {
- vector<char> license;
- vector<char> critical;
- vector<codeSection> cs;
- vector<unique_fd> mapFds;
- int ret;
-
- if (!isCritical) return -1;
- *isCritical = false;
-
- ifstream elfFile(elfPath, ios::in | ios::binary);
- if (!elfFile.is_open()) return -1;
-
- ret = readSectionByName("critical", elfFile, critical);
- *isCritical = !ret;
-
- ret = readSectionByName("license", elfFile, license);
- if (ret) {
- ALOGE("Couldn't find license in %s", elfPath);
- return ret;
- } else {
- ALOGD("Loading %s%s ELF object %s with license %s",
- *isCritical ? "critical for " : "optional", *isCritical ? (char*)critical.data() : "",
- elfPath, (char*)license.data());
- }
-
- // the following default values are for bpfloader V0.0 format which does not include them
- unsigned int bpfLoaderMinVer =
- readSectionUint("bpfloader_min_ver", elfFile, DEFAULT_BPFLOADER_MIN_VER);
- unsigned int bpfLoaderMaxVer =
- readSectionUint("bpfloader_max_ver", elfFile, DEFAULT_BPFLOADER_MAX_VER);
- unsigned int bpfLoaderMinRequiredVer =
- readSectionUint("bpfloader_min_required_ver", elfFile, 0);
- size_t sizeOfBpfMapDef =
- readSectionUint("size_of_bpf_map_def", elfFile, DEFAULT_SIZEOF_BPF_MAP_DEF);
- size_t sizeOfBpfProgDef =
- readSectionUint("size_of_bpf_prog_def", elfFile, DEFAULT_SIZEOF_BPF_PROG_DEF);
-
- // inclusive lower bound check
- if (bpfloader_ver < bpfLoaderMinVer) {
- ALOGI("BpfLoader version 0x%05x ignoring ELF object %s with min ver 0x%05x",
- bpfloader_ver, elfPath, bpfLoaderMinVer);
- return 0;
- }
-
- // exclusive upper bound check
- if (bpfloader_ver >= bpfLoaderMaxVer) {
- ALOGI("BpfLoader version 0x%05x ignoring ELF object %s with max ver 0x%05x",
- bpfloader_ver, elfPath, bpfLoaderMaxVer);
- return 0;
- }
-
- if (bpfloader_ver < bpfLoaderMinRequiredVer) {
- ALOGI("BpfLoader version 0x%05x failing due to ELF object %s with required min ver 0x%05x",
- bpfloader_ver, elfPath, bpfLoaderMinRequiredVer);
- return -1;
- }
-
- ALOGI("BpfLoader version 0x%05x processing ELF object %s with ver [0x%05x,0x%05x)",
- bpfloader_ver, elfPath, bpfLoaderMinVer, bpfLoaderMaxVer);
-
- if (sizeOfBpfMapDef < DEFAULT_SIZEOF_BPF_MAP_DEF) {
- ALOGE("sizeof(bpf_map_def) of %zu is too small (< %d)", sizeOfBpfMapDef,
- DEFAULT_SIZEOF_BPF_MAP_DEF);
- return -1;
- }
-
- if (sizeOfBpfProgDef < DEFAULT_SIZEOF_BPF_PROG_DEF) {
- ALOGE("sizeof(bpf_prog_def) of %zu is too small (< %d)", sizeOfBpfProgDef,
- DEFAULT_SIZEOF_BPF_PROG_DEF);
- return -1;
- }
-
- ret = readCodeSections(elfFile, cs, sizeOfBpfProgDef);
- if (ret) {
- ALOGE("Couldn't read all code sections in %s", elfPath);
- return ret;
- }
-
- /* Just for future debugging */
- if (0) dumpAllCs(cs);
-
- ret = createMaps(elfPath, elfFile, mapFds, location.prefix, sizeOfBpfMapDef, bpfloader_ver);
- if (ret) {
- ALOGE("Failed to create maps: (ret=%d) in %s", ret, elfPath);
- return ret;
- }
-
- for (int i = 0; i < (int)mapFds.size(); i++)
- ALOGV("map_fd found at %d is %d in %s", i, mapFds[i].get(), elfPath);
-
- applyMapRelo(elfFile, mapFds, cs);
-
- ret = loadCodeSections(elfPath, cs, string(license.data()), location.prefix, bpfloader_ver);
- if (ret) ALOGE("Failed to load programs, loadCodeSections ret=%d", ret);
-
- return ret;
-}
-
-} // namespace bpf
-} // namespace android
diff --git a/netbpfload/loader.h b/netbpfload/loader.h
deleted file mode 100644
index 4da6830..0000000
--- a/netbpfload/loader.h
+++ /dev/null
@@ -1,94 +0,0 @@
-/*
- * Copyright (C) 2018-2023 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#pragma once
-
-#include <linux/bpf.h>
-
-#include <fstream>
-
-namespace android {
-namespace bpf {
-
-// Bpf programs may specify per-program & per-map selinux_context and pin_subdir.
-//
-// The BpfLoader needs to convert these bpf.o specified strings into an enum
-// for internal use (to check that valid values were specified for the specific
-// location of the bpf.o file).
-//
-// It also needs to map selinux_context's into pin_subdir's.
-// This is because of how selinux_context is actually implemented via pin+rename.
-//
-// Thus 'domain' enumerates all selinux_context's/pin_subdir's that the BpfLoader
-// is aware of. Thus there currently needs to be a 1:1 mapping between the two.
-//
-enum class domain : int {
- unrecognized = -1, // invalid for this version of the bpfloader
- unspecified = 0, // means just use the default for that specific pin location
- tethering, // (S+) fs_bpf_tethering /sys/fs/bpf/tethering
- net_private, // (T+) fs_bpf_net_private /sys/fs/bpf/net_private
- net_shared, // (T+) fs_bpf_net_shared /sys/fs/bpf/net_shared
- netd_readonly, // (T+) fs_bpf_netd_readonly /sys/fs/bpf/netd_readonly
- netd_shared, // (T+) fs_bpf_netd_shared /sys/fs/bpf/netd_shared
-};
-
-// Note: this does not include domain::unrecognized, but does include domain::unspecified
-static constexpr domain AllDomains[] = {
- domain::unspecified,
- domain::tethering,
- domain::net_private,
- domain::net_shared,
- domain::netd_readonly,
- domain::netd_shared,
-};
-
-static constexpr bool unrecognized(domain d) {
- return d == domain::unrecognized;
-}
-
-// Note: this doesn't handle unrecognized, handle it first.
-static constexpr bool specified(domain d) {
- return d != domain::unspecified;
-}
-
-struct Location {
- const char* const dir = "";
- const char* const prefix = "";
-};
-
-// BPF loader implementation. Loads an eBPF ELF object
-int loadProg(const char* elfPath, bool* isCritical, const unsigned int bpfloader_ver,
- const Location &location = {});
-
-// Exposed for testing
-unsigned int readSectionUint(const char* name, std::ifstream& elfFile, unsigned int defVal);
-
-// Returns the build type string (from ro.build.type).
-const std::string& getBuildType();
-
-// The following functions classify the 3 Android build types.
-inline bool isEng() {
- return getBuildType() == "eng";
-}
-inline bool isUser() {
- return getBuildType() == "user";
-}
-inline bool isUserdebug() {
- return getBuildType() == "userdebug";
-}
-
-} // namespace bpf
-} // namespace android
diff --git a/netd/BpfBaseTest.cpp b/netd/BpfBaseTest.cpp
index c979a7b..34dfbb4 100644
--- a/netd/BpfBaseTest.cpp
+++ b/netd/BpfBaseTest.cpp
@@ -56,7 +56,7 @@
TEST_F(BpfBasicTest, TestCgroupMounted) {
std::string cg2_path;
- ASSERT_EQ(true, CgroupGetControllerPath(CGROUPV2_CONTROLLER_NAME, &cg2_path));
+ ASSERT_EQ(true, CgroupGetControllerPath(CGROUPV2_HIERARCHY_NAME, &cg2_path));
ASSERT_EQ(0, access(cg2_path.c_str(), R_OK));
ASSERT_EQ(0, access((cg2_path + "/cgroup.controllers").c_str(), R_OK));
}
diff --git a/networksecurity/OWNERS b/networksecurity/OWNERS
new file mode 100644
index 0000000..1a4130a
--- /dev/null
+++ b/networksecurity/OWNERS
@@ -0,0 +1,4 @@
+# Bug component: 1479456
+
+sandrom@google.com
+tweek@google.com
diff --git a/networksecurity/framework/Android.bp b/networksecurity/framework/Android.bp
new file mode 100644
index 0000000..2b77926
--- /dev/null
+++ b/networksecurity/framework/Android.bp
@@ -0,0 +1,31 @@
+//
+// Copyright (C) 2024 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+package {
+ default_team: "trendy_team_platform_security",
+ default_applicable_licenses: ["Android-Apache-2.0"],
+}
+
+filegroup {
+ name: "framework-networksecurity-sources",
+ srcs: [
+ "src/**/*.java",
+ "src/**/*.aidl",
+ ],
+ path: "src",
+ visibility: [
+ "//packages/modules/Connectivity:__subpackages__",
+ ],
+}
diff --git a/networksecurity/framework/src/android/net/ct/CertificateTransparencyManager.java b/networksecurity/framework/src/android/net/ct/CertificateTransparencyManager.java
new file mode 100644
index 0000000..94521ae
--- /dev/null
+++ b/networksecurity/framework/src/android/net/ct/CertificateTransparencyManager.java
@@ -0,0 +1,40 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package android.net.ct;
+
+import android.annotation.FlaggedApi;
+import android.annotation.SystemService;
+
+import com.android.net.ct.flags.Flags;
+
+/**
+ * Provides the primary API for the Certificate Transparency Manager.
+ *
+ * @hide
+ */
+@FlaggedApi(Flags.FLAG_CERTIFICATE_TRANSPARENCY_SERVICE)
+@SystemService(CertificateTransparencyManager.SERVICE_NAME)
+public final class CertificateTransparencyManager {
+
+ public static final String SERVICE_NAME = "certificate_transparency";
+
+ /**
+ * Creates a new CertificateTransparencyManager instance.
+ *
+ * @hide
+ */
+ public CertificateTransparencyManager() {}
+}
diff --git a/networksecurity/framework/src/android/net/ct/ICertificateTransparencyManager.aidl b/networksecurity/framework/src/android/net/ct/ICertificateTransparencyManager.aidl
new file mode 100644
index 0000000..b5bce7f
--- /dev/null
+++ b/networksecurity/framework/src/android/net/ct/ICertificateTransparencyManager.aidl
@@ -0,0 +1,22 @@
+/**
+ * Copyright (c) 2024, The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package android.net.ct;
+
+/**
+* Interface for communicating with CertificateTransparencyService.
+* @hide
+*/
+interface ICertificateTransparencyManager {}
diff --git a/networksecurity/service/Android.bp b/networksecurity/service/Android.bp
new file mode 100644
index 0000000..66d201a
--- /dev/null
+++ b/networksecurity/service/Android.bp
@@ -0,0 +1,40 @@
+// Copyright (C) 2024 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package {
+ default_team: "trendy_team_platform_security",
+ default_applicable_licenses: ["Android-Apache-2.0"],
+}
+
+// Main lib for Certificate Transparency services.
+java_library {
+ name: "service-networksecurity-pre-jarjar",
+ defaults: ["framework-system-server-module-defaults"],
+ visibility: ["//packages/modules/Connectivity:__subpackages__"],
+
+ srcs: [
+ "src/**/*.java",
+ ],
+
+ libs: [
+ "framework-connectivity-pre-jarjar",
+ "service-connectivity-pre-jarjar",
+ ],
+
+ // This is included in service-connectivity which is 30+
+ // TODO (b/293613362): allow APEXes to have service jars with higher min_sdk than the APEX
+ // (service-connectivity is only used on 31+) and use 31 here
+ min_sdk_version: "30",
+ sdk_version: "system_server_current",
+ apex_available: ["com.android.tethering"],
+}
diff --git a/networksecurity/service/src/com/android/server/net/ct/CertificateTransparencyService.java b/networksecurity/service/src/com/android/server/net/ct/CertificateTransparencyService.java
new file mode 100644
index 0000000..8c53bf7
--- /dev/null
+++ b/networksecurity/service/src/com/android/server/net/ct/CertificateTransparencyService.java
@@ -0,0 +1,56 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.android.server.net.ct;
+
+import android.annotation.RequiresApi;
+import android.content.Context;
+import android.net.ct.ICertificateTransparencyManager;
+import android.os.Build;
+import android.util.Log;
+
+import com.android.net.ct.flags.Flags;
+import com.android.net.module.util.DeviceConfigUtils;
+
+/** Implementation of the Certificate Transparency service. */
+@RequiresApi(Build.VERSION_CODES.VANILLA_ICE_CREAM)
+public class CertificateTransparencyService extends ICertificateTransparencyManager.Stub {
+
+ private static final String TAG = "CertificateTransparency";
+ private static final String CERTIFICATE_TRANSPARENCY_ENABLED =
+ "certificate_transparency_service_enabled";
+
+ /**
+ * @return true if the CertificateTransparency service is enabled.
+ */
+ public static boolean enabled(Context context) {
+ // TODO: replace isTetheringFeatureEnabled with CT namespace flag.
+ return DeviceConfigUtils.isTetheringFeatureEnabled(
+ context, CERTIFICATE_TRANSPARENCY_ENABLED)
+ && Flags.certificateTransparencyService();
+ }
+
+ /** Creates a new {@link CertificateTransparencyService} object. */
+ public CertificateTransparencyService(Context context) {}
+
+ /**
+ * Called by {@link com.android.server.ConnectivityServiceInitializer}.
+ *
+ * @see com.android.server.SystemService#onBootPhase
+ */
+ public void onBootPhase(int phase) {
+ Log.d(TAG, "CertificateTransparencyService#onBootPhase " + phase);
+ }
+}
diff --git a/service-t/Android.bp b/service-t/Android.bp
index 779f354..e00b7cf 100644
--- a/service-t/Android.bp
+++ b/service-t/Android.bp
@@ -59,6 +59,7 @@
"framework-wifi",
"service-connectivity-pre-jarjar",
"service-nearby-pre-jarjar",
+ "service-networksecurity-pre-jarjar",
"service-thread-pre-jarjar",
service_remoteauth_pre_jarjar_lib,
"ServiceConnectivityResources",
diff --git a/service-t/src/com/android/server/ConnectivityServiceInitializer.java b/service-t/src/com/android/server/ConnectivityServiceInitializer.java
index 1ac2f6e..5d23fdc 100644
--- a/service-t/src/com/android/server/ConnectivityServiceInitializer.java
+++ b/service-t/src/com/android/server/ConnectivityServiceInitializer.java
@@ -28,6 +28,7 @@
import com.android.server.ethernet.EthernetService;
import com.android.server.ethernet.EthernetServiceImpl;
import com.android.server.nearby.NearbyService;
+import com.android.server.net.ct.CertificateTransparencyService;
import com.android.server.thread.ThreadNetworkService;
/**
@@ -43,6 +44,7 @@
private final NearbyService mNearbyService;
private final EthernetServiceImpl mEthernetServiceImpl;
private final ThreadNetworkService mThreadNetworkService;
+ private final CertificateTransparencyService mCertificateTransparencyService;
public ConnectivityServiceInitializer(Context context) {
super(context);
@@ -55,6 +57,7 @@
mNsdService = createNsdService(context);
mNearbyService = createNearbyService(context);
mThreadNetworkService = createThreadNetworkService(context);
+ mCertificateTransparencyService = createCertificateTransparencyService(context);
}
@Override
@@ -111,6 +114,10 @@
if (mThreadNetworkService != null) {
mThreadNetworkService.onBootPhase(phase);
}
+
+ if (SdkLevel.isAtLeastV() && mCertificateTransparencyService != null) {
+ mCertificateTransparencyService.onBootPhase(phase);
+ }
}
/**
@@ -186,4 +193,13 @@
}
return new ThreadNetworkService(context);
}
+
+ /** Return CertificateTransparencyService instance if enable, otherwise null. */
+ @Nullable
+ private CertificateTransparencyService createCertificateTransparencyService(
+ final Context context) {
+ return SdkLevel.isAtLeastV() && CertificateTransparencyService.enabled(context)
+ ? new CertificateTransparencyService(context)
+ : null;
+ }
}
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
index ebd95c9..c3cb776 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
@@ -228,12 +228,12 @@
* Create a ServiceRegistration with only update the subType.
*/
ServiceRegistration withSubtypes(@NonNull Set<String> newSubtypes,
- boolean avoidEmptyTxtRecords) {
+ @NonNull MdnsFeatureFlags featureFlags) {
NsdServiceInfo newServiceInfo = new NsdServiceInfo(serviceInfo);
newServiceInfo.setSubtypes(newSubtypes);
return new ServiceRegistration(srvRecord.record.getServiceHost(), newServiceInfo,
repliedServiceCount, sentPacketCount, exiting, isProbing, ttl,
- avoidEmptyTxtRecords);
+ featureFlags);
}
/**
@@ -241,7 +241,7 @@
*/
ServiceRegistration(@NonNull String[] deviceHostname, @NonNull NsdServiceInfo serviceInfo,
int repliedServiceCount, int sentPacketCount, boolean exiting, boolean isProbing,
- @Nullable Duration ttl, boolean avoidEmptyTxtRecords) {
+ @Nullable Duration ttl, @NonNull MdnsFeatureFlags featureFlags) {
this.serviceInfo = serviceInfo;
final long nonNameRecordsTtlMillis;
@@ -313,7 +313,7 @@
true /* cacheFlush */,
nonNameRecordsTtlMillis,
attrsToTextEntries(
- serviceInfo.getAttributes(), avoidEmptyTxtRecords)),
+ serviceInfo.getAttributes(), featureFlags)),
false /* sharedName */);
allRecords.addAll(ptrRecords);
@@ -397,9 +397,9 @@
*/
ServiceRegistration(@NonNull String[] deviceHostname, @NonNull NsdServiceInfo serviceInfo,
int repliedServiceCount, int sentPacketCount, @Nullable Duration ttl,
- boolean avoidEmptyTxtRecords) {
+ @NonNull MdnsFeatureFlags featureFlags) {
this(deviceHostname, serviceInfo,repliedServiceCount, sentPacketCount,
- false /* exiting */, true /* isProbing */, ttl, avoidEmptyTxtRecords);
+ false /* exiting */, true /* isProbing */, ttl, featureFlags);
}
void setProbing(boolean probing) {
@@ -450,7 +450,7 @@
"Service ID must already exist for an update request: " + serviceId);
}
final ServiceRegistration updatedRegistration = existingRegistration.withSubtypes(
- subtypes, mMdnsFeatureFlags.avoidAdvertisingEmptyTxtRecords());
+ subtypes, mMdnsFeatureFlags);
mServices.put(serviceId, updatedRegistration);
}
@@ -482,7 +482,7 @@
final ServiceRegistration registration = new ServiceRegistration(
mDeviceHostname, serviceInfo, NO_PACKET /* repliedServiceCount */,
NO_PACKET /* sentPacketCount */, ttl,
- mMdnsFeatureFlags.avoidAdvertisingEmptyTxtRecords());
+ mMdnsFeatureFlags);
mServices.put(serviceId, registration);
// Remove existing exiting service
@@ -553,14 +553,14 @@
return new MdnsProber.ProbingInfo(serviceId, probingRecords);
}
- private static List<MdnsServiceInfo.TextEntry> attrsToTextEntries(Map<String, byte[]> attrs,
- boolean avoidEmptyTxtRecords) {
+ private static List<MdnsServiceInfo.TextEntry> attrsToTextEntries(
+ @NonNull Map<String, byte[]> attrs, @NonNull MdnsFeatureFlags featureFlags) {
final List<MdnsServiceInfo.TextEntry> out = new ArrayList<>(
attrs.size() == 0 ? 1 : attrs.size());
- if (avoidEmptyTxtRecords && attrs.size() == 0) {
+ if (featureFlags.avoidAdvertisingEmptyTxtRecords() && attrs.size() == 0) {
// As per RFC6763 6.1, empty TXT records are not allowed, but records containing a
// single empty String must be treated as equivalent.
- out.add(new MdnsServiceInfo.TextEntry("", (byte[]) null));
+ out.add(new MdnsServiceInfo.TextEntry("", MdnsServiceInfo.TextEntry.VALUE_NONE));
return out;
}
@@ -1418,7 +1418,7 @@
final ServiceRegistration newService = new ServiceRegistration(mDeviceHostname, newInfo,
existing.repliedServiceCount, existing.sentPacketCount, existing.ttl,
- mMdnsFeatureFlags.avoidAdvertisingEmptyTxtRecords());
+ mMdnsFeatureFlags);
mServices.put(serviceId, newService);
return makeProbingInfo(serviceId, newService);
}
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java
index 7eea93a..a8a4ef1 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java
@@ -233,6 +233,21 @@
}
/**
+ * Remove services which matches the given type and socket.
+ *
+ * @param cacheKey the target CacheKey.
+ */
+ public void removeServices(@NonNull CacheKey cacheKey) {
+ ensureRunningOnHandlerThread(mHandler);
+ // Remove all services
+ if (mCachedServices.remove(cacheKey) == null) {
+ return;
+ }
+ // Update the next expiration check time if services are removed.
+ mNextExpirationTime = getNextExpirationTime(mClock.elapsedRealtime());
+ }
+
+ /**
* Register a callback to listen to service expiration.
*
* <p> Registering the same callback instance twice is a no-op, since MdnsServiceTypeClient
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceInfo.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceInfo.java
index 3fb92bb..a16fcf7 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceInfo.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceInfo.java
@@ -16,8 +16,6 @@
package com.android.server.connectivity.mdns;
-import static com.android.server.connectivity.mdns.MdnsSocket.INTERFACE_INDEX_UNSPECIFIED;
-
import android.annotation.NonNull;
import android.annotation.Nullable;
import android.net.Network;
@@ -33,7 +31,6 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
-import java.util.Locale;
import java.util.Map;
import java.util.TreeMap;
diff --git a/service-t/src/com/android/server/net/NetworkStatsService.java b/service-t/src/com/android/server/net/NetworkStatsService.java
index 114cf2e..11343d2 100644
--- a/service-t/src/com/android/server/net/NetworkStatsService.java
+++ b/service-t/src/com/android/server/net/NetworkStatsService.java
@@ -74,12 +74,12 @@
import static com.android.net.module.util.DeviceConfigUtils.getDeviceConfigPropertyInt;
import static com.android.net.module.util.NetworkCapabilitiesUtils.getDisplayTransport;
import static com.android.net.module.util.NetworkStatsUtils.LIMIT_GLOBAL_ALERT;
-import static com.android.server.net.NetworkStatsEventLogger.POLL_REASON_PERIODIC;
import static com.android.server.net.NetworkStatsEventLogger.POLL_REASON_DUMPSYS;
import static com.android.server.net.NetworkStatsEventLogger.POLL_REASON_FORCE_UPDATE;
import static com.android.server.net.NetworkStatsEventLogger.POLL_REASON_GLOBAL_ALERT;
import static com.android.server.net.NetworkStatsEventLogger.POLL_REASON_NETWORK_STATUS_CHANGED;
import static com.android.server.net.NetworkStatsEventLogger.POLL_REASON_OPEN_SESSION;
+import static com.android.server.net.NetworkStatsEventLogger.POLL_REASON_PERIODIC;
import static com.android.server.net.NetworkStatsEventLogger.POLL_REASON_RAT_CHANGED;
import static com.android.server.net.NetworkStatsEventLogger.POLL_REASON_REG_CALLBACK;
import static com.android.server.net.NetworkStatsEventLogger.POLL_REASON_REMOVE_UIDS;
@@ -242,13 +242,11 @@
// A message for broadcasting ACTION_NETWORK_STATS_UPDATED in handler thread to prevent
// deadlock.
private static final int MSG_BROADCAST_NETWORK_STATS_UPDATED = 4;
-
/** Flags to control detail level of poll event. */
private static final int FLAG_PERSIST_NETWORK = 0x1;
private static final int FLAG_PERSIST_UID = 0x2;
private static final int FLAG_PERSIST_ALL = FLAG_PERSIST_NETWORK | FLAG_PERSIST_UID;
private static final int FLAG_PERSIST_FORCE = 0x100;
-
/**
* When global alert quota is high, wait for this delay before processing each polling,
* and do not schedule further polls once there is already one queued.
@@ -313,6 +311,12 @@
static final String TRAFFIC_STATS_CACHE_MAX_ENTRIES_NAME = "trafficstats_cache_max_entries";
static final int DEFAULT_TRAFFIC_STATS_CACHE_EXPIRY_DURATION_MS = 1000;
static final int DEFAULT_TRAFFIC_STATS_CACHE_MAX_ENTRIES = 400;
+ /**
+ * The delay time between to network stats update intents.
+ * Added to fix intent spams (b/3115462)
+ */
+ @VisibleForTesting(visibility = PRIVATE)
+ static final int BROADCAST_NETWORK_STATS_UPDATED_DELAY_MS = 1000;
private final Context mContext;
private final NetworkStatsFactory mStatsFactory;
@@ -385,6 +389,7 @@
long getXtPersistBytes(long def);
long getUidPersistBytes(long def);
long getUidTagPersistBytes(long def);
+ long getBroadcastNetworkStatsUpdateDelayMs();
}
private final Object mStatsLock = new Object();
@@ -469,15 +474,36 @@
private long mLastStatsSessionPoll;
+ /**
+ * The timestamp of the most recent network stats broadcast.
+ *
+ * Note that this time could be in the past for completed broadcasts,
+ * or in the future for scheduled broadcasts.
+ *
+ * It is initialized to {@code Long.MIN_VALUE} to ensure that the first broadcast request
+ * is fulfilled immediately, regardless of the delay time.
+ *
+ * This value is used to enforce rate limiting on intents, preventing intent spam.
+ */
+ @GuardedBy("mStatsLock")
+ private long mLatestNetworkStatsUpdatedBroadcastScheduledTime = Long.MIN_VALUE;
+
+
private final TrafficStatsRateLimitCache mTrafficStatsTotalCache;
private final TrafficStatsRateLimitCache mTrafficStatsIfaceCache;
private final TrafficStatsRateLimitCache mTrafficStatsUidCache;
static final String TRAFFICSTATS_RATE_LIMIT_CACHE_ENABLED_FLAG =
"trafficstats_rate_limit_cache_enabled_flag";
+ static final String BROADCAST_NETWORK_STATS_UPDATED_RATE_LIMIT_ENABLED_FLAG =
+ "broadcast_network_stats_updated_rate_limit_enabled_flag";
private final boolean mAlwaysUseTrafficStatsRateLimitCache;
private final int mTrafficStatsRateLimitCacheExpiryDuration;
private final int mTrafficStatsRateLimitCacheMaxEntries;
+ private final boolean mBroadcastNetworkStatsUpdatedRateLimitEnabled;
+
+
+
private final Object mOpenSessionCallsLock = new Object();
/**
@@ -669,6 +695,8 @@
mAlwaysUseTrafficStatsRateLimitCache =
mDeps.alwaysUseTrafficStatsRateLimitCache(mContext);
+ mBroadcastNetworkStatsUpdatedRateLimitEnabled =
+ mDeps.enabledBroadcastNetworkStatsUpdatedRateLimiting(mContext);
mTrafficStatsRateLimitCacheExpiryDuration =
mDeps.getTrafficStatsRateLimitCacheExpiryDuration();
mTrafficStatsRateLimitCacheMaxEntries =
@@ -696,6 +724,15 @@
@VisibleForTesting
public static class Dependencies {
/**
+ * Get broadcast network stats updated delay time in ms
+ * @return
+ */
+ @NonNull
+ public long getBroadcastNetworkStatsUpdateDelayMs() {
+ return BROADCAST_NETWORK_STATS_UPDATED_DELAY_MS;
+ }
+
+ /**
* Get legacy platform stats directory.
*/
@NonNull
@@ -927,6 +964,17 @@
}
/**
+ * Get whether broadcast network stats update rate limiting is enabled.
+ *
+ * This method should only be called once in the constructor,
+ * to ensure that the code does not need to deal with flag values changing at runtime.
+ */
+ public boolean enabledBroadcastNetworkStatsUpdatedRateLimiting(Context ctx) {
+ return DeviceConfigUtils.isTetheringFeatureNotChickenedOut(
+ ctx, BROADCAST_NETWORK_STATS_UPDATED_RATE_LIMIT_ENABLED_FLAG);
+ }
+
+ /**
* Get whether TrafficStats rate-limit cache is always applied.
*
* This method should only be called once in the constructor,
@@ -2645,8 +2693,22 @@
performSampleLocked();
}
- // finally, dispatch updated event to any listeners
- mHandler.sendMessage(mHandler.obtainMessage(MSG_BROADCAST_NETWORK_STATS_UPDATED));
+ // Dispatch updated event to listeners, preventing intent spamming
+ // (b/343844995) possibly from abnormal modem RAT changes or misbehaving
+ // app calls (see NetworkStatsEventLogger#POLL_REASON_* for possible reasons).
+ // If no broadcasts are scheduled, use the time of the last broadcast
+ // to schedule the next one ASAP.
+ if (!mBroadcastNetworkStatsUpdatedRateLimitEnabled) {
+ mHandler.sendMessage(mHandler.obtainMessage(MSG_BROADCAST_NETWORK_STATS_UPDATED));
+ } else if (mLatestNetworkStatsUpdatedBroadcastScheduledTime < SystemClock.uptimeMillis()) {
+ mLatestNetworkStatsUpdatedBroadcastScheduledTime = Math.max(
+ mLatestNetworkStatsUpdatedBroadcastScheduledTime
+ + mSettings.getBroadcastNetworkStatsUpdateDelayMs(),
+ SystemClock.uptimeMillis()
+ );
+ mHandler.sendMessageAtTime(mHandler.obtainMessage(MSG_BROADCAST_NETWORK_STATS_UPDATED),
+ mLatestNetworkStatsUpdatedBroadcastScheduledTime);
+ }
Trace.traceEnd(TRACE_TAG_NETWORK);
}
@@ -3605,6 +3667,11 @@
public long getUidTagPersistBytes(long def) {
return def;
}
+
+ @Override
+ public long getBroadcastNetworkStatsUpdateDelayMs() {
+ return BROADCAST_NETWORK_STATS_UPDATED_DELAY_MS;
+ }
}
// TODO: Read stats by using BpfNetMapsReader.
diff --git a/service/Android.bp b/service/Android.bp
index 1a0e045..c68f0b8 100644
--- a/service/Android.bp
+++ b/service/Android.bp
@@ -206,6 +206,7 @@
},
visibility: [
"//packages/modules/Connectivity/service-t",
+ "//packages/modules/Connectivity/networksecurity:__subpackages__",
"//packages/modules/Connectivity/tests:__subpackages__",
"//packages/modules/Connectivity/thread/service:__subpackages__",
"//packages/modules/Connectivity/thread/tests:__subpackages__",
@@ -247,6 +248,7 @@
"service-connectivity-pre-jarjar",
"service-connectivity-tiramisu-pre-jarjar",
"service-nearby-pre-jarjar",
+ "service-networksecurity-pre-jarjar",
service_remoteauth_pre_jarjar_lib,
"service-thread-pre-jarjar",
],
@@ -315,6 +317,7 @@
":framework-connectivity-jarjar-rules",
":service-connectivity-jarjar-gen",
":service-nearby-jarjar-gen",
+ ":service-networksecurity-jarjar-gen",
":service-remoteauth-jarjar-gen",
":service-thread-jarjar-gen",
],
@@ -404,6 +407,24 @@
visibility: ["//visibility:private"],
}
+java_genrule {
+ name: "service-networksecurity-jarjar-gen",
+ tool_files: [
+ ":service-networksecurity-pre-jarjar{.jar}",
+ "jarjar-excludes.txt",
+ ],
+ tools: [
+ "jarjar-rules-generator",
+ ],
+ out: ["service_ct_jarjar_rules.txt"],
+ cmd: "$(location jarjar-rules-generator) " +
+ "$(location :service-networksecurity-pre-jarjar{.jar}) " +
+ "--prefix com.android.server.net.ct " +
+ "--excludes $(location jarjar-excludes.txt) " +
+ "--output $(out)",
+ visibility: ["//visibility:private"],
+}
+
genrule {
name: "statslog-connectivity-java-gen",
tools: ["stats-log-api-gen"],
diff --git a/service/src/com/android/server/CallbackQueue.java b/service/src/com/android/server/CallbackQueue.java
index 060a984..4e068ea 100644
--- a/service/src/com/android/server/CallbackQueue.java
+++ b/service/src/com/android/server/CallbackQueue.java
@@ -34,8 +34,7 @@
* queue.forEach(netId, callbackId -> { [...] });
* queue.addCallback(netId, callbackId);
* [...]
- * queue.shrinkToLength();
- * storedCallbacks = queue.getBackingArray();
+ * storedCallbacks = queue.getMinimizedBackingArray();
* </pre>
*
* <p>This class is not thread-safe.
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index 953fd76..0fe24a2 100755
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -56,6 +56,7 @@
import static android.net.ConnectivityManager.FIREWALL_RULE_ALLOW;
import static android.net.ConnectivityManager.FIREWALL_RULE_DEFAULT;
import static android.net.ConnectivityManager.FIREWALL_RULE_DENY;
+import static android.net.ConnectivityManager.NETID_UNSET;
import static android.net.ConnectivityManager.NetworkCallback.DECLARED_METHODS_ALL;
import static android.net.ConnectivityManager.NetworkCallback.DECLARED_METHODS_NONE;
import static android.net.ConnectivityManager.TYPE_BLUETOOTH;
@@ -2147,7 +2148,7 @@
}
@VisibleForTesting
- void updateMobileDataPreferredUids() {
+ public void updateMobileDataPreferredUids() {
mHandler.sendEmptyMessage(EVENT_MOBILE_DATA_PREFERRED_UIDS_CHANGED);
}
@@ -3403,7 +3404,7 @@
}
@VisibleForTesting
- void handleBlockedReasonsChanged(List<Pair<Integer, Integer>> reasonsList) {
+ public void handleBlockedReasonsChanged(List<Pair<Integer, Integer>> reasonsList) {
for (Pair<Integer, Integer> reasons: reasonsList) {
final int uid = reasons.first;
final int blockedReasons = reasons.second;
@@ -3472,7 +3473,7 @@
private void handleFrozenUids(int[] uids, int[] frozenStates) {
ensureRunningOnConnectivityServiceThread();
handleDestroyFrozenSockets(uids, frozenStates);
- // TODO: handle freezing NetworkCallbacks
+ handleFreezeNetworkCallbacks(uids, frozenStates);
}
private void handleDestroyFrozenSockets(int[] uids, int[] frozenStates) {
@@ -3490,6 +3491,73 @@
}
}
+ private void handleFreezeNetworkCallbacks(int[] uids, int[] frozenStates) {
+ if (!mQueueCallbacksForFrozenApps) {
+ return;
+ }
+ for (int i = 0; i < uids.length; i++) {
+ final int uid = uids[i];
+ // These counters may be modified on different threads, but using them here is fine
+ // because this is only an optimization where wrong behavior would only happen if they
+ // are zero even though there is a request registered. This is not possible as they are
+ // always incremented before posting messages to register, and decremented on the
+ // handler thread when unregistering.
+ if (mSystemNetworkRequestCounter.get(uid) == 0
+ && mNetworkRequestCounter.get(uid) == 0) {
+ // Avoid iterating requests if there isn't any. The counters only track app requests
+ // and not internal requests (for example always-on requests which do not have a
+ // mMessenger), so it does not completely match the content of mRequests. This is OK
+ // as only app requests need to be frozen.
+ continue;
+ }
+
+ if (frozenStates[i] == UID_FROZEN_STATE_FROZEN) {
+ freezeNetworkCallbacksForUid(uid);
+ } else {
+ unfreezeNetworkCallbacksForUid(uid);
+ }
+ }
+ }
+
+ /**
+ * Suspend callbacks for a UID that was just frozen.
+ *
+ * <p>Note that it is not possible for a process to be frozen during a blocking binder call
+ * (see CachedAppOptimizer.freezeBinder), and IConnectivityManager callback registrations are
+ * blocking binder calls, so no callback can be registered while the UID is frozen. This means
+ * it is not necessary to check frozen state on new callback registrations, and calling this
+ * method when a UID is newly frozen is sufficient.
+ *
+ * <p>If it ever becomes possible for a process to be frozen during a blocking binder call,
+ * ConnectivityService will need to handle freezing callbacks that reach ConnectivityService
+ * after the app was frozen when being registered.
+ */
+ private void freezeNetworkCallbacksForUid(int uid) {
+ if (DDBG) Log.d(TAG, "Freezing callbacks for UID " + uid);
+ for (NetworkRequestInfo nri : mNetworkRequests.values()) {
+ if (nri.mUid != uid) continue;
+ // mNetworkRequests can have duplicate values for multilayer requests, but calling
+ // onFrozen multiple times is fine.
+ // If freezeNetworkCallbacksForUid was called multiple times in a raw for a frozen UID
+ // (which would be incorrect), this would also handle it gracefully.
+ nri.onFrozen();
+ }
+ }
+
+ private void unfreezeNetworkCallbacksForUid(int uid) {
+ // This sends all callbacks for one NetworkRequest at a time, which may not be the
+ // same order they were queued in, but different network requests use different
+ // binder objects, so the relative order of their callbacks is not guaranteed.
+ // If callbacks are not queued, callbacks from different binder objects may be
+ // posted on different threads when the process is unfrozen, so even if they were
+ // called a long time apart while the process was frozen, they may still appear in
+ // different order when unfreezing it.
+ for (NetworkRequestInfo nri : mNetworkRequests.values()) {
+ if (nri.mUid != uid) continue;
+ nri.sendQueuedCallbacks();
+ }
+ }
+
private void handleUpdateFirewallDestroySocketReasons(
List<Pair<Integer, Integer>> reasonsList) {
if (!shouldTrackFirewallDestroySocketReasons()) {
@@ -7544,6 +7612,29 @@
// single NetworkRequest in mRequests.
final List<NetworkRequest> mRequests;
+ /**
+ * List of callbacks that are queued for sending later when the requesting app is unfrozen.
+ *
+ * <p>There may typically be hundreds of NetworkRequestInfo, so a memory-efficient structure
+ * (just an int[]) is used to keep queued callbacks. This reduces the number of object
+ * references.
+ *
+ * <p>This is intended to be used with {@link CallbackQueue} which defines the internal
+ * format.
+ */
+ @NonNull
+ private int[] mQueuedCallbacks = new int[0];
+
+ private static final int MATCHED_NETID_NOT_FROZEN = -1;
+
+ /**
+ * If this request was already satisfied by a network when the requesting UID was frozen,
+ * the netId that was matched at that time. Otherwise, NETID_UNSET if no network was
+ * satisfying this request when frozen (including if this is a listen and not a request),
+ * and MATCHED_NETID_NOT_FROZEN if not frozen.
+ */
+ private int mMatchedNetIdWhenFrozen = MATCHED_NETID_NOT_FROZEN;
+
// mSatisfier and mActiveRequest rely on one another therefore set them together.
void setSatisfier(
@Nullable final NetworkAgentInfo satisfier,
@@ -7715,6 +7806,8 @@
}
setSatisfier(satisfier, activeRequest);
}
+ mMatchedNetIdWhenFrozen = nri.mMatchedNetIdWhenFrozen;
+ mQueuedCallbacks = nri.mQueuedCallbacks;
mMessenger = nri.mMessenger;
mBinder = nri.mBinder;
mPid = nri.mPid;
@@ -7779,11 +7872,190 @@
}
}
+ /**
+ * Called when this NRI is being frozen.
+ *
+ * <p>Calling this method multiple times when the NRI is frozen is fine. This may happen
+ * if iterating through the NetworkRequest -> NRI map since there are duplicates in the
+ * NRI values for multilayer requests. It may also happen if an app is frozen, killed,
+ * restarted and refrozen since there is no callback sent when processes are killed, but in
+ * that case the callbacks to the killed app do not matter.
+ */
+ void onFrozen() {
+ if (mMatchedNetIdWhenFrozen != MATCHED_NETID_NOT_FROZEN) {
+ // Already frozen
+ return;
+ }
+ if (mSatisfier != null) {
+ mMatchedNetIdWhenFrozen = mSatisfier.network.netId;
+ } else {
+ mMatchedNetIdWhenFrozen = NETID_UNSET;
+ }
+ }
+
+ boolean maybeQueueCallback(@NonNull NetworkAgentInfo nai, int callbackId) {
+ if (mMatchedNetIdWhenFrozen == MATCHED_NETID_NOT_FROZEN) {
+ return false;
+ }
+
+ boolean ignoreThisCallback = false;
+ final int netId = nai.network.netId;
+ final CallbackQueue queue = new CallbackQueue(mQueuedCallbacks);
+ // Based on the new callback, clear previous callbacks that are no longer necessary.
+ // For example, if the network is lost, there is no need to send intermediate callbacks.
+ switch (callbackId) {
+ // PRECHECK is not an API and not very meaningful, do not deliver it for frozen apps
+ // Networks are likely to already be lost when the app is unfrozen, also skip LOSING
+ case CALLBACK_PRECHECK:
+ case CALLBACK_LOSING:
+ ignoreThisCallback = true;
+ break;
+ case CALLBACK_LOST:
+ // All callbacks for this netId before onLost are unnecessary. And onLost itself
+ // is also unnecessary if onAvailable was previously queued for this netId: the
+ // Network just appeared and disappeared while the app was frozen.
+ ignoreThisCallback = queue.hasCallback(netId, CALLBACK_AVAILABLE);
+ queue.removeCallbacksForNetId(netId);
+ break;
+ case CALLBACK_AVAILABLE:
+ if (mSatisfier != null) {
+ // For requests that are satisfied by individual networks (not LISTEN), when
+ // AVAILABLE is received, the request is matching a new Network, so previous
+ // callbacks (for other Networks) are unnecessary.
+ queue.clear();
+ }
+ break;
+ case CALLBACK_SUSPENDED:
+ case CALLBACK_RESUMED:
+ if (queue.hasCallback(netId, CALLBACK_AVAILABLE)) {
+ // AVAILABLE will already send the latest suspended status
+ ignoreThisCallback = true;
+ break;
+ }
+ // If SUSPENDED was queued, just remove it from the queue instead of sending
+ // RESUMED; and vice-versa.
+ final int otherCb = callbackId == CALLBACK_SUSPENDED
+ ? CALLBACK_RESUMED
+ : CALLBACK_SUSPENDED;
+ ignoreThisCallback = queue.removeCallbacks(netId, otherCb);
+ break;
+ case CALLBACK_CAP_CHANGED:
+ case CALLBACK_IP_CHANGED:
+ case CALLBACK_LOCAL_NETWORK_INFO_CHANGED:
+ case CALLBACK_BLK_CHANGED:
+ ignoreThisCallback = queue.hasCallback(netId, CALLBACK_AVAILABLE);
+ break;
+ default:
+ Log.wtf(TAG, "Unexpected callback type: "
+ + ConnectivityManager.getCallbackName(callbackId));
+ return false;
+ }
+
+ if (!ignoreThisCallback) {
+ // For non-listen (matching) callbacks, AVAILABLE can appear in the queue twice in a
+ // row for the same network if the new AVAILABLE suppressed intermediate AVAILABLEs
+ // for other networks. Example:
+ // A is matched, app is frozen, B is matched, A is matched again (removes callbacks
+ // for B), app is unfrozen.
+ // In that case call AVAILABLE sub-callbacks to update state, but not AVAILABLE
+ // itself.
+ if (callbackId == CALLBACK_AVAILABLE && netId == mMatchedNetIdWhenFrozen) {
+ // The queue should have been cleared here, since this is AVAILABLE on a
+ // non-listen callback (mMatchedNetIdWhenFrozen is set).
+ addAvailableSubCallbacks(nai, queue);
+ } else {
+ // When unfreezing, no need to send a callback multiple times for the same netId
+ queue.removeCallbacks(netId, callbackId);
+ // TODO: this code always adds the callback for simplicity. It would save
+ // some CPU/memory if the code instead only added to the queue callbacks where
+ // isCallbackOverridden=true, or which need to be in the queue because they
+ // affect other callbacks that are overridden.
+ queue.addCallback(netId, callbackId);
+ }
+ }
+ // Instead of shrinking the queue, possibly reallocating, the NRI could keep the array
+ // and length in memory for future adds, but this saves memory by avoiding the cost
+ // of an extra member and of unused array length (there are often hundreds of NRIs).
+ mQueuedCallbacks = queue.getMinimizedBackingArray();
+ return true;
+ }
+
+ /**
+ * Called when this NRI is being unfrozen to stop queueing, and send queued callbacks.
+ *
+ * <p>Calling this method multiple times when the NRI is unfrozen (for example iterating
+ * through the NetworkRequest -> NRI map where there are duplicate values for multilayer
+ * requests) is fine.
+ */
+ void sendQueuedCallbacks() {
+ mMatchedNetIdWhenFrozen = MATCHED_NETID_NOT_FROZEN;
+ if (mQueuedCallbacks.length == 0) {
+ return;
+ }
+ new CallbackQueue(mQueuedCallbacks).forEach((netId, callbackId) -> {
+ // For CALLBACK_LOST only, there will not be a NAI for the netId. Build and send the
+ // callback directly.
+ if (callbackId == CALLBACK_LOST) {
+ if (isCallbackOverridden(CALLBACK_LOST)) {
+ final Bundle cbBundle = makeCommonBundleForCallback(this,
+ new Network(netId));
+ callCallbackForRequest(this, CALLBACK_LOST, cbBundle, 0 /* arg1 */);
+ }
+ return; // Next item in forEach
+ }
+
+ // Other callbacks should always have a NAI, because if a Network disconnects
+ // LOST will be called, unless the request is no longer satisfied by that Network in
+ // which case AVAILABLE will have been called for another Network. In both cases
+ // previous callbacks are cleared.
+ final NetworkAgentInfo nai = getNetworkAgentInfoForNetId(netId);
+ if (nai == null) {
+ Log.wtf(TAG, "Missing NetworkAgentInfo for net " + netId
+ + " for callback " + callbackId);
+ return; // Next item in forEach
+ }
+
+ final int arg1 =
+ callbackId == CALLBACK_AVAILABLE || callbackId == CALLBACK_BLK_CHANGED
+ ? getBlockedState(nai, mAsUid)
+ : 0;
+ callCallbackForRequest(this, nai, callbackId, arg1);
+ });
+ mQueuedCallbacks = new int[0];
+ }
+
boolean isCallbackOverridden(int callbackId) {
return !mUseDeclaredMethodsForCallbacksEnabled
|| (mDeclaredMethodsFlags & (1 << callbackId)) != 0;
}
+ /**
+ * Queue all callbacks that are called by AVAILABLE, except onAvailable.
+ *
+ * <p>AVAILABLE may call SUSPENDED, CAP_CHANGED, IP_CHANGED, LOCAL_NETWORK_INFO_CHANGED,
+ * and BLK_CHANGED, in this order.
+ */
+ private void addAvailableSubCallbacks(
+ @NonNull NetworkAgentInfo nai, @NonNull CallbackQueue queue) {
+ final boolean callSuspended =
+ !nai.networkCapabilities.hasCapability(NET_CAPABILITY_NOT_SUSPENDED);
+ final boolean callLocalInfoChanged = nai.isLocalNetwork();
+
+ final int cbCount = 3 + (callSuspended ? 1 : 0) + (callLocalInfoChanged ? 1 : 0);
+ // Avoid unnecessary re-allocations by reserving enough space for all callbacks to add.
+ queue.ensureHasCapacity(cbCount);
+ final int netId = nai.network.netId;
+ if (callSuspended) {
+ queue.addCallback(netId, CALLBACK_SUSPENDED);
+ }
+ queue.addCallback(netId, CALLBACK_CAP_CHANGED);
+ queue.addCallback(netId, CALLBACK_IP_CHANGED);
+ if (callLocalInfoChanged) {
+ queue.addCallback(netId, CALLBACK_LOCAL_NETWORK_INFO_CHANGED);
+ }
+ queue.addCallback(netId, CALLBACK_BLK_CHANGED);
+ }
+
boolean hasHigherOrderThan(@NonNull final NetworkRequestInfo target) {
// Compare two preference orders.
return mPreferenceOrder < target.mPreferenceOrder;
@@ -9400,8 +9672,10 @@
* interfaces.
* Ingress discard rule is added to the address iff
* 1. The address is not a link local address
- * 2. The address is used by a single VPN interface and not used by any other
- * interfaces even non-VPN ones
+ * 2. The address is used by a single interface of VPN whose VPN type is not TYPE_VPN_LEGACY
+ * or TYPE_VPN_OEM and the address is not used by any other interfaces even non-VPN ones
+ * Ingress discard rule is not be added to TYPE_VPN_LEGACY or TYPE_VPN_OEM VPN since these VPNs
+ * might need to receive packet to VPN address via non-VPN interface.
* This method can be called during network disconnects, when nai has already been removed from
* mNetworkAgentInfos.
*
@@ -9436,7 +9710,10 @@
// for different network.
final Set<Pair<InetAddress, String>> ingressDiscardRules = new ArraySet<>();
for (final NetworkAgentInfo agent : nais) {
- if (!agent.isVPN() || agent.isDestroyed()) {
+ final int vpnType = getVpnType(agent);
+ if (!agent.isVPN() || agent.isDestroyed()
+ || vpnType == VpnManager.TYPE_VPN_LEGACY
+ || vpnType == VpnManager.TYPE_VPN_OEM) {
continue;
}
final LinkProperties agentLp = (nai == agent) ? lp : agent.linkProperties;
@@ -10277,6 +10554,11 @@
// are Type.LISTEN, but should not have NetworkCallbacks invoked.
return;
}
+ // Even if a callback ends up not being sent, it may affect other callbacks in the queue, so
+ // queue callbacks before checking the declared methods flags.
+ if (networkAgent != null && nri.maybeQueueCallback(networkAgent, notificationType)) {
+ return;
+ }
if (!nri.isCallbackOverridden(notificationType)) {
// No need to send the notification as the recipient method is not overridden
return;
diff --git a/staticlibs/Android.bp b/staticlibs/Android.bp
index de7f2d4..94a2411 100644
--- a/staticlibs/Android.bp
+++ b/staticlibs/Android.bp
@@ -296,6 +296,9 @@
"framework-connectivity-t.stubs.module_lib",
"framework-location.stubs.module_lib",
],
+ static_libs: [
+ "modules-utils-build",
+ ],
jarjar_rules: "jarjar-rules-shared.txt",
visibility: [
"//cts/tests/tests/net",
@@ -450,6 +453,7 @@
visibility: ["//packages/modules/Connectivity/service-t"],
}
+// net-utils-framework-connectivity is only for framework-connectivity.
java_library {
name: "net-utils-framework-connectivity",
srcs: [
@@ -462,8 +466,7 @@
"//apex_available:platform",
],
visibility: [
- "//packages/modules/Connectivity:__subpackages__",
- "//packages/modules/NetworkStack:__subpackages__",
+ "//packages/modules/Connectivity/framework",
],
libs: [
"androidx.annotation_annotation",
@@ -502,9 +505,6 @@
"com.android.tethering",
"//apex_available:platform",
],
- visibility: [
- "//packages/modules/Connectivity:__subpackages__",
- ],
defaults_visibility: [
"//visibility:private",
],
@@ -514,6 +514,7 @@
},
}
+// net-utils-service-connectivity is only for service-connectivity.
java_library {
name: "net-utils-service-connectivity",
srcs: [
@@ -527,16 +528,24 @@
],
defaults: ["net-utils-non-bootclasspath-defaults"],
jarjar_rules: "jarjar-rules-shared.txt",
+ visibility: [
+ "//packages/modules/Connectivity/service",
+ "//packages/modules/Connectivity/staticlibs/tests/unit",
+ ],
}
java_library {
- name: "net-utils-tethering",
+ name: "net-utils-connectivity-apks",
srcs: [
":net-utils-all-srcs",
":framework-connectivity-shared-srcs",
],
defaults: ["net-utils-non-bootclasspath-defaults"],
jarjar_rules: "jarjar-rules-shared.txt",
+ visibility: [
+ "//packages/modules/CaptivePortalLogin:__subpackages__",
+ "//packages/modules/Connectivity/Tethering",
+ ],
}
aidl_interface {
@@ -552,6 +561,7 @@
apex_available: [
"com.android.tethering",
"com.android.wifi",
+ "//apex_available:platform",
],
},
cpp: {
diff --git a/staticlibs/device/com/android/net/module/util/GrowingIntArray.java b/staticlibs/device/com/android/net/module/util/GrowingIntArray.java
index 4a81c10..d47738b 100644
--- a/staticlibs/device/com/android/net/module/util/GrowingIntArray.java
+++ b/staticlibs/device/com/android/net/module/util/GrowingIntArray.java
@@ -168,7 +168,7 @@
* stop using this instance of {@link GrowingIntArray} if they use the array returned by this
* method.
*/
- public int[] getShrinkedBackingArray() {
+ public int[] getMinimizedBackingArray() {
shrinkToLength();
return mValues;
}
diff --git a/staticlibs/framework/com/android/net/module/util/NetworkStatsUtils.java b/staticlibs/framework/com/android/net/module/util/NetworkStatsUtils.java
index 41a9428..04b6174 100644
--- a/staticlibs/framework/com/android/net/module/util/NetworkStatsUtils.java
+++ b/staticlibs/framework/com/android/net/module/util/NetworkStatsUtils.java
@@ -19,6 +19,9 @@
import android.app.usage.NetworkStats;
import com.android.internal.annotations.VisibleForTesting;
+import com.android.modules.utils.build.SdkLevel;
+
+import java.util.ArrayList;
/**
* Various utilities used for NetworkStats related code.
@@ -105,12 +108,21 @@
*/
public static android.net.NetworkStats fromPublicNetworkStats(
NetworkStats publiceNetworkStats) {
- android.net.NetworkStats stats = new android.net.NetworkStats(0L, 0);
+ final ArrayList<android.net.NetworkStats.Entry> entries = new ArrayList<>();
while (publiceNetworkStats.hasNextBucket()) {
NetworkStats.Bucket bucket = new NetworkStats.Bucket();
publiceNetworkStats.getNextBucket(bucket);
- final android.net.NetworkStats.Entry entry = fromBucket(bucket);
- stats = stats.addEntry(entry);
+ entries.add(fromBucket(bucket));
+ }
+ android.net.NetworkStats stats = new android.net.NetworkStats(0L, 0);
+ // The new API is only supported on devices running the mainline version of `NetworkStats`.
+ // It should always be used when available for memory efficiency.
+ if (SdkLevel.isAtLeastT()) {
+ stats = stats.addEntries(entries);
+ } else {
+ for (android.net.NetworkStats.Entry entry : entries) {
+ stats = stats.addEntry(entry);
+ }
}
return stats;
}
diff --git a/staticlibs/framework/com/android/net/module/util/PerUidCounter.java b/staticlibs/framework/com/android/net/module/util/PerUidCounter.java
index 463b0c4..98d91a5 100644
--- a/staticlibs/framework/com/android/net/module/util/PerUidCounter.java
+++ b/staticlibs/framework/com/android/net/module/util/PerUidCounter.java
@@ -87,7 +87,9 @@
}
}
- @VisibleForTesting
+ /**
+ * Get the current counter value for the given uid.
+ */
public synchronized int get(int uid) {
return mUidToCount.get(uid, 0);
}
diff --git a/staticlibs/tests/unit/src/com/android/net/module/util/GrowingIntArrayTest.kt b/staticlibs/tests/unit/src/com/android/net/module/util/GrowingIntArrayTest.kt
index bdcb8c0..4b740e3 100644
--- a/staticlibs/tests/unit/src/com/android/net/module/util/GrowingIntArrayTest.kt
+++ b/staticlibs/tests/unit/src/com/android/net/module/util/GrowingIntArrayTest.kt
@@ -109,12 +109,12 @@
}
@Test
- fun testGetShrinkedBackingArray() {
+ fun testGetMinimizedBackingArray() {
val array = GrowingIntArray(10)
array.add(-1)
array.add(2)
- assertContentEquals(intArrayOf(-1, 2), array.shrinkedBackingArray)
+ assertContentEquals(intArrayOf(-1, 2), array.minimizedBackingArray)
}
@Test
diff --git a/staticlibs/tests/unit/src/com/android/net/module/util/NetworkStatsUtilsTest.kt b/staticlibs/tests/unit/src/com/android/net/module/util/NetworkStatsUtilsTest.kt
index 2785ea9..97d3c19 100644
--- a/staticlibs/tests/unit/src/com/android/net/module/util/NetworkStatsUtilsTest.kt
+++ b/staticlibs/tests/unit/src/com/android/net/module/util/NetworkStatsUtilsTest.kt
@@ -17,15 +17,27 @@
package com.android.net.module.util
import android.net.NetworkStats
-import android.text.TextUtils
+import android.net.NetworkStats.DEFAULT_NETWORK_NO
+import android.net.NetworkStats.DEFAULT_NETWORK_YES
+import android.net.NetworkStats.Entry
+import android.net.NetworkStats.METERED_NO
+import android.net.NetworkStats.ROAMING_NO
+import android.net.NetworkStats.ROAMING_YES
+import android.net.NetworkStats.SET_DEFAULT
+import android.net.NetworkStats.TAG_NONE
import androidx.test.filters.SmallTest
import androidx.test.runner.AndroidJUnit4
-import org.junit.Test
-import org.junit.runner.RunWith
+import com.android.testutils.assertEntryEquals
+import com.android.testutils.assertNetworkStatsEquals
+import com.android.testutils.makePublicStatsFromAndroidNetStats
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
+import org.junit.Test
+import org.junit.runner.RunWith
import org.mockito.Mockito.doReturn
import org.mockito.Mockito.mock
+import org.mockito.Mockito.`when`
+
@RunWith(AndroidJUnit4::class)
@SmallTest
@@ -90,11 +102,40 @@
NetworkStats.ROAMING_NO, NetworkStats.DEFAULT_NETWORK_ALL, 1024, 8, 2048, 12,
0 /* operations */)
- // TODO: Use assertEquals once all downstreams accept null iface in
- // NetworkStats.Entry#equals.
assertEntryEquals(expectedEntry, entry)
}
+ @Test
+ fun testPublicStatsToAndroidNetStats() {
+ val uid1 = 10001
+ val uid2 = 10002
+ val testIface = "wlan0"
+ val testAndroidNetStats = NetworkStats(0L, 3)
+ .addEntry(Entry(testIface, uid1, SET_DEFAULT, TAG_NONE,
+ METERED_NO, ROAMING_NO, DEFAULT_NETWORK_YES, 20, 3, 57, 40, 3))
+ .addEntry(Entry(
+ testIface, uid2, SET_DEFAULT, TAG_NONE,
+ METERED_NO, ROAMING_YES, DEFAULT_NETWORK_NO, 2, 7, 2, 5, 7))
+ .addEntry(Entry(testIface, uid2, SET_DEFAULT, TAG_NONE,
+ METERED_NO, ROAMING_YES, DEFAULT_NETWORK_NO, 4, 5, 3, 1, 8))
+ val publicStats: android.app.usage.NetworkStats =
+ makePublicStatsFromAndroidNetStats(testAndroidNetStats)
+ val androidNetStats: NetworkStats =
+ NetworkStatsUtils.fromPublicNetworkStats(publicStats)
+
+ // 1. The public `NetworkStats` class does not include interface information.
+ // Interface details must be removed and items with duplicated
+ // keys need to be merged before making any comparisons.
+ // 2. The public `NetworkStats` class lacks an operations field.
+ // Thus, the information will not be preserved during the conversion.
+ val expectedStats = NetworkStats(0L, 2)
+ .addEntry(Entry(null, uid1, SET_DEFAULT, TAG_NONE,
+ METERED_NO, ROAMING_NO, DEFAULT_NETWORK_YES, 20, 3, 57, 40, 0))
+ .addEntry(Entry(null, uid2, SET_DEFAULT, TAG_NONE,
+ METERED_NO, ROAMING_YES, DEFAULT_NETWORK_NO, 6, 12, 5, 6, 0))
+ assertNetworkStatsEquals(expectedStats, androidNetStats)
+ }
+
private fun makeMockBucket(
uid: Int,
tag: Int,
@@ -121,22 +162,4 @@
doReturn(txPackets).`when`(ret).getTxPackets()
return ret
}
-
- /**
- * Assert that the two {@link NetworkStats.Entry} are equals.
- */
- private fun assertEntryEquals(left: NetworkStats.Entry, right: NetworkStats.Entry) {
- TextUtils.equals(left.iface, right.iface)
- assertEquals(left.uid, right.uid)
- assertEquals(left.set, right.set)
- assertEquals(left.tag, right.tag)
- assertEquals(left.metered, right.metered)
- assertEquals(left.roaming, right.roaming)
- assertEquals(left.defaultNetwork, right.defaultNetwork)
- assertEquals(left.rxBytes, right.rxBytes)
- assertEquals(left.rxPackets, right.rxPackets)
- assertEquals(left.txBytes, right.txBytes)
- assertEquals(left.txPackets, right.txPackets)
- assertEquals(left.operations, right.operations)
- }
}
\ No newline at end of file
diff --git a/staticlibs/tests/unit/src/com/android/testutils/NetworkStatsUtilsTest.kt b/staticlibs/tests/unit/src/com/android/testutils/NetworkStatsUtilsTest.kt
new file mode 100644
index 0000000..57920fc
--- /dev/null
+++ b/staticlibs/tests/unit/src/com/android/testutils/NetworkStatsUtilsTest.kt
@@ -0,0 +1,69 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.net.NetworkStats
+import android.net.NetworkStats.DEFAULT_NETWORK_NO
+import android.net.NetworkStats.METERED_NO
+import android.net.NetworkStats.ROAMING_NO
+import android.net.NetworkStats.SET_DEFAULT
+import android.net.NetworkStats.TAG_NONE
+import android.os.Build
+import kotlin.test.assertFalse
+import kotlin.test.assertTrue
+import org.junit.Rule
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.JUnit4
+
+private const val TEST_IFACE = "test0"
+private val TEST_IFACE2: String? = null
+private const val TEST_START = 1194220800000L
+
+@RunWith(JUnit4::class)
+class NetworkStatsUtilsTest {
+ // This is a unit test for a test utility that uses R APIs
+ @Rule @JvmField
+ val ignoreRule = DevSdkIgnoreRule(ignoreClassUpTo = Build.VERSION_CODES.Q)
+
+ @Test
+ fun testOrderInsensitiveEquals() {
+ val testEntry = arrayOf(
+ NetworkStats.Entry(TEST_IFACE, 100, SET_DEFAULT, TAG_NONE, METERED_NO, ROAMING_NO,
+ DEFAULT_NETWORK_NO, 128L, 8L, 0L, 2L, 20L),
+ NetworkStats.Entry(TEST_IFACE2, 100, SET_DEFAULT, TAG_NONE, METERED_NO, ROAMING_NO,
+ DEFAULT_NETWORK_NO, 512L, 32L, 0L, 0L, 0L)
+ )
+
+ // Verify equals of empty stats regardless of initial capacity.
+ val red = NetworkStats(TEST_START, 0)
+ val blue = NetworkStats(TEST_START, 1)
+ assertTrue(orderInsensitiveEquals(red, blue))
+ assertTrue(orderInsensitiveEquals(blue, red))
+
+ // Verify not equal.
+ red.combineValues(testEntry[1])
+ blue.combineValues(testEntry[0]).combineValues(testEntry[1])
+ assertFalse(orderInsensitiveEquals(red, blue))
+ assertFalse(orderInsensitiveEquals(blue, red))
+
+ // Verify equals even if the order of entries are not the same.
+ red.combineValues(testEntry[0])
+ assertTrue(orderInsensitiveEquals(red, blue))
+ assertTrue(orderInsensitiveEquals(blue, red))
+ }
+}
diff --git a/staticlibs/testutils/Android.bp b/staticlibs/testutils/Android.bp
index 4749e75..8c71a91 100644
--- a/staticlibs/testutils/Android.bp
+++ b/staticlibs/testutils/Android.bp
@@ -42,7 +42,6 @@
"net-utils-device-common-struct",
"net-utils-device-common-struct-base",
"net-utils-device-common-wear",
- "net-utils-framework-connectivity",
"modules-utils-build_system",
],
lint: {
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/NetworkStatsUtils.kt b/staticlibs/testutils/devicetests/com/android/testutils/NetworkStatsUtils.kt
index 8324b25..e8c7297 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/NetworkStatsUtils.kt
+++ b/staticlibs/testutils/devicetests/com/android/testutils/NetworkStatsUtils.kt
@@ -16,8 +16,20 @@
package com.android.testutils
+import android.app.usage.NetworkStatsManager
+import android.content.Context
+import android.net.INetworkStatsService
+import android.net.INetworkStatsSession
import android.net.NetworkStats
+import android.net.NetworkTemplate
+import android.text.TextUtils
+import com.android.modules.utils.build.SdkLevel
import kotlin.test.assertTrue
+import org.mockito.ArgumentMatchers.any
+import org.mockito.ArgumentMatchers.anyBoolean
+import org.mockito.ArgumentMatchers.anyInt
+import org.mockito.ArgumentMatchers.anyLong
+import org.mockito.Mockito
@JvmOverloads
fun orderInsensitiveEquals(
@@ -26,7 +38,7 @@
compareTime: Boolean = false
): Boolean {
if (leftStats == rightStats) return true
- if (compareTime && leftStats.getElapsedRealtime() != rightStats.getElapsedRealtime()) {
+ if (compareTime && leftStats.elapsedRealtime != rightStats.elapsedRealtime) {
return false
}
@@ -47,12 +59,41 @@
left.metered, left.roaming, left.defaultNetwork, i)
if (j == -1) return false
rightTrimmedEmpty.getValues(j, right)
- if (left != right) return false
+ if (SdkLevel.isAtLeastT()) {
+ if (left != right) return false
+ } else {
+ if (!checkEntryEquals(left, right)) return false
+ }
}
return true
}
/**
+ * Assert that the two {@link NetworkStats.Entry} are equals.
+ */
+fun assertEntryEquals(left: NetworkStats.Entry, right: NetworkStats.Entry) {
+ assertTrue(checkEntryEquals(left, right))
+}
+
+// TODO: Make all callers use NetworkStats.Entry#equals once S- downstreams
+// are no longer supported. Because NetworkStats is mainlined on T+ and
+// NetworkStats.Entry#equals in S- does not support null iface.
+fun checkEntryEquals(left: NetworkStats.Entry, right: NetworkStats.Entry): Boolean {
+ return TextUtils.equals(left.iface, right.iface) &&
+ left.uid == right.uid &&
+ left.set == right.set &&
+ left.tag == right.tag &&
+ left.metered == right.metered &&
+ left.roaming == right.roaming &&
+ left.defaultNetwork == right.defaultNetwork &&
+ left.rxBytes == right.rxBytes &&
+ left.rxPackets == right.rxPackets &&
+ left.txBytes == right.txBytes &&
+ left.txPackets == right.txPackets &&
+ left.operations == right.operations
+}
+
+/**
* Assert that two {@link NetworkStats} are equals, assuming the order of the records are not
* necessarily the same.
*
@@ -66,7 +107,7 @@
compareTime: Boolean = false
) {
assertTrue(orderInsensitiveEquals(expected, actual, compareTime),
- "expected: " + expected + " but was: " + actual)
+ "expected: $expected but was: $actual")
}
/**
@@ -74,5 +115,34 @@
* object.
*/
fun assertParcelingIsLossless(stats: NetworkStats) {
- assertParcelingIsLossless(stats, { a, b -> orderInsensitiveEquals(a, b) })
+ assertParcelingIsLossless(stats) { a, b -> orderInsensitiveEquals(a, b) }
+}
+
+/**
+ * Make a {@link android.app.usage.NetworkStats} instance from
+ * a {@link android.net.NetworkStats} instance.
+ */
+// It's not possible to directly create a mocked `NetworkStats` instance
+// because of limitations with `NetworkStats#getNextBucket`.
+// As a workaround for testing, create a mock by controlling the return values
+// from the mocked service that provides the `NetworkStats` data.
+// Notes:
+// 1. The order of records in the final `NetworkStats` object might change or
+// some records might be merged if there are items with duplicate keys.
+// 2. The interface and operations fields will be empty since there is
+// no such field in the {@link android.app.usage.NetworkStats}.
+fun makePublicStatsFromAndroidNetStats(androidNetStats: NetworkStats):
+ android.app.usage.NetworkStats {
+ val mockService = Mockito.mock(INetworkStatsService::class.java)
+ val manager = NetworkStatsManager(Mockito.mock(Context::class.java), mockService)
+ val mockStatsSession = Mockito.mock(INetworkStatsSession::class.java)
+
+ Mockito.doReturn(mockStatsSession).`when`(mockService)
+ .openSessionForUsageStats(anyInt(), any())
+ Mockito.doReturn(androidNetStats).`when`(mockStatsSession).getSummaryForAllUid(
+ any(NetworkTemplate::class.java), anyLong(), anyLong(), anyBoolean())
+ return manager.querySummary(
+ Mockito.mock(NetworkTemplate::class.java),
+ Long.MIN_VALUE, Long.MAX_VALUE
+ )
}
diff --git a/tests/common/java/android/net/netstats/NetworkStatsApiTest.kt b/tests/common/java/android/net/netstats/NetworkStatsApiTest.kt
index 8cef6aa..17f5e96 100644
--- a/tests/common/java/android/net/netstats/NetworkStatsApiTest.kt
+++ b/tests/common/java/android/net/netstats/NetworkStatsApiTest.kt
@@ -28,18 +28,26 @@
import android.net.NetworkStats.SET_DEFAULT
import android.net.NetworkStats.SET_FOREGROUND
import android.net.NetworkStats.TAG_NONE
+import android.os.Build
import androidx.test.filters.SmallTest
+import com.android.testutils.ConnectivityModuleTest
+import com.android.testutils.DevSdkIgnoreRule
+import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
+import com.android.testutils.DevSdkIgnoreRunner
import com.android.testutils.assertNetworkStatsEquals
import com.android.testutils.assertParcelingIsLossless
import kotlin.test.assertEquals
import org.junit.Before
+import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
-import org.junit.runners.JUnit4
-@RunWith(JUnit4::class)
+@DevSdkIgnoreRunner.MonitorThreadLeak
+@RunWith(DevSdkIgnoreRunner::class)
@SmallTest
class NetworkStatsApiTest {
+ @get:Rule
+ val ignoreRule = DevSdkIgnoreRule()
private val testStatsEmpty = NetworkStats(0L, 0)
// Note that these variables need to be initialized outside of constructor, initialize
@@ -49,6 +57,7 @@
// be merged if performing add on these 2 stats.
private lateinit var testStats1: NetworkStats
private lateinit var testStats2: NetworkStats
+ private lateinit var expectedEntriesInStats2: List<Entry>
// This is a result of adding stats1 and stats2, while the merging of common key items is
// subject to test later, this should not be initialized with for a loop to add stats1
@@ -84,19 +93,23 @@
METERED_NO, ROAMING_NO, DEFAULT_NETWORK_NO, 3, 1, 6, 2, 0))
assertEquals(8, testStats1.size())
- testStats2 = NetworkStats(0L, 0)
- // Entries which are common for set1 and set2.
- .addEntry(Entry(TEST_IFACE, TEST_UID1, SET_DEFAULT, 0x80,
- METERED_NO, ROAMING_NO, DEFAULT_NETWORK_NO, 3, 15, 2, 31, 1))
- .addEntry(Entry(TEST_IFACE, TEST_UID1, SET_FOREGROUND, TAG_NONE,
- METERED_NO, ROAMING_NO, DEFAULT_NETWORK_NO, 13, 61, 10, 1, 45))
- .addEntry(Entry(TEST_IFACE, TEST_UID2, SET_DEFAULT, TAG_NONE,
- METERED_NO, ROAMING_NO, DEFAULT_NETWORK_NO, 11, 2, 3, 4, 7))
- .addEntry(Entry(IFACE_VT, TEST_UID1, SET_DEFAULT, TAG_NONE,
- METERED_NO, ROAMING_NO, DEFAULT_NETWORK_NO, 4, 3, 2, 1, 0))
- // Entry which only appears in set2.
- .addEntry(Entry(IFACE_VT, TEST_UID2, SET_DEFAULT, TAG_NONE,
- METERED_NO, ROAMING_NO, DEFAULT_NETWORK_NO, 2, 3, 7, 8, 0))
+ expectedEntriesInStats2 = listOf(
+ // Entries which are common for set1 and set2.
+ Entry(TEST_IFACE, TEST_UID1, SET_DEFAULT, 0x80,
+ METERED_NO, ROAMING_NO, DEFAULT_NETWORK_NO, 3, 15, 2, 31, 1),
+ Entry(TEST_IFACE, TEST_UID1, SET_FOREGROUND, TAG_NONE,
+ METERED_NO, ROAMING_NO, DEFAULT_NETWORK_NO, 13, 61, 10, 1, 45),
+ Entry(TEST_IFACE, TEST_UID2, SET_DEFAULT, TAG_NONE,
+ METERED_NO, ROAMING_NO, DEFAULT_NETWORK_NO, 11, 2, 3, 4, 7),
+ Entry(IFACE_VT, TEST_UID1, SET_DEFAULT, TAG_NONE,
+ METERED_NO, ROAMING_NO, DEFAULT_NETWORK_NO, 4, 3, 2, 1, 0),
+ // Entry which only appears in set2.
+ Entry(IFACE_VT, TEST_UID2, SET_DEFAULT, TAG_NONE,
+ METERED_NO, ROAMING_NO, DEFAULT_NETWORK_NO, 2, 3, 7, 8, 0))
+ testStats2 = NetworkStats(0L, 5)
+ for (entry in expectedEntriesInStats2) {
+ testStats2 = testStats2.addEntry(entry)
+ }
assertEquals(5, testStats2.size())
testStats3 = NetworkStats(0L, 9)
@@ -125,18 +138,6 @@
@Test
fun testAddEntry() {
- val expectedEntriesInStats2 = arrayOf(
- Entry(TEST_IFACE, TEST_UID1, SET_DEFAULT, 0x80,
- METERED_NO, ROAMING_NO, DEFAULT_NETWORK_NO, 3, 15, 2, 31, 1),
- Entry(TEST_IFACE, TEST_UID1, SET_FOREGROUND, TAG_NONE,
- METERED_NO, ROAMING_NO, DEFAULT_NETWORK_NO, 13, 61, 10, 1, 45),
- Entry(TEST_IFACE, TEST_UID2, SET_DEFAULT, TAG_NONE,
- METERED_NO, ROAMING_NO, DEFAULT_NETWORK_NO, 11, 2, 3, 4, 7),
- Entry(IFACE_VT, TEST_UID1, SET_DEFAULT, TAG_NONE,
- METERED_NO, ROAMING_NO, DEFAULT_NETWORK_NO, 4, 3, 2, 1, 0),
- Entry(IFACE_VT, TEST_UID2, SET_DEFAULT, TAG_NONE,
- METERED_NO, ROAMING_NO, DEFAULT_NETWORK_NO, 2, 3, 7, 8, 0))
-
// While testStats* are already initialized with addEntry, verify content added
// matches expectation.
for (i in expectedEntriesInStats2.indices) {
@@ -150,6 +151,27 @@
assertEquals(Entry(IFACE_VT, TEST_UID1, SET_DEFAULT, TAG_NONE,
METERED_NO, ROAMING_NO, DEFAULT_NETWORK_NO, 16, -2, 9, 1, 9),
stats.getValues(3, null))
+
+ // Verify the original ststs object is not altered.
+ for (i in expectedEntriesInStats2.indices) {
+ val entry = testStats2.getValues(i, null)
+ assertEquals(expectedEntriesInStats2[i], entry)
+ }
+ }
+
+ @ConnectivityModuleTest
+ @IgnoreUpTo(Build.VERSION_CODES.S_V2) // Mainlined NetworkStats only runs on T+
+ @Test
+ fun testAddEntries() {
+ val baseStats = NetworkStats(0L, 1)
+ .addEntry(Entry(IFACE_VT, TEST_UID1, SET_DEFAULT, TAG_NONE,
+ METERED_NO, ROAMING_NO, DEFAULT_NETWORK_NO, 12, -5, 7, 0, 9))
+ val statsUnderTest = baseStats.addEntries(expectedEntriesInStats2)
+ // Assume the correctness of addEntry is verified in other tests.
+ val expectedStats = testStats2
+ .addEntry(Entry(IFACE_VT, TEST_UID1, SET_DEFAULT, TAG_NONE,
+ METERED_NO, ROAMING_NO, DEFAULT_NETWORK_NO, 12, -5, 7, 0, 9))
+ assertNetworkStatsEquals(expectedStats, statsUnderTest)
}
@Test
diff --git a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
index 9458460..88c2d5a 100644
--- a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
+++ b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
@@ -206,6 +206,7 @@
import com.android.networkstack.apishim.common.ConnectivityManagerShim;
import com.android.testutils.AutoReleaseNetworkCallbackRule;
import com.android.testutils.CompatUtil;
+import com.android.testutils.ConnectUtil;
import com.android.testutils.ConnectivityModuleTest;
import com.android.testutils.DevSdkIgnoreRule;
import com.android.testutils.DevSdkIgnoreRule.IgnoreAfter;
@@ -2942,20 +2943,10 @@
// This may also apply to wifi in principle, but in practice methods that mock validation
// URL all disconnect wifi forcefully anyway, so don't wait for wifi to validate.
if (mPackageManager.hasSystemFeature(FEATURE_TELEPHONY)) {
- ensureValidatedNetwork(makeCellNetworkRequest());
+ new ConnectUtil(mContext).ensureCellularValidated();
}
}
- private void ensureValidatedNetwork(NetworkRequest request) {
- final TestableNetworkCallback cb = new TestableNetworkCallback();
- mCm.registerNetworkCallback(request, cb);
- cb.eventuallyExpect(CallbackEntry.NETWORK_CAPS_UPDATED,
- NETWORK_CALLBACK_TIMEOUT_MS,
- entry -> ((CallbackEntry.CapabilitiesChanged) entry).getCaps()
- .hasCapability(NET_CAPABILITY_VALIDATED));
- mCm.unregisterNetworkCallback(cb);
- }
-
@AppModeFull(reason = "WRITE_DEVICE_CONFIG permission can't be granted to instant apps")
@Test
public void testAcceptPartialConnectivity_validatedNetwork() throws Exception {
diff --git a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
index 5c1099d..c71d925 100644
--- a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
+++ b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
@@ -130,6 +130,7 @@
import kotlin.test.assertContentEquals
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
+import kotlin.test.assertNotEquals
import kotlin.test.assertNotNull
import kotlin.test.assertNull
import kotlin.test.fail
@@ -142,10 +143,10 @@
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
-import kotlin.test.assertNotEquals
private const val TAG = "NsdManagerTest"
private const val TIMEOUT_MS = 2000L
+
// Registration may take a long time if there are devices with the same hostname on the network,
// as the device needs to try another name and probe again. This is especially true since when using
// mdnsresponder the usual hostname is "Android", and on conflict "Android-2", "Android-3", ... are
@@ -187,11 +188,12 @@
private val customHostname = "NsdTestHost%09d".format(Random().nextInt(1_000_000_000))
private val customHostname2 = "NsdTestHost%09d".format(Random().nextInt(1_000_000_000))
private val publicKey = hexStringToByteArray(
- "0201030dc141d0637960b98cbc12cfca"
- + "221d2879dac26ee5b460e9007c992e19"
- + "02d897c391b03764d448f7d0c772fdb0"
- + "3b1d9d6d52ff8886769e8e2362513565"
- + "270962d3")
+ "0201030dc141d0637960b98cbc12cfca" +
+ "221d2879dac26ee5b460e9007c992e19" +
+ "02d897c391b03764d448f7d0c772fdb0" +
+ "3b1d9d6d52ff8886769e8e2362513565" +
+ "270962d3"
+ )
private val handlerThread = HandlerThread(NsdManagerTest::class.java.simpleName)
private val ctsNetUtils by lazy{ CtsNetUtils(context) }
@@ -243,11 +245,14 @@
val iface = tnm.createTapInterface()
val cb = TestableNetworkCallback()
val testNetworkSpecifier = TestNetworkSpecifier(iface.interfaceName)
- cm.requestNetwork(NetworkRequest.Builder()
+ cm.requestNetwork(
+ NetworkRequest.Builder()
.removeCapability(NET_CAPABILITY_TRUSTED)
.addTransportType(TRANSPORT_TEST)
.setNetworkSpecifier(testNetworkSpecifier)
- .build(), cb)
+ .build(),
+ cb
+ )
val agent = registerTestNetworkAgent(iface.interfaceName)
val network = agent.network ?: fail("Registered agent should have a network")
@@ -267,12 +272,17 @@
val lp = LinkProperties().apply {
interfaceName = ifaceName
}
- val agent = TestableNetworkAgent(context, handlerThread.looper,
+ val agent = TestableNetworkAgent(
+ context,
+ handlerThread.looper,
NetworkCapabilities().apply {
removeCapability(NET_CAPABILITY_TRUSTED)
addTransportType(TRANSPORT_TEST)
setNetworkSpecifier(TestNetworkSpecifier(ifaceName))
- }, lp, NetworkAgentConfig.Builder().build())
+ },
+ lp,
+ NetworkAgentConfig.Builder().build()
+ )
val network = agent.register()
agent.markConnected()
agent.expectCallback<OnNetworkCreated>()
@@ -346,15 +356,19 @@
Triple(null, null, "null key"),
Triple("", null, "empty key"),
Triple(string256, null, "key with 256 characters"),
- Triple("key", string256.substring(3),
- "key+value combination with more than 255 characters"),
+ Triple(
+ "key",
+ string256.substring(3),
+ "key+value combination with more than 255 characters"
+ ),
Triple("key", string256.substring(4), "key+value combination with 255 characters"),
Triple("\u0019", null, "key with invalid character"),
Triple("=", null, "key with invalid character"),
Triple("\u007f", null, "key with invalid character")
).forEach {
assertFailsWith<IllegalArgumentException>(
- "Setting invalid ${it.third} unexpectedly succeeded") {
+ "Setting invalid ${it.third} unexpectedly succeeded"
+ ) {
si.setAttribute(it.first, it.second)
}
}
@@ -377,7 +391,8 @@
// Test registering without an Executor
nsdManager.registerService(si, NsdManager.PROTOCOL_DNS_SD, registrationRecord)
val registeredInfo = registrationRecord.expectCallback<ServiceRegistered>(
- REGISTRATION_TIMEOUT_MS).serviceInfo
+ REGISTRATION_TIMEOUT_MS
+ ).serviceInfo
// Only service name is included in ServiceRegistered callbacks
assertNull(registeredInfo.serviceType)
@@ -392,7 +407,9 @@
// Expect a service record to be discovered
val foundInfo = discoveryRecord.waitForServiceDiscovered(
- registeredInfo.serviceName, serviceType)
+ registeredInfo.serviceName,
+ serviceType
+ )
// Test resolving without an Executor
val resolveRecord = NsdResolveRecord()
@@ -407,8 +424,10 @@
assertNull(resolvedService.attributes["booleanAttr"])
assertEquals("value", resolvedService.attributes["keyValueAttr"].utf8ToString())
assertEquals("=", resolvedService.attributes["keyEqualsAttr"].utf8ToString())
- assertEquals(" value ",
- resolvedService.attributes[" whiteSpaceKeyValueAttr "].utf8ToString())
+ assertEquals(
+ " value ",
+ resolvedService.attributes[" whiteSpaceKeyValueAttr "].utf8ToString()
+ )
assertEquals(string256.substring(9), resolvedService.attributes["longkey"].utf8ToString())
assertArrayEquals(testByteArray, resolvedService.attributes["binaryDataAttr"])
assertTrue(resolvedService.attributes.containsKey("nullBinaryDataAttr"))
@@ -442,12 +461,15 @@
val registrationRecord2 = NsdRegistrationRecord()
nsdManager.registerService(si2, NsdManager.PROTOCOL_DNS_SD, registrationRecord2)
val registeredInfo2 = registrationRecord2.expectCallback<ServiceRegistered>(
- REGISTRATION_TIMEOUT_MS).serviceInfo
+ REGISTRATION_TIMEOUT_MS
+ ).serviceInfo
// Expect a service record to be discovered (and filter the ones
// that are unrelated to this test)
val foundInfo2 = discoveryRecord.waitForServiceDiscovered(
- registeredInfo2.serviceName, serviceType)
+ registeredInfo2.serviceName,
+ serviceType
+ )
// Resolve the service
val resolveRecord2 = NsdResolveRecord()
@@ -950,7 +972,8 @@
// when the compat change is disabled.
// Note that before T the compat constant had a different int value.
assertFalse(CompatChanges.isChangeEnabled(
- ConnectivityCompatChanges.RUN_NATIVE_NSD_ONLY_IF_LEGACY_APPS_T_AND_LATER))
+ ConnectivityCompatChanges.RUN_NATIVE_NSD_ONLY_IF_LEGACY_APPS_T_AND_LATER
+ ))
}
@Test
@@ -1023,7 +1046,9 @@
// onStopResolutionFailed on the record directly then verify it is received.
val resolveRecord = NsdResolveRecord()
resolveRecord.onStopResolutionFailed(
- NsdServiceInfo(), NsdManager.FAILURE_OPERATION_NOT_RUNNING)
+ NsdServiceInfo(),
+ NsdManager.FAILURE_OPERATION_NOT_RUNNING
+ )
val failedCb = resolveRecord.expectCallback<StopResolutionFailed>()
assertEquals(NsdManager.FAILURE_OPERATION_NOT_RUNNING, failedCb.errorCode)
}
@@ -1274,15 +1299,22 @@
val si = makeTestServiceInfo(testNetwork1.network)
- val packetReader = TapPacketReader(Handler(handlerThread.looper),
- testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+ val packetReader = TapPacketReader(
+ Handler(handlerThread.looper),
+ testNetwork1.iface.fileDescriptor.fileDescriptor,
+ 1500 /* maxPacketSize */
+ )
packetReader.startAsyncForTest()
handlerThread.waitForIdle(TIMEOUT_MS)
// Register service on testNetwork1
val registrationRecord = NsdRegistrationRecord()
- nsdManager.registerService(si, NsdManager.PROTOCOL_DNS_SD, { it.run() },
- registrationRecord)
+ nsdManager.registerService(
+ si,
+ NsdManager.PROTOCOL_DNS_SD,
+ { it.run() },
+ registrationRecord
+ )
tryTest {
assertNotNull(packetReader.pollForProbe(serviceName, serviceType),
@@ -1313,15 +1345,22 @@
parseNumericAddress("2001:db8::3"))
}
- val packetReader = TapPacketReader(Handler(handlerThread.looper),
- testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+ val packetReader = TapPacketReader(
+ Handler(handlerThread.looper),
+ testNetwork1.iface.fileDescriptor.fileDescriptor,
+ 1500 /* maxPacketSize */
+ )
packetReader.startAsyncForTest()
handlerThread.waitForIdle(TIMEOUT_MS)
// Register service on testNetwork1
val registrationRecord = NsdRegistrationRecord()
- nsdManager.registerService(si, NsdManager.PROTOCOL_DNS_SD, { it.run() },
- registrationRecord)
+ nsdManager.registerService(
+ si,
+ NsdManager.PROTOCOL_DNS_SD,
+ { it.run() },
+ registrationRecord
+ )
tryTest {
assertNotNull(packetReader.pollForProbe(serviceName, serviceType),
@@ -1352,15 +1391,22 @@
hostname = customHostname
}
- val packetReader = TapPacketReader(Handler(handlerThread.looper),
- testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+ val packetReader = TapPacketReader(
+ Handler(handlerThread.looper),
+ testNetwork1.iface.fileDescriptor.fileDescriptor,
+ 1500 /* maxPacketSize */
+ )
packetReader.startAsyncForTest()
handlerThread.waitForIdle(TIMEOUT_MS)
// Register service on testNetwork1
val registrationRecord = NsdRegistrationRecord()
- nsdManager.registerService(si, NsdManager.PROTOCOL_DNS_SD, { it.run() },
- registrationRecord)
+ nsdManager.registerService(
+ si,
+ NsdManager.PROTOCOL_DNS_SD,
+ { it.run() },
+ registrationRecord
+ )
tryTest {
assertNotNull(packetReader.pollForProbe(serviceName, serviceType),
@@ -1392,8 +1438,11 @@
val registrationRecord = NsdRegistrationRecord()
val discoveryRecord = NsdDiscoveryRecord()
val registeredService = registerService(registrationRecord, si)
- val packetReader = TapPacketReader(Handler(handlerThread.looper),
- testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+ val packetReader = TapPacketReader(
+ Handler(handlerThread.looper),
+ testNetwork1.iface.fileDescriptor.fileDescriptor,
+ 1500 /* maxPacketSize */
+ )
packetReader.startAsyncForTest()
handlerThread.waitForIdle(TIMEOUT_MS)
@@ -1471,7 +1520,9 @@
val registeredService = registerService(registrationRecord, si)
val packetReader = TapPacketReader(
Handler(handlerThread.looper),
- testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+ testNetwork1.iface.fileDescriptor.fileDescriptor,
+ 1500 /* maxPacketSize */
+ )
packetReader.startAsyncForTest()
handlerThread.waitForIdle(TIMEOUT_MS)
@@ -1501,11 +1552,11 @@
val newRegistration =
registrationRecord
.expectCallbackEventually<ServiceRegistered>(REGISTRATION_TIMEOUT_MS) {
- it.serviceInfo.serviceName == serviceName
- && it.serviceInfo.hostname.let { hostname ->
- hostname != null
- && hostname.startsWith(customHostname)
- && hostname != customHostname
+ it.serviceInfo.serviceName == serviceName &&
+ it.serviceInfo.hostname.let { hostname ->
+ hostname != null &&
+ hostname.startsWith(customHostname) &&
+ hostname != customHostname
}
}
@@ -1536,8 +1587,11 @@
val registrationRecord = NsdRegistrationRecord()
val discoveryRecord = NsdDiscoveryRecord()
val registeredService = registerService(registrationRecord, si)
- val packetReader = TapPacketReader(Handler(handlerThread.looper),
- testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+ val packetReader = TapPacketReader(
+ Handler(handlerThread.looper),
+ testNetwork1.iface.fileDescriptor.fileDescriptor,
+ 1500 /* maxPacketSize */
+ )
packetReader.startAsyncForTest()
handlerThread.waitForIdle(TIMEOUT_MS)
@@ -1576,13 +1630,21 @@
fun testDiscoveryWithPtrOnlyResponse_ServiceIsFound() {
// Register service on testNetwork1
val discoveryRecord = NsdDiscoveryRecord()
- val packetReader = TapPacketReader(Handler(handlerThread.looper),
- testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+ val packetReader = TapPacketReader(
+ Handler(handlerThread.looper),
+ testNetwork1.iface.fileDescriptor.fileDescriptor,
+ 1500 /* maxPacketSize */
+ )
packetReader.startAsyncForTest()
handlerThread.waitForIdle(TIMEOUT_MS)
- nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD,
- testNetwork1.network, { it.run() }, discoveryRecord)
+ nsdManager.discoverServices(
+ serviceType,
+ NsdManager.PROTOCOL_DNS_SD,
+ testNetwork1.network,
+ { it.run() },
+ discoveryRecord
+ )
tryTest {
discoveryRecord.expectCallback<DiscoveryStarted>()
@@ -1626,8 +1688,10 @@
fun testResolveWhenServerSendsNoAdditionalRecord() {
// Resolve service on testNetwork1
val resolveRecord = NsdResolveRecord()
- val packetReader = TapPacketReader(Handler(handlerThread.looper),
- testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */
+ val packetReader = TapPacketReader(
+ Handler(handlerThread.looper),
+ testNetwork1.iface.fileDescriptor.fileDescriptor,
+ 1500 /* maxPacketSize */
)
packetReader.startAsyncForTest()
handlerThread.waitForIdle(TIMEOUT_MS)
@@ -1651,16 +1715,21 @@
rdata='testkey=testvalue')
))).hex()
*/
- val srvTxtResponsePayload = HexDump.hexStringToByteArray("000084000000000200000000104" +
+ val srvTxtResponsePayload = HexDump.hexStringToByteArray(
+ "000084000000000200000000104" +
"e7364546573743132333435363738390d5f6e6d74313233343536373839045f746370056c6f6" +
"3616c0000218001000000780011000000007a020874657374686f7374c030c00c00100001000" +
- "00078001211746573746b65793d7465737476616c7565")
+ "00078001211746573746b65793d7465737476616c7565"
+ )
replaceServiceNameAndTypeWithTestSuffix(srvTxtResponsePayload)
packetReader.sendResponse(buildMdnsPacket(srvTxtResponsePayload))
val testHostname = "testhost.local"
- val addressQuery = packetReader.pollForQuery(testHostname,
- DnsResolver.TYPE_A, DnsResolver.TYPE_AAAA)
+ val addressQuery = packetReader.pollForQuery(
+ testHostname,
+ DnsResolver.TYPE_A,
+ DnsResolver.TYPE_AAAA
+ )
assertNotNull(addressQuery)
/*
@@ -1672,9 +1741,11 @@
rdata='2001:db8::123')
))).hex()
*/
- val addressPayload = HexDump.hexStringToByteArray("0000840000000002000000000874657374" +
+ val addressPayload = HexDump.hexStringToByteArray(
+ "0000840000000002000000000874657374" +
"686f7374056c6f63616c0000010001000000780004c000027bc00c001c000100000078001020" +
- "010db8000000000000000000000123")
+ "010db8000000000000000000000123"
+ )
packetReader.sendResponse(buildMdnsPacket(addressPayload))
val serviceResolved = resolveRecord.expectCallback<ServiceResolved>()
@@ -1688,7 +1759,8 @@
}
assertEquals(
setOf(parseNumericAddress("192.0.2.123"), parseNumericAddress("2001:db8::123")),
- serviceResolved.serviceInfo.hostAddresses.toSet())
+ serviceResolved.serviceInfo.hostAddresses.toSet()
+ )
}
@Test
@@ -1715,7 +1787,7 @@
scapy.raw(scapy.DNS(rd=0, qr=0, aa=0, qd =
scapy.DNSQR(qname='_nmt123456789._tcp.local', qtype='PTR', qclass=0x8001)
)).hex()
- */
+ */
val mdnsPayload = HexDump.hexStringToByteArray("0000000000010000000000000d5f6e6d74313" +
"233343536373839045f746370056c6f63616c00000c8001")
replaceServiceNameAndTypeWithTestSuffix(mdnsPayload)
@@ -1771,7 +1843,7 @@
an = scapy.DNSRR(rrname='_nmt123456789._tcp.local', type='PTR', ttl=4500,
rdata='NsdTest123456789._nmt123456789._tcp.local')
)).hex()
- */
+ */
val query = HexDump.hexStringToByteArray("0000000000020001000000000d5f6e6d74313233343" +
"536373839045f746370056c6f63616c00000c8001104e7364546573743132333435363738390" +
"d5f6e6d74313233343536373839045f746370056c6f63616c00001080010d5f6e6d743132333" +
@@ -1806,7 +1878,7 @@
an = scapy.DNSRR(rrname='_nmt123456789._tcp.local', type='PTR', ttl=2150,
rdata='NsdTest123456789._nmt123456789._tcp.local')
)).hex()
- */
+ */
val query2 = HexDump.hexStringToByteArray("0000000000020001000000000d5f6e6d7431323334" +
"3536373839045f746370056c6f63616c00000c8001104e736454657374313233343536373839" +
"0d5f6e6d74313233343536373839045f746370056c6f63616c00001080010d5f6e6d74313233" +
@@ -1858,7 +1930,7 @@
scapy.DNSQR(qname='NsdTest123456789._nmt123456789._tcp.local', qtype='TXT',
qclass=0x8001)
)).hex()
- */
+ */
val query = HexDump.hexStringToByteArray("0000020000020000000000000d5f6e6d74313233343" +
"536373839045f746370056c6f63616c00000c8001104e7364546573743132333435363738390" +
"d5f6e6d74313233343536373839045f746370056c6f63616c0000108001")
@@ -1870,7 +1942,7 @@
an = scapy.DNSRR(rrname='_test._tcp.local', type='PTR', ttl=4500,
rdata='NsdTest._test._tcp.local')
)).hex()
- */
+ */
val knownAnswer1 = HexDump.hexStringToByteArray("000002000000000100000000055f74657374" +
"045f746370056c6f63616c00000c000100001194001a074e736454657374055f74657374045f" +
"746370056c6f63616c00")
@@ -1882,7 +1954,7 @@
an = scapy.DNSRR(rrname='_nmt123456789._tcp.local', type='PTR', ttl=4500,
rdata='NsdTest123456789._nmt123456789._tcp.local')
)).hex()
- */
+ */
val knownAnswer2 = HexDump.hexStringToByteArray("0000000000000001000000000d5f6e6d7431" +
"3233343536373839045f746370056c6f63616c00000c000100001194002b104e736454657374" +
"3132333435363738390d5f6e6d74313233343536373839045f746370056c6f63616c00")
@@ -1919,13 +1991,21 @@
// Register service on testNetwork1
val discoveryRecord = NsdDiscoveryRecord()
- val packetReader = TapPacketReader(Handler(handlerThread.looper),
- testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+ val packetReader = TapPacketReader(
+ Handler(handlerThread.looper),
+ testNetwork1.iface.fileDescriptor.fileDescriptor,
+ 1500 /* maxPacketSize */
+ )
packetReader.startAsyncForTest()
handlerThread.waitForIdle(TIMEOUT_MS)
- nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD,
- testNetwork1.network, { it.run() }, discoveryRecord)
+ nsdManager.discoverServices(
+ serviceType,
+ NsdManager.PROTOCOL_DNS_SD,
+ testNetwork1.network,
+ { it.run() },
+ discoveryRecord
+ )
tryTest {
discoveryRecord.expectCallback<DiscoveryStarted>()
@@ -1987,10 +2067,12 @@
val hostAddresses1 = listOf(
parseNumericAddress("192.0.2.23"),
parseNumericAddress("2001:db8::1"),
- parseNumericAddress("2001:db8::2"))
+ parseNumericAddress("2001:db8::2")
+ )
val hostAddresses2 = listOf(
parseNumericAddress("192.0.2.24"),
- parseNumericAddress("2001:db8::3"))
+ parseNumericAddress("2001:db8::3")
+ )
val si1 = NsdServiceInfo().also {
it.network = testNetwork1.network
it.serviceName = serviceName
@@ -2054,10 +2136,12 @@
val hostAddresses1 = listOf(
parseNumericAddress("192.0.2.23"),
parseNumericAddress("2001:db8::1"),
- parseNumericAddress("2001:db8::2"))
+ parseNumericAddress("2001:db8::2")
+ )
val hostAddresses2 = listOf(
parseNumericAddress("192.0.2.24"),
- parseNumericAddress("2001:db8::3"))
+ parseNumericAddress("2001:db8::3")
+ )
val si1 = NsdServiceInfo().also {
it.network = testNetwork1.network
it.hostname = customHostname
@@ -2105,7 +2189,8 @@
val hostAddresses = listOf(
parseNumericAddress("192.0.2.23"),
parseNumericAddress("2001:db8::1"),
- parseNumericAddress("2001:db8::2"))
+ parseNumericAddress("2001:db8::2")
+ )
val si1 = NsdServiceInfo().also {
it.network = testNetwork1.network
it.serviceType = serviceType
@@ -2270,8 +2355,11 @@
it.port = TEST_PORT
it.publicKey = publicKey
}
- val packetReader = TapPacketReader(Handler(handlerThread.looper),
- testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+ val packetReader = TapPacketReader(
+ Handler(handlerThread.looper),
+ testNetwork1.iface.fileDescriptor.fileDescriptor,
+ 1500 /* maxPacketSize */
+ )
packetReader.startAsyncForTest()
handlerThread.waitForIdle(TIMEOUT_MS)
@@ -2322,8 +2410,11 @@
parseNumericAddress("2001:db8::2"))
it.publicKey = publicKey
}
- val packetReader = TapPacketReader(Handler(handlerThread.looper),
- testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+ val packetReader = TapPacketReader(
+ Handler(handlerThread.looper),
+ testNetwork1.iface.fileDescriptor.fileDescriptor,
+ 1500 /* maxPacketSize */
+ )
packetReader.startAsyncForTest()
handlerThread.waitForIdle(TIMEOUT_MS)
@@ -2376,8 +2467,11 @@
it.hostAddresses = listOf()
it.publicKey = publicKey
}
- val packetReader = TapPacketReader(Handler(handlerThread.looper),
- testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+ val packetReader = TapPacketReader(
+ Handler(handlerThread.looper),
+ testNetwork1.iface.fileDescriptor.fileDescriptor,
+ 1500 /* maxPacketSize */
+ )
packetReader.startAsyncForTest()
handlerThread.waitForIdle(TIMEOUT_MS)
@@ -2430,10 +2524,20 @@
val discoveryRecord1 = NsdDiscoveryRecord()
val discoveryRecord2 = NsdDiscoveryRecord()
val discoveryRecord3 = NsdDiscoveryRecord()
- nsdManager.discoverServices("_test1._tcp", NsdManager.PROTOCOL_DNS_SD,
- testNetwork1.network, { it.run() }, discoveryRecord1)
- nsdManager.discoverServices("_test2._tcp", NsdManager.PROTOCOL_DNS_SD,
- testNetwork1.network, { it.run() }, discoveryRecord2)
+ nsdManager.discoverServices(
+ "_test1._tcp",
+ NsdManager.PROTOCOL_DNS_SD,
+ testNetwork1.network,
+ { it.run() },
+ discoveryRecord1
+ )
+ nsdManager.discoverServices(
+ "_test2._tcp",
+ NsdManager.PROTOCOL_DNS_SD,
+ testNetwork1.network,
+ { it.run() },
+ discoveryRecord2
+ )
nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, discoveryRecord3)
tryTest {
@@ -2527,11 +2631,16 @@
*/
private fun getServiceTypeClients(): List<String> {
return SystemUtil.runShellCommand(
- InstrumentationRegistry.getInstrumentation(), "dumpsys servicediscovery")
+ InstrumentationRegistry.getInstrumentation(),
+ "dumpsys servicediscovery"
+ )
.split("\n").mapNotNull { line ->
line.indexOf("ServiceTypeClient:").let { idx ->
- if (idx == -1) null
- else line.substring(idx)
+ if (idx == -1) {
+ null
+ } else {
+ line.substring(idx)
+ }
}
}
}
@@ -2544,9 +2653,11 @@
rclass=0x8001, port=31234, target='conflict.local', ttl=120)
)).hex()
*/
- val mdnsPayload = HexDump.hexStringToByteArray("000084000000000100000000104e736454657" +
+ val mdnsPayload = HexDump.hexStringToByteArray(
+ "000084000000000100000000104e736454657" +
"3743132333435363738390d5f6e6d74313233343536373839045f746370056c6f63616c00002" +
- "18001000000780016000000007a0208636f6e666c696374056c6f63616c00")
+ "18001000000780016000000007a0208636f6e666c696374056c6f63616c00"
+ )
replaceServiceNameAndTypeWithTestSuffix(mdnsPayload)
return buildMdnsPacket(mdnsPayload)
@@ -2560,9 +2671,11 @@
rdata='2001:db8::321')
)).hex()
*/
- val mdnsPayload = HexDump.hexStringToByteArray("000084000000000100000000144e7364" +
+ val mdnsPayload = HexDump.hexStringToByteArray(
+ "000084000000000100000000144e7364" +
"54657374486f7374313233343536373839056c6f63616c00001c000100000078001020010db80000" +
- "00000000000000000321")
+ "00000000000000000321"
+ )
replaceCustomHostnameWithTestSuffix(mdnsPayload)
return buildMdnsPacket(mdnsPayload)
@@ -2613,22 +2726,29 @@
mdnsPayload: ByteArray,
srcAddr: Inet6Address = testSrcAddr
): ByteBuffer {
- val packetBuffer = PacketBuilder.allocate(true /* hasEther */, IPPROTO_IPV6,
- IPPROTO_UDP, mdnsPayload.size)
+ val packetBuffer = PacketBuilder.allocate(
+ true /* hasEther */,
+ IPPROTO_IPV6,
+ IPPROTO_UDP,
+ mdnsPayload.size
+ )
val packetBuilder = PacketBuilder(packetBuffer)
// Multicast ethernet address for IPv6 to ff02::fb
val multicastEthAddr = MacAddress.fromBytes(
- byteArrayOf(0x33, 0x33, 0, 0, 0, 0xfb.toByte()))
+ byteArrayOf(0x33, 0x33, 0, 0, 0, 0xfb.toByte())
+ )
packetBuilder.writeL2Header(
MacAddress.fromBytes(byteArrayOf(1, 2, 3, 4, 5, 6)) /* srcMac */,
multicastEthAddr,
- ETH_P_IPV6.toShort())
+ ETH_P_IPV6.toShort()
+ )
packetBuilder.writeIpv6Header(
0x60000000, // version=6, traffic class=0x0, flowlabel=0x0
IPPROTO_UDP.toByte(),
64 /* hop limit */,
srcAddr,
- multicastIpv6Addr /* dstIp */)
+ multicastIpv6Addr /* dstIp */
+ )
packetBuilder.writeUdpHeader(MDNS_PORT /* srcPort */, MDNS_PORT /* dstPort */)
packetBuffer.put(mdnsPayload)
return packetBuilder.finalizePacket()
diff --git a/tests/unit/java/com/android/server/CallbackQueueTest.kt b/tests/unit/java/com/android/server/CallbackQueueTest.kt
index a6dd5c3..d8d35c1 100644
--- a/tests/unit/java/com/android/server/CallbackQueueTest.kt
+++ b/tests/unit/java/com/android/server/CallbackQueueTest.kt
@@ -116,7 +116,7 @@
addCallback(TEST_NETID_1, CALLBACK_AVAILABLE)
addCallback(TEST_NETID_2, CALLBACK_AVAILABLE)
}
- val queue2 = CallbackQueue(queue1.shrinkedBackingArray)
+ val queue2 = CallbackQueue(queue1.minimizedBackingArray)
assertQueueEquals(listOf(
TEST_NETID_1 to CALLBACK_AVAILABLE,
TEST_NETID_2 to CALLBACK_AVAILABLE
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
index 8c44abd..9674da3 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
@@ -123,18 +123,18 @@
}
private val TEST_PUBLIC_KEY = hexStringToByteArray(
- "0201030dc141d0637960b98cbc12cfca"
- + "221d2879dac26ee5b460e9007c992e19"
- + "02d897c391b03764d448f7d0c772fdb0"
- + "3b1d9d6d52ff8886769e8e2362513565"
- + "270962d3")
+ "0201030dc141d0637960b98cbc12cfca" +
+ "221d2879dac26ee5b460e9007c992e19" +
+ "02d897c391b03764d448f7d0c772fdb0" +
+ "3b1d9d6d52ff8886769e8e2362513565" +
+ "270962d3")
private val TEST_PUBLIC_KEY_2 = hexStringToByteArray(
- "0201030dc141d0637960b98cbc12cfca"
- + "221d2879dac26ee5b460e9007c992e19"
- + "02d897c391b03764d448f7d0c772fdb0"
- + "3b1d9d6d52ff8886769e8e2362513565"
- + "270962d4")
+ "0201030dc141d0637960b98cbc12cfca" +
+ "221d2879dac26ee5b460e9007c992e19" +
+ "02d897c391b03764d448f7d0c772fdb0" +
+ "3b1d9d6d52ff8886769e8e2362513565" +
+ "270962d4")
@RunWith(DevSdkIgnoreRunner::class)
@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
@@ -159,7 +159,7 @@
@Before
fun setUp() {
- deps.resetElapsedRealTime();
+ deps.resetElapsedRealTime()
thread.start()
}
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt
index b040ab6..0a8f108 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt
@@ -154,6 +154,11 @@
serviceCache.registerServiceExpiredCallback(cacheKey, callback)
}
+ private fun removeServices(
+ serviceCache: MdnsServiceCache,
+ cacheKey: CacheKey
+ ): Unit = runningOnHandlerAndReturn { serviceCache.removeServices(cacheKey) }
+
@Test
fun testAddAndRemoveService() {
val serviceCache = MdnsServiceCache(thread.looper, makeFlags(), clock)
@@ -291,6 +296,37 @@
assertEquals(response4, responses[3])
}
+ @Test
+ fun testRemoveServices() {
+ val serviceCache = MdnsServiceCache(thread.looper, makeFlags(), clock)
+ addOrUpdateService(serviceCache, cacheKey1, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
+ addOrUpdateService(serviceCache, cacheKey1, createResponse(SERVICE_NAME_2, SERVICE_TYPE_1))
+ addOrUpdateService(serviceCache, cacheKey2, createResponse(SERVICE_NAME_1, SERVICE_TYPE_2))
+ val responses1 = getServices(serviceCache, cacheKey1)
+ assertEquals(2, responses1.size)
+ assertTrue(responses1.stream().anyMatch { response ->
+ response.serviceInstanceName == SERVICE_NAME_1
+ })
+ assertTrue(responses1.any { response ->
+ response.serviceInstanceName == SERVICE_NAME_2
+ })
+ val responses2 = getServices(serviceCache, cacheKey2)
+ assertEquals(1, responses2.size)
+ assertTrue(responses2.stream().anyMatch { response ->
+ response.serviceInstanceName == SERVICE_NAME_1
+ })
+
+ removeServices(serviceCache, cacheKey1)
+ val responses3 = getServices(serviceCache, cacheKey1)
+ assertEquals(0, responses3.size)
+ val responses4 = getServices(serviceCache, cacheKey2)
+ assertEquals(1, responses4.size)
+
+ removeServices(serviceCache, cacheKey2)
+ val responses5 = getServices(serviceCache, cacheKey2)
+ assertEquals(0, responses5.size)
+ }
+
private fun createResponse(
serviceInstanceName: String,
serviceType: String,
diff --git a/tests/unit/java/com/android/server/connectivityservice/CSIngressDiscardRuleTests.kt b/tests/unit/java/com/android/server/connectivityservice/CSIngressDiscardRuleTests.kt
index 93f6e81..77b06b2 100644
--- a/tests/unit/java/com/android/server/connectivityservice/CSIngressDiscardRuleTests.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/CSIngressDiscardRuleTests.kt
@@ -26,7 +26,9 @@
import android.net.NetworkCapabilities.TRANSPORT_VPN
import android.net.NetworkCapabilities.TRANSPORT_WIFI
import android.net.NetworkRequest
+import android.net.VpnManager.TYPE_VPN_OEM
import android.net.VpnManager.TYPE_VPN_SERVICE
+import android.net.VpnManager.TYPE_VPN_LEGACY
import android.net.VpnTransportInfo
import android.os.Build
import androidx.test.filters.SmallTest
@@ -49,19 +51,19 @@
private const val TIMEOUT_MS = 1_000L
private const val LONG_TIMEOUT_MS = 5_000
-private fun vpnNc() = NetworkCapabilities.Builder()
- .addTransportType(TRANSPORT_VPN)
- .removeCapability(NET_CAPABILITY_NOT_VPN)
- .addCapability(NET_CAPABILITY_NOT_VCN_MANAGED)
- .setTransportInfo(
- VpnTransportInfo(
- TYPE_VPN_SERVICE,
- "MySession12345",
- false /* bypassable */,
- false /* longLivedTcpConnectionsExpensive */
- )
- )
- .build()
+private fun vpnNc(vpnType: Int = TYPE_VPN_SERVICE) = NetworkCapabilities.Builder().apply {
+ addTransportType(TRANSPORT_VPN)
+ removeCapability(NET_CAPABILITY_NOT_VPN)
+ addCapability(NET_CAPABILITY_NOT_VCN_MANAGED)
+ setTransportInfo(
+ VpnTransportInfo(
+ vpnType,
+ "MySession12345",
+ false /* bypassable */,
+ false /* longLivedTcpConnectionsExpensive */
+ )
+ )
+}.build()
private fun wifiNc() = NetworkCapabilities.Builder()
.addTransportType(TRANSPORT_WIFI)
@@ -310,4 +312,38 @@
// IngressDiscardRule should not be added since feature is disabled
verify(bpfNetMaps, never()).setIngressDiscardRule(any(), any())
}
+
+ fun doTestVpnIngressDiscardRule_VpnType(vpnType: Int, expectAddRule: Boolean) {
+ val nr = nr(TRANSPORT_VPN)
+ val cb = TestableNetworkCallback()
+ cm.registerNetworkCallback(nr, cb)
+ val nc = vpnNc(vpnType)
+ val lp = lp(VPN_IFNAME, IPV6_LINK_ADDRESS, LOCAL_IPV6_LINK_ADDRRESS)
+ val agent = Agent(nc = nc, lp = lp)
+ agent.connect()
+ cb.expectAvailableCallbacks(agent.network, validated = false)
+
+ if (expectAddRule) {
+ verify(bpfNetMaps).setIngressDiscardRule(IPV6_ADDRESS, VPN_IFNAME)
+ } else {
+ verify(bpfNetMaps, never()).setIngressDiscardRule(any(), any())
+ }
+ }
+
+ @Test
+ fun testVpnIngressDiscardRule_ServiceVpn() {
+ doTestVpnIngressDiscardRule_VpnType(TYPE_VPN_SERVICE, expectAddRule = true)
+ }
+
+ @Test
+ fun testVpnIngressDiscardRule_LegacyVpn() {
+ // IngressDiscardRule should not be added to Legacy VPN
+ doTestVpnIngressDiscardRule_VpnType(TYPE_VPN_LEGACY, expectAddRule = false)
+ }
+
+ @Test
+ fun testVpnIngressDiscardRule_OemVpn() {
+ // IngressDiscardRule should not be added to OEM VPN
+ doTestVpnIngressDiscardRule_VpnType(TYPE_VPN_OEM, expectAddRule = false)
+ }
}
diff --git a/tests/unit/java/com/android/server/connectivityservice/CSQueuedCallbacksTest.kt b/tests/unit/java/com/android/server/connectivityservice/CSQueuedCallbacksTest.kt
new file mode 100644
index 0000000..fc2a06c
--- /dev/null
+++ b/tests/unit/java/com/android/server/connectivityservice/CSQueuedCallbacksTest.kt
@@ -0,0 +1,636 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivityservice
+
+import android.app.ActivityManager.UidFrozenStateChangedCallback
+import android.app.ActivityManager.UidFrozenStateChangedCallback.UID_FROZEN_STATE_FROZEN
+import android.app.ActivityManager.UidFrozenStateChangedCallback.UID_FROZEN_STATE_UNFROZEN
+import android.net.ConnectivityManager.BLOCKED_METERED_REASON_DATA_SAVER
+import android.net.ConnectivityManager.BLOCKED_REASON_NONE
+import android.net.ConnectivitySettingsManager
+import android.net.INetworkMonitor.NETWORK_VALIDATION_PROBE_DNS
+import android.net.INetworkMonitor.NETWORK_VALIDATION_PROBE_HTTPS
+import android.net.INetworkMonitor.NETWORK_VALIDATION_RESULT_VALID
+import android.net.LinkAddress
+import android.net.LinkProperties
+import android.net.LocalNetworkConfig
+import android.net.NetworkCapabilities
+import android.net.NetworkCapabilities.NET_CAPABILITY_FOREGROUND
+import android.net.NetworkCapabilities.NET_CAPABILITY_INTERNET
+import android.net.NetworkCapabilities.NET_CAPABILITY_LOCAL_NETWORK
+import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_CONGESTED
+import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_ROAMING
+import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_SUSPENDED
+import android.net.NetworkCapabilities.NET_CAPABILITY_TEMPORARILY_NOT_METERED
+import android.net.NetworkCapabilities.NET_CAPABILITY_VALIDATED
+import android.net.NetworkCapabilities.TRANSPORT_CELLULAR
+import android.net.NetworkCapabilities.TRANSPORT_ETHERNET
+import android.net.NetworkCapabilities.TRANSPORT_WIFI
+import android.net.NetworkPolicyManager.NetworkPolicyCallback
+import android.net.NetworkRequest
+import android.os.Build
+import android.os.Process
+import com.android.server.CALLING_UID_UNMOCKED
+import com.android.server.CSAgentWrapper
+import com.android.server.CSTest
+import com.android.server.FromS
+import com.android.server.HANDLER_TIMEOUT_MS
+import com.android.server.connectivity.ConnectivityFlags.QUEUE_CALLBACKS_FOR_FROZEN_APPS
+import com.android.server.defaultLp
+import com.android.server.defaultNc
+import com.android.testutils.DevSdkIgnoreRule
+import com.android.testutils.DevSdkIgnoreRunner
+import com.android.testutils.RecorderCallback.CallbackEntry.BlockedStatus
+import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged
+import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged
+import com.android.testutils.RecorderCallback.CallbackEntry.LocalInfoChanged
+import com.android.testutils.RecorderCallback.CallbackEntry.Lost
+import com.android.testutils.RecorderCallback.CallbackEntry.Resumed
+import com.android.testutils.RecorderCallback.CallbackEntry.Suspended
+import com.android.testutils.TestableNetworkCallback
+import com.android.testutils.visibleOnHandlerThread
+import com.android.testutils.waitForIdleSerialExecutor
+import java.util.Collections
+import kotlin.test.fail
+import org.junit.Before
+import org.junit.Rule
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.Parameterized
+import org.mockito.ArgumentCaptor
+import org.mockito.Mockito.any
+import org.mockito.Mockito.verify
+
+private const val TEST_UID = 42
+
+@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
+@DevSdkIgnoreRunner.MonitorThreadLeak
+@RunWith(DevSdkIgnoreRunner::class)
+class CSQueuedCallbacksTest(freezingBehavior: FreezingBehavior) : CSTest() {
+ companion object {
+ enum class FreezingBehavior {
+ UID_FROZEN,
+ UID_NOT_FROZEN,
+ UID_FROZEN_FEATURE_DISABLED
+ }
+
+ // Use a parameterized test with / without freezing to make it easy to compare and make sure
+ // freezing behavior (which callbacks are sent in which order) stays close to what happens
+ // without freezing.
+ @JvmStatic
+ @Parameterized.Parameters(name = "freezingBehavior={0}")
+ fun freezingBehavior() = listOf(
+ FreezingBehavior.UID_FROZEN,
+ FreezingBehavior.UID_NOT_FROZEN,
+ FreezingBehavior.UID_FROZEN_FEATURE_DISABLED
+ )
+
+ private val TAG = CSQueuedCallbacksTest::class.simpleName
+ ?: fail("Could not get test class name")
+ }
+
+ @get:Rule
+ val ignoreRule = DevSdkIgnoreRule()
+
+ private val mockedBlockedReasonsPerUid = Collections.synchronizedMap(mutableMapOf(
+ Process.myUid() to BLOCKED_REASON_NONE,
+ TEST_UID to BLOCKED_REASON_NONE
+ ))
+
+ private val freezeUids = freezingBehavior != FreezingBehavior.UID_NOT_FROZEN
+ private val expectAllCallbacks = freezingBehavior == FreezingBehavior.UID_NOT_FROZEN ||
+ freezingBehavior == FreezingBehavior.UID_FROZEN_FEATURE_DISABLED
+ init {
+ setFeatureEnabled(
+ QUEUE_CALLBACKS_FOR_FROZEN_APPS,
+ freezingBehavior != FreezingBehavior.UID_FROZEN_FEATURE_DISABLED
+ )
+ }
+
+ @Before
+ fun subclassSetUp() {
+ // Ensure cellular stays up. CS is recreated for each test so no cleanup is necessary.
+// cm.requestNetwork(
+// NetworkRequest.Builder().addTransportType(TRANSPORT_CELLULAR).build(),
+// TestableNetworkCallback()
+// )
+ }
+
+ @Test
+ fun testFrozenWhileNetworkConnects_UpdatesAreReceived() {
+ val cb = TestableNetworkCallback(logTag = TAG)
+ withCallingUid(TEST_UID) {
+ cm.registerNetworkCallback(NetworkRequest.Builder().build(), cb)
+ }
+ val agent = Agent(TRANSPORT_CELLULAR).apply { connect() }
+ val lpChangeOnConnect = agent.sendLpChange { setLinkAddresses("fe80:db8::123/64") }
+ val ncChangeOnConnect = agent.sendNcChange {
+ addCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED)
+ }
+
+ maybeSetUidsFrozen(true, TEST_UID)
+
+ val lpChange1WhileFrozen = agent.sendLpChange {
+ setLinkAddresses("fe80:db8::126/64")
+ }
+ val ncChange1WhileFrozen = agent.sendNcChange {
+ removeCapability(NET_CAPABILITY_NOT_ROAMING)
+ }
+ val ncChange2WhileFrozen = agent.sendNcChange {
+ addCapability(NET_CAPABILITY_NOT_ROAMING)
+ addCapability(NET_CAPABILITY_NOT_CONGESTED)
+ }
+ val lpChange2WhileFrozen = agent.sendLpChange {
+ setLinkAddresses("fe80:db8::125/64")
+ }
+ maybeSetUidsFrozen(false, TEST_UID)
+
+ // Verify callbacks that are sent before freezing
+ cb.expectAvailableCallbacks(agent.network, validated = false)
+ cb.expectLpWith(agent, lpChangeOnConnect)
+ cb.expectNcWith(agent, ncChangeOnConnect)
+
+ // Below callbacks should be skipped if the processes were frozen, since a single callback
+ // will be sent with the latest state after unfreezing
+ if (expectAllCallbacks) {
+ cb.expectLpWith(agent, lpChange1WhileFrozen)
+ cb.expectNcWith(agent, ncChange1WhileFrozen)
+ }
+
+ cb.expectNcWith(agent, ncChange2WhileFrozen)
+ cb.expectLpWith(agent, lpChange2WhileFrozen)
+
+ cb.assertNoCallback(timeoutMs = 0L)
+ }
+
+ @Test
+ fun testFrozenWhileNetworkConnects_SuspendedUnsuspendedWhileFrozen() {
+ val cb = TestableNetworkCallback(logTag = TAG)
+ withCallingUid(TEST_UID) {
+ cm.registerNetworkCallback(NetworkRequest.Builder().build(), cb)
+ }
+
+ val agent = Agent(TRANSPORT_CELLULAR).apply { connect() }
+ maybeSetUidsFrozen(true, TEST_UID)
+ val rmCap = agent.sendNcChange { removeCapability(NET_CAPABILITY_NOT_SUSPENDED) }
+ val addCap = agent.sendNcChange { addCapability(NET_CAPABILITY_NOT_SUSPENDED) }
+
+ maybeSetUidsFrozen(false, TEST_UID)
+
+ cb.expectAvailableCallbacks(agent.network, validated = false)
+ if (expectAllCallbacks) {
+ cb.expectNcWith(agent, rmCap)
+ cb.expect<Suspended>(agent)
+ cb.expectNcWith(agent, addCap)
+ cb.expect<Resumed>(agent)
+ } else {
+ // When frozen, a single NetworkCapabilitiesChange will be sent at unfreezing time,
+ // with nc actually identical to the original ones. This is because NetworkCapabilities
+ // callbacks were sent, but CS does not keep initial NetworkCapabilities in memory, so
+ // it cannot detect A->B->A.
+ cb.expect<CapabilitiesChanged>(agent) {
+ it.caps.hasCapability(NET_CAPABILITY_NOT_SUSPENDED)
+ }
+ }
+ cb.assertNoCallback(timeoutMs = 0L)
+ }
+
+ @Test
+ fun testFrozenWhileNetworkConnects_UnsuspendedWhileFrozen_GetResumedCallbackWhenUnfrozen() {
+ val cb = TestableNetworkCallback(logTag = TAG)
+ withCallingUid(TEST_UID) {
+ cm.registerNetworkCallback(NetworkRequest.Builder().build(), cb)
+ }
+
+ val agent = Agent(TRANSPORT_CELLULAR).apply { connect() }
+ val rmCap = agent.sendNcChange { removeCapability(NET_CAPABILITY_NOT_SUSPENDED) }
+ maybeSetUidsFrozen(true, TEST_UID)
+ val addCap = agent.sendNcChange { addCapability(NET_CAPABILITY_NOT_SUSPENDED) }
+ maybeSetUidsFrozen(false, TEST_UID)
+
+ cb.expectAvailableCallbacks(agent.network, validated = false)
+ cb.expectNcWith(agent, rmCap)
+ cb.expect<Suspended>(agent)
+ cb.expectNcWith(agent, addCap)
+ cb.expect<Resumed>(agent)
+ cb.assertNoCallback(timeoutMs = 0L)
+ }
+
+ @Test
+ fun testFrozenWhileNetworkConnects_BlockedUnblockedWhileFrozen_SingleCallbackIfFrozen() {
+ val cb = TestableNetworkCallback(logTag = TAG)
+ withCallingUid(TEST_UID) {
+ cm.registerNetworkCallback(NetworkRequest.Builder().build(), cb)
+ }
+ val agent = Agent(TRANSPORT_CELLULAR).apply { connect() }
+
+ maybeSetUidsFrozen(true, TEST_UID)
+ setUidsBlockedForDataSaver(true, TEST_UID)
+ setUidsBlockedForDataSaver(false, TEST_UID)
+ maybeSetUidsFrozen(false, TEST_UID)
+
+ cb.expectAvailableCallbacks(agent.network, validated = false)
+ if (expectAllCallbacks) {
+ cb.expect<BlockedStatus>(agent) { it.blocked }
+ }
+ // The unblocked callback is sent in any case (with the latest blocked reason), as the
+ // blocked reason may have changed, and ConnectivityService cannot know that it is the same
+ // as the original reason as it does not keep pre-freeze blocked reasons in memory.
+ cb.expect<BlockedStatus>(agent) { !it.blocked }
+ cb.assertNoCallback(timeoutMs = 0L)
+ }
+
+ @Test
+ fun testFrozenWhileNetworkConnects_BlockedWhileFrozen_GetLastBlockedCallbackOnlyIfFrozen() {
+ val cb = TestableNetworkCallback(logTag = TAG)
+ withCallingUid(TEST_UID) {
+ cm.registerNetworkCallback(NetworkRequest.Builder().build(), cb)
+ }
+ val agent = Agent(TRANSPORT_CELLULAR).apply { connect() }
+
+ maybeSetUidsFrozen(true, TEST_UID)
+ setUidsBlockedForDataSaver(true, TEST_UID)
+ setUidsBlockedForDataSaver(false, TEST_UID)
+ setUidsBlockedForDataSaver(true, TEST_UID)
+ maybeSetUidsFrozen(false, TEST_UID)
+
+ cb.expectAvailableCallbacks(agent.network, validated = false)
+ if (expectAllCallbacks) {
+ cb.expect<BlockedStatus>(agent) { it.blocked }
+ cb.expect<BlockedStatus>(agent) { !it.blocked }
+ }
+ cb.expect<BlockedStatus>(agent) { it.blocked }
+ cb.assertNoCallback(timeoutMs = 0L)
+ }
+
+ @Test
+ fun testNetworkCallback_NetworkToggledWhileFrozen_NotSeen() {
+ val cb = TestableNetworkCallback(logTag = TAG)
+ withCallingUid(TEST_UID) {
+ cm.registerNetworkCallback(NetworkRequest.Builder().build(), cb)
+ }
+ val cellAgent = Agent(TRANSPORT_CELLULAR).apply { connect() }
+ maybeSetUidsFrozen(true, TEST_UID)
+ val wifiAgent = Agent(TRANSPORT_WIFI).apply { connect() }
+ wifiAgent.disconnect()
+ maybeSetUidsFrozen(false, TEST_UID)
+
+ cb.expectAvailableCallbacks(cellAgent.network, validated = false)
+ if (expectAllCallbacks) {
+ cb.expectAvailableCallbacks(wifiAgent.network, validated = false)
+ cb.expect<Lost>(wifiAgent)
+ }
+ cb.assertNoCallback(timeoutMs = 0L)
+ }
+
+ @Test
+ fun testNetworkCallback_NetworkAppearedWhileFrozen_ReceiveLatestInfoInOnAvailable() {
+ val cb = TestableNetworkCallback(logTag = TAG)
+ withCallingUid(TEST_UID) {
+ cm.registerNetworkCallback(NetworkRequest.Builder().build(), cb)
+ }
+ maybeSetUidsFrozen(true, TEST_UID)
+ val agent = Agent(TRANSPORT_CELLULAR).apply { connect() }
+ waitForIdle()
+ agent.makeValidationSuccess()
+ val lpChange = agent.sendLpChange {
+ setLinkAddresses("fe80:db8::123/64")
+ }
+ val suspendedChange = agent.sendNcChange {
+ removeCapability(NET_CAPABILITY_NOT_SUSPENDED)
+ }
+ setUidsBlockedForDataSaver(true, TEST_UID)
+
+ maybeSetUidsFrozen(false, TEST_UID)
+
+ val expectLatestStatusInOnAvailable = !expectAllCallbacks
+ cb.expectAvailableCallbacks(
+ agent.network,
+ suspended = expectLatestStatusInOnAvailable,
+ validated = expectLatestStatusInOnAvailable,
+ blocked = expectLatestStatusInOnAvailable
+ )
+ if (expectAllCallbacks) {
+ cb.expectNcWith(agent) { addCapability(NET_CAPABILITY_VALIDATED) }
+ cb.expectLpWith(agent, lpChange)
+ cb.expectNcWith(agent, suspendedChange)
+ cb.expect<Suspended>(agent)
+ cb.expect<BlockedStatus>(agent) { it.blocked }
+ }
+ cb.assertNoCallback(timeoutMs = 0L)
+ }
+
+ @Test
+ @DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.UPSIDE_DOWN_CAKE)
+ fun testNetworkCallback_LocalNetworkAppearedWhileFrozen_ReceiveLatestInfoInOnAvailable() {
+ val cb = TestableNetworkCallback(logTag = TAG)
+ withCallingUid(TEST_UID) {
+ cm.registerNetworkCallback(
+ NetworkRequest.Builder().addCapability(NET_CAPABILITY_LOCAL_NETWORK).build(),
+ cb
+ )
+ }
+ val upstreamAgent = Agent(
+ nc = defaultNc()
+ .addTransportType(TRANSPORT_WIFI)
+ .addCapability(NET_CAPABILITY_INTERNET),
+ lp = defaultLp().apply { interfaceName = "wlan0" }
+ ).apply { connect() }
+ maybeSetUidsFrozen(true, TEST_UID)
+
+ val lnc = LocalNetworkConfig.Builder().build()
+ val localAgent = Agent(
+ nc = defaultNc()
+ .addCapability(NET_CAPABILITY_LOCAL_NETWORK)
+ .removeCapability(NET_CAPABILITY_INTERNET),
+ lp = defaultLp().apply { interfaceName = "local42" },
+ lnc = FromS(lnc)
+ ).apply { connect() }
+ localAgent.sendLocalNetworkConfig(
+ LocalNetworkConfig.Builder()
+ .setUpstreamSelector(
+ NetworkRequest.Builder()
+ .addCapability(NET_CAPABILITY_INTERNET)
+ .build()
+ )
+ .build()
+ )
+
+ maybeSetUidsFrozen(false, TEST_UID)
+
+ cb.expectAvailableCallbacks(
+ localAgent.network,
+ validated = false,
+ upstream = if (expectAllCallbacks) null else upstreamAgent.network
+ )
+ if (expectAllCallbacks) {
+ cb.expect<LocalInfoChanged>(localAgent) {
+ it.info.upstreamNetwork == upstreamAgent.network
+ }
+ }
+ cb.assertNoCallback(timeoutMs = 0L)
+ }
+
+ @Test
+ fun testNetworkRequest_NetworkSwitchesWhileFrozen_ReceiveLastNetworkUpdatesOnly() {
+ val cb = TestableNetworkCallback(logTag = TAG)
+ withCallingUid(TEST_UID) {
+ cm.requestNetwork(NetworkRequest.Builder().build(), cb)
+ }
+ val cellAgent = Agent(TRANSPORT_CELLULAR).apply { connect() }
+ maybeSetUidsFrozen(true, TEST_UID)
+ val wifiAgent = Agent(TRANSPORT_WIFI).apply { connect() }
+ val ethAgent = Agent(TRANSPORT_ETHERNET).apply { connect() }
+ waitForIdle()
+ ethAgent.makeValidationSuccess()
+ maybeSetUidsFrozen(false, TEST_UID)
+
+ cb.expectAvailableCallbacks(cellAgent.network, validated = false)
+ if (expectAllCallbacks) {
+ cb.expectAvailableCallbacks(wifiAgent.network, validated = false)
+ cb.expectAvailableCallbacks(ethAgent.network, validated = false)
+ cb.expectNcWith(ethAgent) { addCapability(NET_CAPABILITY_VALIDATED) }
+ } else {
+ cb.expectAvailableCallbacks(ethAgent.network, validated = true)
+ }
+ cb.assertNoCallback(timeoutMs = 0L)
+ }
+
+ @Test
+ fun testNetworkRequest_NetworkSwitchesBackWhileFrozen_ReceiveNoAvailableCallback() {
+ val cb = TestableNetworkCallback(logTag = TAG)
+ withCallingUid(TEST_UID) {
+ cm.requestNetwork(NetworkRequest.Builder().build(), cb)
+ }
+ val cellAgent = Agent(TRANSPORT_CELLULAR).apply { connect() }
+ maybeSetUidsFrozen(true, TEST_UID)
+ val wifiAgent = Agent(TRANSPORT_WIFI).apply { connect() }
+ waitForIdle()
+
+ // CS switches back to validated cell over non-validated Wi-Fi
+ cellAgent.makeValidationSuccess()
+ val cellLpChange = cellAgent.sendLpChange {
+ setLinkAddresses("fe80:db8::123/64")
+ }
+ setUidsBlockedForDataSaver(true, TEST_UID)
+ maybeSetUidsFrozen(false, TEST_UID)
+
+ cb.expectAvailableCallbacks(cellAgent.network, validated = false)
+ if (expectAllCallbacks) {
+ cb.expectAvailableCallbacks(wifiAgent.network, validated = false)
+ // There is an extra "double validated" CapabilitiesChange callback (b/245893397), so
+ // callbacks are (AVAIL, NC, LP), extra NC, then further updates (LP and BLK here).
+ cb.expectAvailableDoubleValidatedCallbacks(cellAgent.network)
+ cb.expectLpWith(cellAgent, cellLpChange)
+ cb.expect<BlockedStatus>(cellAgent) { it.blocked }
+ } else {
+ cb.expectNcWith(cellAgent) {
+ addCapability(NET_CAPABILITY_VALIDATED)
+ }
+ cb.expectLpWith(cellAgent, cellLpChange)
+ cb.expect<BlockedStatus>(cellAgent) { it.blocked }
+ }
+ cb.assertNoCallback(timeoutMs = 0L)
+ }
+
+ @Test
+ fun testTrackDefaultRequest_AppFrozenWhilePerAppDefaultRequestFiled_ReceiveChangeCallbacks() {
+ val cellAgent = Agent(TRANSPORT_CELLULAR, baseNc = makeInternetNc()).apply { connect() }
+ waitForIdle()
+
+ val cb = TestableNetworkCallback(logTag = TAG)
+ withCallingUid(TEST_UID) {
+ cm.registerDefaultNetworkCallback(cb)
+ }
+ maybeSetUidsFrozen(true, TEST_UID)
+
+ // Change LinkProperties twice before the per-app network request is applied
+ val lpChange1 = cellAgent.sendLpChange {
+ setLinkAddresses("fe80:db8::123/64")
+ }
+ val lpChange2 = cellAgent.sendLpChange {
+ setLinkAddresses("fe80:db8::124/64")
+ }
+ setMobileDataPreferredUids(setOf(TEST_UID))
+
+ // Change NetworkCapabilities after the per-app network request is applied
+ val ncChange = cellAgent.sendNcChange {
+ addCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED)
+ }
+ maybeSetUidsFrozen(false, TEST_UID)
+
+ // Even if a per-app network request was filed to replace the default network request for
+ // the app, all network change callbacks are received
+ cb.expectAvailableCallbacks(cellAgent.network, validated = false)
+ if (expectAllCallbacks) {
+ cb.expectLpWith(cellAgent, lpChange1)
+ }
+ cb.expectLpWith(cellAgent, lpChange2)
+ cb.expectNcWith(cellAgent, ncChange)
+ cb.assertNoCallback(timeoutMs = 0L)
+ }
+
+ @Test
+ fun testTrackDefaultRequest_AppFrozenWhilePerAppDefaultToggled_GetStatusUpdateCallbacksOnly() {
+ // Add validated Wi-Fi and non-validated cell, expect Wi-Fi is preferred by default
+ val wifiAgent = Agent(TRANSPORT_WIFI, baseNc = makeInternetNc()).apply { connect() }
+ wifiAgent.makeValidationSuccess()
+ val cellAgent = Agent(TRANSPORT_CELLULAR, baseNc = makeInternetNc()).apply { connect() }
+ waitForIdle()
+
+ val cb = TestableNetworkCallback(logTag = TAG)
+ withCallingUid(TEST_UID) {
+ cm.registerDefaultNetworkCallback(cb)
+ }
+ maybeSetUidsFrozen(true, TEST_UID)
+
+ // LP change on the original Wi-Fi network
+ val lpChange = wifiAgent.sendLpChange {
+ setLinkAddresses("fe80:db8::123/64")
+ }
+ // Set per-app default to cell, then unset it
+ setMobileDataPreferredUids(setOf(TEST_UID))
+ setMobileDataPreferredUids(emptySet())
+
+ maybeSetUidsFrozen(false, TEST_UID)
+
+ cb.expectAvailableCallbacks(wifiAgent.network)
+ if (expectAllCallbacks) {
+ cb.expectLpWith(wifiAgent, lpChange)
+ cb.expectAvailableCallbacks(cellAgent.network, validated = false)
+ // Cellular stops being foreground since it is now matched for this app
+ cb.expect<CapabilitiesChanged> { it.caps.hasCapability(NET_CAPABILITY_FOREGROUND) }
+ cb.expectAvailableCallbacks(wifiAgent.network)
+ } else {
+ // After switching to cell and back while frozen, only network attribute update
+ // callbacks (and not AVAILABLE) for the original Wi-Fi network should be sent
+ cb.expect<CapabilitiesChanged>(wifiAgent)
+ cb.expectLpWith(wifiAgent, lpChange)
+ cb.expect<BlockedStatus> { !it.blocked }
+ }
+ cb.assertNoCallback(timeoutMs = 0L)
+ }
+
+ private fun setUidsBlockedForDataSaver(blocked: Boolean, vararg uid: Int) {
+ val reason = if (blocked) {
+ BLOCKED_METERED_REASON_DATA_SAVER
+ } else {
+ BLOCKED_REASON_NONE
+ }
+ if (deps.isAtLeastV) {
+ visibleOnHandlerThread(csHandler) {
+ service.handleBlockedReasonsChanged(uid.map { android.util.Pair(it, reason) })
+ }
+ } else {
+ notifyLegacyBlockedReasonChanged(reason, uid)
+ waitForIdle()
+ }
+ }
+
+ @Suppress("DEPRECATION")
+ private fun notifyLegacyBlockedReasonChanged(reason: Int, uids: IntArray) {
+ val cbCaptor = ArgumentCaptor.forClass(NetworkPolicyCallback::class.java)
+ verify(context.networkPolicyManager).registerNetworkPolicyCallback(
+ any(),
+ cbCaptor.capture()
+ )
+ uids.forEach {
+ cbCaptor.value.onUidBlockedReasonChanged(it, reason)
+ }
+ }
+
+ private fun withCallingUid(uid: Int, action: () -> Unit) {
+ deps.callingUid = uid
+ action()
+ deps.callingUid = CALLING_UID_UNMOCKED
+ }
+
+ private fun getUidFrozenStateChangedCallback(): UidFrozenStateChangedCallback {
+ val captor = ArgumentCaptor.forClass(UidFrozenStateChangedCallback::class.java)
+ verify(activityManager).registerUidFrozenStateChangedCallback(any(), captor.capture())
+ return captor.value
+ }
+
+ private fun maybeSetUidsFrozen(frozen: Boolean, vararg uids: Int) {
+ if (!freezeUids) return
+ val state = if (frozen) UID_FROZEN_STATE_FROZEN else UID_FROZEN_STATE_UNFROZEN
+ getUidFrozenStateChangedCallback()
+ .onUidFrozenStateChanged(uids, IntArray(uids.size) { state })
+ waitForIdle()
+ }
+
+ private fun CSAgentWrapper.makeValidationSuccess() {
+ setValidationResult(
+ NETWORK_VALIDATION_RESULT_VALID,
+ probesCompleted = NETWORK_VALIDATION_PROBE_DNS or NETWORK_VALIDATION_PROBE_HTTPS,
+ probesSucceeded = NETWORK_VALIDATION_PROBE_DNS or NETWORK_VALIDATION_PROBE_HTTPS
+ )
+ cm.reportNetworkConnectivity(network, true)
+ // Ensure validation is scheduled
+ waitForIdle()
+ // Ensure validation completes on mock executor
+ waitForIdleSerialExecutor(CSTestExecutor, HANDLER_TIMEOUT_MS)
+ // Ensure validation results are processed
+ waitForIdle()
+ }
+
+ private fun setMobileDataPreferredUids(uids: Set<Int>) {
+ ConnectivitySettingsManager.setMobileDataPreferredUids(context, uids)
+ service.updateMobileDataPreferredUids()
+ waitForIdle()
+ }
+}
+
+private fun makeInternetNc() = NetworkCapabilities.Builder(defaultNc())
+ .addCapability(NET_CAPABILITY_INTERNET)
+ .build()
+
+private fun CSAgentWrapper.sendLpChange(
+ mutator: LinkProperties.() -> Unit
+): LinkProperties.() -> Unit {
+ lp.mutator()
+ sendLinkProperties(lp)
+ return mutator
+}
+
+private fun CSAgentWrapper.sendNcChange(
+ mutator: NetworkCapabilities.() -> Unit
+): NetworkCapabilities.() -> Unit {
+ nc.mutator()
+ sendNetworkCapabilities(nc)
+ return mutator
+}
+
+private fun TestableNetworkCallback.expectLpWith(
+ agent: CSAgentWrapper,
+ change: LinkProperties.() -> Unit
+) = expect<LinkPropertiesChanged>(agent) {
+ // This test uses changes that are no-op when already applied (idempotent): verify that the
+ // change is already applied.
+ it.lp == LinkProperties(it.lp).apply(change)
+}
+
+private fun TestableNetworkCallback.expectNcWith(
+ agent: CSAgentWrapper,
+ change: NetworkCapabilities.() -> Unit
+) = expect<CapabilitiesChanged>(agent) {
+ it.caps == NetworkCapabilities(it.caps).apply(change)
+}
+
+private fun LinkProperties.setLinkAddresses(vararg addrs: String) {
+ setLinkAddresses(addrs.map { LinkAddress(it) })
+}
diff --git a/tests/unit/java/com/android/server/connectivityservice/base/CSAgentWrapper.kt b/tests/unit/java/com/android/server/connectivityservice/base/CSAgentWrapper.kt
index 13c5cbc..1f5ee32 100644
--- a/tests/unit/java/com/android/server/connectivityservice/base/CSAgentWrapper.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/base/CSAgentWrapper.kt
@@ -192,7 +192,8 @@
connect()
}
- fun setProbesStatus(probesCompleted: Int, probesSucceeded: Int) {
+ fun setValidationResult(result: Int, probesCompleted: Int, probesSucceeded: Int) {
+ nmValidationResult = result
nmProbesCompleted = probesCompleted
nmProbesSucceeded = probesSucceeded
}
@@ -204,8 +205,10 @@
// in the beginning. Because NETWORK_VALIDATION_PROBE_HTTP is the decisive probe for captive
// portal, considering the NETWORK_VALIDATION_PROBE_HTTPS hasn't probed yet and set only
// DNS and HTTP probes completed.
- setProbesStatus(
- NETWORK_VALIDATION_PROBE_DNS or NETWORK_VALIDATION_PROBE_HTTP /* probesCompleted */,
- VALIDATION_RESULT_INVALID /* probesSucceeded */)
+ setValidationResult(
+ VALIDATION_RESULT_INVALID,
+ probesCompleted = NETWORK_VALIDATION_PROBE_DNS or NETWORK_VALIDATION_PROBE_HTTP,
+ probesSucceeded = NO_PROBE_RESULT
+ )
}
}
diff --git a/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt b/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
index de56ae5..46c25d2 100644
--- a/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
@@ -538,8 +538,12 @@
provider: NetworkProvider? = null
) = CSAgentWrapper(context, deps, csHandlerThread, networkStack,
nac, nc, lp, lnc, score, provider)
- fun Agent(vararg transports: Int, lp: LinkProperties = defaultLp()): CSAgentWrapper {
- val nc = NetworkCapabilities.Builder().apply {
+ fun Agent(
+ vararg transports: Int,
+ baseNc: NetworkCapabilities = defaultNc(),
+ lp: LinkProperties = defaultLp()
+ ): CSAgentWrapper {
+ val nc = NetworkCapabilities.Builder(baseNc).apply {
transports.forEach {
addTransportType(it)
}
diff --git a/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java b/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java
index 7e0a225..3d2f389 100644
--- a/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java
+++ b/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java
@@ -69,6 +69,8 @@
import static com.android.server.net.NetworkStatsEventLogger.POLL_REASON_RAT_CHANGED;
import static com.android.server.net.NetworkStatsEventLogger.PollEvent.pollReasonNameOf;
import static com.android.server.net.NetworkStatsService.ACTION_NETWORK_STATS_POLL;
+import static com.android.server.net.NetworkStatsService.ACTION_NETWORK_STATS_UPDATED;
+import static com.android.server.net.NetworkStatsService.BROADCAST_NETWORK_STATS_UPDATED_RATE_LIMIT_ENABLED_FLAG;
import static com.android.server.net.NetworkStatsService.DEFAULT_TRAFFIC_STATS_CACHE_EXPIRY_DURATION_MS;
import static com.android.server.net.NetworkStatsService.DEFAULT_TRAFFIC_STATS_CACHE_MAX_ENTRIES;
import static com.android.server.net.NetworkStatsService.NETSTATS_FASTDATAINPUT_FALLBACKS_COUNTER_NAME;
@@ -82,6 +84,7 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.mockito.AdditionalMatchers.aryEq;
@@ -101,8 +104,10 @@
import android.annotation.NonNull;
import android.app.AlarmManager;
+import android.content.BroadcastReceiver;
import android.content.Context;
import android.content.Intent;
+import android.content.IntentFilter;
import android.content.pm.PackageManager;
import android.content.res.Resources;
import android.database.ContentObserver;
@@ -138,6 +143,7 @@
import android.provider.Settings;
import android.system.ErrnoException;
import android.telephony.TelephonyManager;
+import android.testing.TestableLooper;
import android.text.TextUtils;
import android.util.ArrayMap;
import android.util.IndentingPrintWriter;
@@ -150,6 +156,7 @@
import com.android.connectivity.resources.R;
import com.android.internal.util.FileRotator;
import com.android.internal.util.test.BroadcastInterceptingContext;
+import com.android.net.module.util.ArrayTrackRecord;
import com.android.net.module.util.BpfDump;
import com.android.net.module.util.IBpfMap;
import com.android.net.module.util.LocationPermissionChecker;
@@ -618,6 +625,12 @@
}
@Override
+ public boolean enabledBroadcastNetworkStatsUpdatedRateLimiting(Context ctx) {
+ return mFeatureFlags.getOrDefault(
+ BROADCAST_NETWORK_STATS_UPDATED_RATE_LIMIT_ENABLED_FLAG, true);
+ }
+
+ @Override
public int getTrafficStatsRateLimitCacheExpiryDuration() {
return DEFAULT_TRAFFIC_STATS_CACHE_EXPIRY_DURATION_MS;
}
@@ -2617,6 +2630,8 @@
private void mockDefaultSettings() throws Exception {
mockSettings(HOUR_IN_MILLIS, WEEK_IN_MILLIS);
+ mSettings.setBroadcastNetworkStatsUpdateDelayMs(
+ NetworkStatsService.BROADCAST_NETWORK_STATS_UPDATED_DELAY_MS);
}
private void mockSettings(long bucketDuration, long deleteAge) {
@@ -2631,6 +2646,8 @@
@NonNull
private volatile Config mConfig;
private final AtomicBoolean mCombineSubtypeEnabled = new AtomicBoolean();
+ private long mBroadcastNetworkStatsUpdateDelayMs =
+ NetworkStatsService.BROADCAST_NETWORK_STATS_UPDATED_DELAY_MS;
TestNetworkStatsSettings(long bucketDuration, long deleteAge) {
mConfig = new Config(bucketDuration, deleteAge, deleteAge);
@@ -2693,6 +2710,15 @@
public boolean getAugmentEnabled() {
return false;
}
+
+ @Override
+ public long getBroadcastNetworkStatsUpdateDelayMs() {
+ return mBroadcastNetworkStatsUpdateDelayMs;
+ }
+
+ public void setBroadcastNetworkStatsUpdateDelayMs(long broadcastDelay) {
+ mBroadcastNetworkStatsUpdateDelayMs = broadcastDelay;
+ }
}
private void assertStatsFilesExist(boolean exist) {
@@ -3064,4 +3090,91 @@
final String dump = getDump();
assertDumpContains(dump, "Log for testing");
}
+
+ private static class TestNetworkStatsUpdatedReceiver extends BroadcastReceiver {
+ private final ArrayTrackRecord<Intent>.ReadHead mHistory;
+
+ TestNetworkStatsUpdatedReceiver() {
+ mHistory = (new ArrayTrackRecord<Intent>()).newReadHead();
+ }
+
+ @Override
+ public void onReceive(Context context, Intent intent) {
+ mHistory.add(intent);
+ }
+
+ /**
+ * Assert no broadcast intent is received in blocking manner
+ */
+ public void assertNoBroadcastIntentReceived() {
+ assertNull(mHistory.peek());
+ }
+
+ /**
+ * Assert an intent is received and remove it from queue
+ */
+ public void assertBroadcastIntentReceived() {
+ assertNotNull(mHistory.poll(WAIT_TIMEOUT, number -> true));
+ }
+ }
+
+ @FeatureFlag(name = BROADCAST_NETWORK_STATS_UPDATED_RATE_LIMIT_ENABLED_FLAG)
+ @Test
+ public void testNetworkStatsUpdatedIntentSpam_rateLimitOn() throws Exception {
+ // Set the update delay long enough that messages won't be processed before unblocked
+ // Set a short time to test the behavior before reaching delay.
+ // Constraint: test running time < toleranceMs < update delay time
+ mSettings.setBroadcastNetworkStatsUpdateDelayMs(100_000L);
+ final long toleranceMs = 5000;
+
+ final TestableLooper mTestableLooper = new TestableLooper(mHandlerThread.getLooper());
+ final TestNetworkStatsUpdatedReceiver receiver = new TestNetworkStatsUpdatedReceiver();
+ mServiceContext.registerReceiver(receiver, new IntentFilter(ACTION_NETWORK_STATS_UPDATED));
+
+ try {
+ // Test that before anything, the intent is delivered immediately
+ mService.forceUpdate();
+ mTestableLooper.processAllMessages();
+ receiver.assertBroadcastIntentReceived();
+ receiver.assertNoBroadcastIntentReceived();
+
+ // Test that the next two intents results in exactly one intent delivered
+ for (int i = 0; i < 2; i++) {
+ mService.forceUpdate();
+ }
+ // Test that the delay depends on our set value
+ mTestableLooper.moveTimeForward(mSettings.getBroadcastNetworkStatsUpdateDelayMs()
+ - toleranceMs);
+ mTestableLooper.processAllMessages();
+ receiver.assertNoBroadcastIntentReceived();
+
+ // Unblock messages and test that the second and third update
+ // is broadcasted right after the delay
+ mTestableLooper.moveTimeForward(toleranceMs);
+ mTestableLooper.processAllMessages();
+ receiver.assertBroadcastIntentReceived();
+ receiver.assertNoBroadcastIntentReceived();
+
+ } finally {
+ mTestableLooper.destroy();
+ }
+ }
+
+ @FeatureFlag(name = BROADCAST_NETWORK_STATS_UPDATED_RATE_LIMIT_ENABLED_FLAG, enabled = false)
+ @Test
+ public void testNetworkStatsUpdatedIntentSpam_rateLimitOff() throws Exception {
+ // Set the update delay long enough to ensure that messages are processed
+ // despite the rate limit.
+ mSettings.setBroadcastNetworkStatsUpdateDelayMs(100_000L);
+
+ final TestNetworkStatsUpdatedReceiver receiver = new TestNetworkStatsUpdatedReceiver();
+ mServiceContext.registerReceiver(receiver, new IntentFilter(ACTION_NETWORK_STATS_UPDATED));
+
+ for (int i = 0; i < 2; i++) {
+ mService.forceUpdate();
+ waitForIdle();
+ receiver.assertBroadcastIntentReceived();
+ }
+ receiver.assertNoBroadcastIntentReceived();
+ }
}
diff --git a/thread/framework/java/android/net/thread/ChannelMaxPower.aidl b/thread/framework/java/android/net/thread/ChannelMaxPower.aidl
index bcda8a8..abc00b9 100644
--- a/thread/framework/java/android/net/thread/ChannelMaxPower.aidl
+++ b/thread/framework/java/android/net/thread/ChannelMaxPower.aidl
@@ -16,11 +16,11 @@
package android.net.thread;
- /**
- * Mapping from a channel to its max power.
- *
- * {@hide}
- */
+/**
+ * Mapping from a channel to its max power.
+ *
+ * {@hide}
+ */
parcelable ChannelMaxPower {
int channel; // The Thread radio channel.
int maxPower; // The max power in the unit of 0.01dBm. Passing INT16_MAX(32767) will
diff --git a/thread/framework/java/android/net/thread/IOperationalDatasetCallback.aidl b/thread/framework/java/android/net/thread/IOperationalDatasetCallback.aidl
index b576b33..3fece65 100644
--- a/thread/framework/java/android/net/thread/IOperationalDatasetCallback.aidl
+++ b/thread/framework/java/android/net/thread/IOperationalDatasetCallback.aidl
@@ -23,6 +23,8 @@
* @hide
*/
oneway interface IOperationalDatasetCallback {
- void onActiveOperationalDatasetChanged(in @nullable ActiveOperationalDataset activeOpDataset);
- void onPendingOperationalDatasetChanged(in @nullable PendingOperationalDataset pendingOpDataset);
+ void onActiveOperationalDatasetChanged(
+ in @nullable ActiveOperationalDataset activeOpDataset);
+ void onPendingOperationalDatasetChanged(
+ in @nullable PendingOperationalDataset pendingOpDataset);
}
diff --git a/thread/framework/java/android/net/thread/IThreadNetworkController.aidl b/thread/framework/java/android/net/thread/IThreadNetworkController.aidl
index f50de74..b7f68c9 100644
--- a/thread/framework/java/android/net/thread/IThreadNetworkController.aidl
+++ b/thread/framework/java/android/net/thread/IThreadNetworkController.aidl
@@ -28,9 +28,9 @@
import android.net.thread.ThreadConfiguration;
/**
-* Interface for communicating with ThreadNetworkControllerService.
-* @hide
-*/
+ * Interface for communicating with ThreadNetworkControllerService.
+ * @hide
+ */
interface IThreadNetworkController {
void registerStateCallback(in IStateCallback callback);
void unregisterStateCallback(in IStateCallback callback);
@@ -38,10 +38,12 @@
void unregisterOperationalDatasetCallback(in IOperationalDatasetCallback callback);
void join(in ActiveOperationalDataset activeOpDataset, in IOperationReceiver receiver);
- void scheduleMigration(in PendingOperationalDataset pendingOpDataset, in IOperationReceiver receiver);
+ void scheduleMigration(
+ in PendingOperationalDataset pendingOpDataset, in IOperationReceiver receiver);
void leave(in IOperationReceiver receiver);
- void setTestNetworkAsUpstream(in String testNetworkInterfaceName, in IOperationReceiver receiver);
+ void setTestNetworkAsUpstream(
+ in String testNetworkInterfaceName, in IOperationReceiver receiver);
void setChannelMaxPowers(in ChannelMaxPower[] channelMaxPowers, in IOperationReceiver receiver);
int getThreadVersion();
diff --git a/thread/framework/java/android/net/thread/IThreadNetworkManager.aidl b/thread/framework/java/android/net/thread/IThreadNetworkManager.aidl
index 0e394b1..b63cd72 100644
--- a/thread/framework/java/android/net/thread/IThreadNetworkManager.aidl
+++ b/thread/framework/java/android/net/thread/IThreadNetworkManager.aidl
@@ -19,9 +19,9 @@
import android.net.thread.IThreadNetworkController;
/**
-* Interface for communicating with ThreadNetworkService.
-* @hide
-*/
+ * Interface for communicating with ThreadNetworkService.
+ * @hide
+ */
interface IThreadNetworkManager {
List<IThreadNetworkController> getAllThreadNetworkControllers();
}
diff --git a/thread/scripts/make-pretty.sh b/thread/scripts/make-pretty.sh
index c176bfa..e012d41 100755
--- a/thread/scripts/make-pretty.sh
+++ b/thread/scripts/make-pretty.sh
@@ -1,9 +1,35 @@
#!/usr/bin/env bash
-SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
+ANDROID_ROOT_DIR=$(
+ while [ ! -d ".repo" ] && [ "$PWD" != "/" ]; do cd ..; done
+ pwd
+)
-GOOGLE_JAVA_FORMAT=$SCRIPT_DIR/../../../../../prebuilts/tools/common/google-java-format/google-java-format
-ANDROID_BP_FORMAT=$SCRIPT_DIR/../../../../../prebuilts/build-tools/linux-x86/bin/bpfmt
+if [ ! -d "$ANDROID_ROOT_DIR/.repo" ]; then
+ echo "Error: The script has to run in an Android repo checkout"
+ exit 1
+fi
-$GOOGLE_JAVA_FORMAT --aosp -i $(find $SCRIPT_DIR/../ -name "*.java")
-$ANDROID_BP_FORMAT -w $(find $SCRIPT_DIR/../ -name "*.bp")
+GOOGLE_JAVA_FORMAT=$ANDROID_ROOT_DIR/prebuilts/tools/common/google-java-format/google-java-format
+ANDROID_BP_FORMAT=$ANDROID_ROOT_DIR/prebuilts/build-tools/linux-x86/bin/bpfmt
+AIDL_FORMAT=$ANDROID_ROOT_DIR/system/tools/aidl/aidl-format.sh
+
+CONNECTIVITY_DIR=$ANDROID_ROOT_DIR/packages/modules/Connectivity
+OPENTHREAD_DIR=$ANDROID_ROOT_DIR/external/openthread
+OTBR_POSIX_DIR=$ANDROID_ROOT_DIR/external/ot-br-posix
+
+ALLOWED_CODE_DIRS=($CONNECTIVITY_DIR $OPENTHREAD_DIR $OTBR_POSIX_DIR)
+CODE_DIR=$(git rev-parse --show-toplevel)
+
+if [[ ! " ${ALLOWED_CODE_DIRS[@]} " =~ " ${CODE_DIR} " ]]; then
+ echo "Error: The script has to run in the Git project Connectivity, openthread or ot-br-posix"
+ exit 1
+fi
+
+if [[ $CODE_DIR == $CONNECTIVITY_DIR ]]; then
+ CODE_DIR=$CODE_DIR"/thread"
+fi
+
+$GOOGLE_JAVA_FORMAT --aosp -i $(find $CODE_DIR -name "*.java")
+$ANDROID_BP_FORMAT -w $(find $CODE_DIR -name "*.bp")
+$AIDL_FORMAT -w $(find $CODE_DIR -name "*.aidl")
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
index 19084c6..b621a6a 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
@@ -108,6 +108,7 @@
import android.os.HandlerThread;
import android.os.IBinder;
import android.os.Looper;
+import android.os.ParcelFileDescriptor;
import android.os.RemoteException;
import android.os.SystemClock;
import android.os.UserManager;
@@ -119,7 +120,7 @@
import com.android.server.ServiceManagerWrapper;
import com.android.server.connectivity.ConnectivityResources;
import com.android.server.thread.openthread.BackboneRouterState;
-import com.android.server.thread.openthread.BorderRouterConfigurationParcel;
+import com.android.server.thread.openthread.BorderRouterConfiguration;
import com.android.server.thread.openthread.DnsTxtAttribute;
import com.android.server.thread.openthread.IChannelMasksReceiver;
import com.android.server.thread.openthread.IOtDaemon;
@@ -212,7 +213,7 @@
private boolean mUserRestricted;
private boolean mForceStopOtDaemonEnabled;
- private BorderRouterConfigurationParcel mBorderRouterConfig;
+ private BorderRouterConfiguration mBorderRouterConfig;
@VisibleForTesting
ThreadNetworkControllerService(
@@ -237,7 +238,11 @@
mInfraIfController = infraIfController;
mUpstreamNetworkRequest = newUpstreamNetworkRequest();
mNetworkToInterface = new HashMap<Network, String>();
- mBorderRouterConfig = new BorderRouterConfigurationParcel();
+ mBorderRouterConfig =
+ new BorderRouterConfiguration.Builder()
+ .setIsBorderRoutingEnabled(true)
+ .setInfraInterfaceName(null)
+ .build();
mPersistentSettings = persistentSettings;
mNsdPublisher = nsdPublisher;
mUserManager = userManager;
@@ -1227,38 +1232,54 @@
}
}
- private void enableBorderRouting(String infraIfName) {
- if (mBorderRouterConfig.isBorderRoutingEnabled
- && infraIfName.equals(mBorderRouterConfig.infraInterfaceName)) {
+ private void configureBorderRouter(BorderRouterConfiguration borderRouterConfig) {
+ if (mBorderRouterConfig.equals(borderRouterConfig)) {
return;
}
- Log.i(TAG, "Enable border routing on AIL: " + infraIfName);
+ Log.i(
+ TAG,
+ "Configuring Border Router: " + mBorderRouterConfig + " -> " + borderRouterConfig);
+ mBorderRouterConfig = borderRouterConfig;
+ ParcelFileDescriptor infraIcmp6Socket = null;
+ if (mBorderRouterConfig.infraInterfaceName != null) {
+ try {
+ infraIcmp6Socket =
+ mInfraIfController.createIcmp6Socket(
+ mBorderRouterConfig.infraInterfaceName);
+ } catch (IOException e) {
+ Log.i(TAG, "Failed to create ICMPv6 socket on infra network interface", e);
+ }
+ }
try {
- mBorderRouterConfig.infraInterfaceName = infraIfName;
- mBorderRouterConfig.infraInterfaceIcmp6Socket =
- mInfraIfController.createIcmp6Socket(infraIfName);
- mBorderRouterConfig.isBorderRoutingEnabled = true;
-
getOtDaemon()
.configureBorderRouter(
- mBorderRouterConfig, new ConfigureBorderRouterStatusReceiver());
- } catch (RemoteException | IOException | ThreadNetworkException e) {
- Log.w(TAG, "Failed to enable border routing", e);
+ mBorderRouterConfig,
+ infraIcmp6Socket,
+ new ConfigureBorderRouterStatusReceiver());
+ } catch (RemoteException | ThreadNetworkException e) {
+ Log.w(TAG, "Failed to configure border router " + mBorderRouterConfig, e);
}
}
+ private void enableBorderRouting(String infraIfName) {
+ BorderRouterConfiguration borderRouterConfig =
+ newBorderRouterConfigBuilder(mBorderRouterConfig)
+ .setIsBorderRoutingEnabled(true)
+ .setInfraInterfaceName(infraIfName)
+ .build();
+ Log.i(TAG, "Enable border routing on AIL: " + infraIfName);
+ configureBorderRouter(borderRouterConfig);
+ }
+
private void disableBorderRouting() {
mUpstreamNetwork = null;
- mBorderRouterConfig.infraInterfaceName = null;
- mBorderRouterConfig.infraInterfaceIcmp6Socket = null;
- mBorderRouterConfig.isBorderRoutingEnabled = false;
- try {
- getOtDaemon()
- .configureBorderRouter(
- mBorderRouterConfig, new ConfigureBorderRouterStatusReceiver());
- } catch (RemoteException | ThreadNetworkException e) {
- Log.w(TAG, "Failed to disable border routing", e);
- }
+ BorderRouterConfiguration borderRouterConfig =
+ newBorderRouterConfigBuilder(mBorderRouterConfig)
+ .setIsBorderRoutingEnabled(false)
+ .setInfraInterfaceName(null)
+ .build();
+ Log.i(TAG, "Disabling border routing");
+ configureBorderRouter(borderRouterConfig);
}
private void handleThreadInterfaceStateChanged(boolean isUp) {
@@ -1359,6 +1380,13 @@
return builder.build();
}
+ private static BorderRouterConfiguration.Builder newBorderRouterConfigBuilder(
+ BorderRouterConfiguration brConfig) {
+ return new BorderRouterConfiguration.Builder()
+ .setIsBorderRoutingEnabled(brConfig.isBorderRoutingEnabled)
+ .setInfraInterfaceName(brConfig.infraInterfaceName);
+ }
+
private static final class CallbackMetadata {
private static long gId = 0;
diff --git a/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java b/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java
index b6d9aa3..9e8dc3a 100644
--- a/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java
+++ b/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java
@@ -17,17 +17,20 @@
package android.net.thread;
import static android.Manifest.permission.MANAGE_TEST_NETWORKS;
+import static android.net.InetAddresses.parseNumericAddress;
import static android.net.thread.utils.IntegrationTestUtils.DEFAULT_DATASET;
import static android.net.thread.utils.IntegrationTestUtils.getIpv6LinkAddresses;
+import static android.net.thread.utils.IntegrationTestUtils.isExpectedIcmpv4Packet;
import static android.net.thread.utils.IntegrationTestUtils.isExpectedIcmpv6Packet;
-import static android.net.thread.utils.IntegrationTestUtils.isFromIpv6Source;
+import static android.net.thread.utils.IntegrationTestUtils.isFrom;
import static android.net.thread.utils.IntegrationTestUtils.isInMulticastGroup;
-import static android.net.thread.utils.IntegrationTestUtils.isToIpv6Destination;
+import static android.net.thread.utils.IntegrationTestUtils.isTo;
import static android.net.thread.utils.IntegrationTestUtils.joinNetworkAndWaitForOmr;
import static android.net.thread.utils.IntegrationTestUtils.newPacketReader;
import static android.net.thread.utils.IntegrationTestUtils.pollForPacket;
import static android.net.thread.utils.IntegrationTestUtils.sendUdpMessage;
import static android.net.thread.utils.IntegrationTestUtils.waitFor;
+import static android.system.OsConstants.ICMP_ECHO;
import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ECHO_REPLY_TYPE;
import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ECHO_REQUEST_TYPE;
@@ -44,11 +47,11 @@
import static java.util.Objects.requireNonNull;
import android.content.Context;
-import android.net.InetAddresses;
import android.net.IpPrefix;
import android.net.LinkAddress;
import android.net.LinkProperties;
import android.net.MacAddress;
+import android.net.RouteInfo;
import android.net.thread.utils.FullThreadDevice;
import android.net.thread.utils.InfraNetworkDevice;
import android.net.thread.utils.OtDaemonController;
@@ -74,7 +77,9 @@
import org.junit.Test;
import org.junit.runner.RunWith;
+import java.net.Inet4Address;
import java.net.Inet6Address;
+import java.net.InetAddress;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
@@ -89,11 +94,14 @@
private static final String TAG = BorderRoutingTest.class.getSimpleName();
private static final int NUM_FTD = 2;
private static final Inet6Address GROUP_ADDR_SCOPE_5 =
- (Inet6Address) InetAddresses.parseNumericAddress("ff05::1234");
+ (Inet6Address) parseNumericAddress("ff05::1234");
private static final Inet6Address GROUP_ADDR_SCOPE_4 =
- (Inet6Address) InetAddresses.parseNumericAddress("ff04::1234");
+ (Inet6Address) parseNumericAddress("ff04::1234");
private static final Inet6Address GROUP_ADDR_SCOPE_3 =
- (Inet6Address) InetAddresses.parseNumericAddress("ff03::1234");
+ (Inet6Address) parseNumericAddress("ff03::1234");
+ private static final Inet4Address IPV4_SERVER_ADDR =
+ (Inet4Address) parseNumericAddress("8.8.8.8");
+ private static final String NAT64_CIDR = "192.168.255.0/24";
@Rule public final ThreadFeatureCheckerRule mThreadRule = new ThreadFeatureCheckerRule();
@@ -165,7 +173,7 @@
mInfraDevice.sendEchoRequest(ftd.getOmrAddress());
// Infra device receives an echo reply sent by FTD.
- assertNotNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd.getOmrAddress()));
+ assertNotNull(pollForIcmpPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd.getOmrAddress()));
}
@Test
@@ -186,7 +194,7 @@
mInfraDevice.sendEchoRequest(ftd.getOmrAddress());
- assertNotNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd.getOmrAddress()));
+ assertNotNull(pollForIcmpPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd.getOmrAddress()));
}
@Test
@@ -213,7 +221,7 @@
mInfraDevice.sendEchoRequest(ftd.getOmrAddress());
- assertNotNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftdOmr));
+ assertNotNull(pollForIcmpPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftdOmr));
} finally {
runAsShell(MANAGE_TEST_NETWORKS, () -> oldInfraNetworkTracker.teardown());
}
@@ -322,7 +330,7 @@
mInfraDevice.sendEchoRequest(GROUP_ADDR_SCOPE_5);
- assertNotNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd.getOmrAddress()));
+ assertNotNull(pollForIcmpPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd.getOmrAddress()));
}
@Test
@@ -354,7 +362,7 @@
mInfraDevice.sendEchoRequest(GROUP_ADDR_SCOPE_3);
- assertNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd.getOmrAddress()));
+ assertNull(pollForIcmpPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd.getOmrAddress()));
}
@Test
@@ -375,7 +383,7 @@
mInfraDevice.sendEchoRequest(GROUP_ADDR_SCOPE_4);
- assertNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd.getOmrAddress()));
+ assertNull(pollForIcmpPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd.getOmrAddress()));
}
@Test
@@ -405,13 +413,15 @@
mInfraDevice.sendEchoRequest(GROUP_ADDR_SCOPE_5);
- assertNotNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd1.getOmrAddress()));
+ assertNotNull(
+ pollForIcmpPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd1.getOmrAddress()));
// Verify ping reply from ftd1 and ftd2 separately as the order of replies can't be
// predicted.
mInfraDevice.sendEchoRequest(GROUP_ADDR_SCOPE_4);
- assertNotNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd2.getOmrAddress()));
+ assertNotNull(
+ pollForIcmpPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd2.getOmrAddress()));
}
@Test
@@ -441,12 +451,14 @@
mInfraDevice.sendEchoRequest(GROUP_ADDR_SCOPE_5);
- assertNotNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd1.getOmrAddress()));
+ assertNotNull(
+ pollForIcmpPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd1.getOmrAddress()));
// Send the request twice as the order of replies from ftd1 and ftd2 are not guaranteed
mInfraDevice.sendEchoRequest(GROUP_ADDR_SCOPE_5);
- assertNotNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd2.getOmrAddress()));
+ assertNotNull(
+ pollForIcmpPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd2.getOmrAddress()));
}
@Test
@@ -469,9 +481,11 @@
ftd.ping(GROUP_ADDR_SCOPE_4);
assertNotNull(
- pollForPacketOnInfraNetwork(ICMPV6_ECHO_REQUEST_TYPE, ftdOmr, GROUP_ADDR_SCOPE_5));
+ pollForIcmpPacketOnInfraNetwork(
+ ICMPV6_ECHO_REQUEST_TYPE, ftdOmr, GROUP_ADDR_SCOPE_5));
assertNotNull(
- pollForPacketOnInfraNetwork(ICMPV6_ECHO_REQUEST_TYPE, ftdOmr, GROUP_ADDR_SCOPE_4));
+ pollForIcmpPacketOnInfraNetwork(
+ ICMPV6_ECHO_REQUEST_TYPE, ftdOmr, GROUP_ADDR_SCOPE_4));
}
@Test
@@ -493,7 +507,7 @@
ftd.ping(GROUP_ADDR_SCOPE_3);
assertNull(
- pollForPacketOnInfraNetwork(
+ pollForIcmpPacketOnInfraNetwork(
ICMPV6_ECHO_REQUEST_TYPE, ftd.getOmrAddress(), GROUP_ADDR_SCOPE_3));
}
@@ -517,7 +531,8 @@
ftd.ping(GROUP_ADDR_SCOPE_4, ftdLla);
assertNull(
- pollForPacketOnInfraNetwork(ICMPV6_ECHO_REQUEST_TYPE, ftdLla, GROUP_ADDR_SCOPE_4));
+ pollForIcmpPacketOnInfraNetwork(
+ ICMPV6_ECHO_REQUEST_TYPE, ftdLla, GROUP_ADDR_SCOPE_4));
}
@Test
@@ -541,7 +556,7 @@
ftd.ping(GROUP_ADDR_SCOPE_4, ftdMla);
assertNull(
- pollForPacketOnInfraNetwork(
+ pollForIcmpPacketOnInfraNetwork(
ICMPV6_ECHO_REQUEST_TYPE, ftdMla, GROUP_ADDR_SCOPE_4));
}
}
@@ -572,7 +587,7 @@
mInfraDevice.sendEchoRequest(GROUP_ADDR_SCOPE_5);
- assertNotNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftdOmr));
+ assertNotNull(pollForIcmpPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftdOmr));
}
@Test
@@ -600,18 +615,40 @@
ftd.ping(GROUP_ADDR_SCOPE_4);
assertNotNull(
- pollForPacketOnInfraNetwork(ICMPV6_ECHO_REQUEST_TYPE, ftdOmr, GROUP_ADDR_SCOPE_4));
+ pollForIcmpPacketOnInfraNetwork(
+ ICMPV6_ECHO_REQUEST_TYPE, ftdOmr, GROUP_ADDR_SCOPE_4));
+ }
+
+ @Test
+ public void nat64_threadDevicePingIpv4InfraDevice_outboundPacketIsForwarded() throws Exception {
+ FullThreadDevice ftd = mFtds.get(0);
+ joinNetworkAndWaitForOmr(ftd, DEFAULT_DATASET);
+ // TODO: enable NAT64 via ThreadNetworkController API instead of ot-ctl
+ mOtCtl.setNat64Cidr(NAT64_CIDR);
+ mOtCtl.setNat64Enabled(true);
+ waitFor(() -> mOtCtl.hasNat64PrefixInNetdata(), Duration.ofSeconds(10));
+
+ ftd.ping(IPV4_SERVER_ADDR);
+
+ assertNotNull(pollForIcmpPacketOnInfraNetwork(ICMP_ECHO, null, IPV4_SERVER_ADDR));
}
private void setUpInfraNetwork() throws Exception {
+ LinkProperties lp = new LinkProperties();
+ // NAT64 feature requires the infra network to have an IPv4 default route.
+ lp.addRoute(
+ new RouteInfo(
+ new IpPrefix("0.0.0.0/0") /* destination */,
+ null /* gateway */,
+ null,
+ RouteInfo.RTN_UNICAST,
+ 1500 /* mtu */));
mInfraNetworkTracker =
runAsShell(
MANAGE_TEST_NETWORKS,
- () ->
- initTestNetwork(
- mContext, new LinkProperties(), 5000 /* timeoutMs */));
- mController.setTestNetworkAsUpstreamAndWait(
- mInfraNetworkTracker.getTestIface().getInterfaceName());
+ () -> initTestNetwork(mContext, lp, 5000 /* timeoutMs */));
+ String infraNetworkName = mInfraNetworkTracker.getTestIface().getInterfaceName();
+ mController.setTestNetworkAsUpstreamAndWait(infraNetworkName);
}
private void tearDownInfraNetwork() {
@@ -648,20 +685,28 @@
assertInfraLinkMemberOfGroup(address);
}
- private byte[] pollForPacketOnInfraNetwork(int type, Inet6Address srcAddress) {
- return pollForPacketOnInfraNetwork(type, srcAddress, null);
+ private byte[] pollForIcmpPacketOnInfraNetwork(int type, InetAddress srcAddress) {
+ return pollForIcmpPacketOnInfraNetwork(type, srcAddress, null /* destAddress */);
}
- private byte[] pollForPacketOnInfraNetwork(
- int type, Inet6Address srcAddress, Inet6Address destAddress) {
- Predicate<byte[]> filter;
- filter =
+ private byte[] pollForIcmpPacketOnInfraNetwork(
+ int type, InetAddress srcAddress, InetAddress destAddress) {
+ if (srcAddress == null && destAddress == null) {
+ throw new IllegalArgumentException("srcAddress and destAddress cannot be both null");
+ }
+ if (srcAddress != null && destAddress != null) {
+ if ((srcAddress instanceof Inet4Address) != (destAddress instanceof Inet4Address)) {
+ throw new IllegalArgumentException(
+ "srcAddress and destAddress must be both IPv4 or both IPv6");
+ }
+ }
+ boolean isIpv4 =
+ (srcAddress instanceof Inet4Address) || (destAddress instanceof Inet4Address);
+ final Predicate<byte[]> filter =
p ->
- (isExpectedIcmpv6Packet(p, type)
- && (srcAddress == null ? true : isFromIpv6Source(p, srcAddress))
- && (destAddress == null
- ? true
- : isToIpv6Destination(p, destAddress)));
+ (isIpv4 ? isExpectedIcmpv4Packet(p, type) : isExpectedIcmpv6Packet(p, type))
+ && (srcAddress == null || isFrom(p, srcAddress))
+ && (destAddress == null || isTo(p, destAddress));
return pollForPacket(mInfraNetworkReader, filter);
}
}
diff --git a/thread/tests/integration/src/android/net/thread/utils/FullThreadDevice.java b/thread/tests/integration/src/android/net/thread/utils/FullThreadDevice.java
index 8440bbc..083a841 100644
--- a/thread/tests/integration/src/android/net/thread/utils/FullThreadDevice.java
+++ b/thread/tests/integration/src/android/net/thread/utils/FullThreadDevice.java
@@ -417,7 +417,7 @@
executeCommand("ipmaddr add " + address.getHostAddress());
}
- public void ping(Inet6Address address, Inet6Address source) {
+ public void ping(InetAddress address, Inet6Address source) {
ping(
address,
source,
@@ -428,7 +428,7 @@
PING_TIMEOUT_0_1_SECOND);
}
- public void ping(Inet6Address address) {
+ public void ping(InetAddress address) {
ping(
address,
null,
@@ -440,7 +440,7 @@
}
/** Returns the number of ping reply packets received. */
- public int ping(Inet6Address address, int count) {
+ public int ping(InetAddress address, int count) {
List<String> output =
ping(
address,
@@ -454,7 +454,7 @@
}
private List<String> ping(
- Inet6Address address,
+ InetAddress address,
Inet6Address source,
int size,
int count,
diff --git a/thread/tests/integration/src/android/net/thread/utils/IntegrationTestUtils.java b/thread/tests/integration/src/android/net/thread/utils/IntegrationTestUtils.java
index 7b0c415..82e9332 100644
--- a/thread/tests/integration/src/android/net/thread/utils/IntegrationTestUtils.java
+++ b/thread/tests/integration/src/android/net/thread/utils/IntegrationTestUtils.java
@@ -16,6 +16,7 @@
package android.net.thread.utils;
import static android.net.NetworkCapabilities.NET_CAPABILITY_LOCAL_NETWORK;
+import static android.system.OsConstants.IPPROTO_ICMP;
import static android.system.OsConstants.IPPROTO_ICMPV6;
import static com.android.compatibility.common.util.SystemUtil.runShellCommandOrThrow;
@@ -49,7 +50,9 @@
import androidx.test.core.app.ApplicationProvider;
import com.android.net.module.util.Struct;
+import com.android.net.module.util.structs.Icmpv4Header;
import com.android.net.module.util.structs.Icmpv6Header;
+import com.android.net.module.util.structs.Ipv4Header;
import com.android.net.module.util.structs.Ipv6Header;
import com.android.net.module.util.structs.PrefixInformationOption;
import com.android.net.module.util.structs.RaHeader;
@@ -62,6 +65,7 @@
import java.io.IOException;
import java.net.DatagramPacket;
import java.net.DatagramSocket;
+import java.net.Inet4Address;
import java.net.Inet6Address;
import java.net.InetAddress;
import java.net.InetSocketAddress;
@@ -192,16 +196,36 @@
return null;
}
- /** Returns {@code true} if {@code packet} is an ICMPv6 packet of given {@code type}. */
- public static boolean isExpectedIcmpv6Packet(byte[] packet, int type) {
- if (packet == null) {
+ /** Returns {@code true} if {@code packet} is an ICMPv4 packet of given {@code type}. */
+ public static boolean isExpectedIcmpv4Packet(byte[] packet, int type) {
+ ByteBuffer buf = makeByteBuffer(packet);
+ Ipv4Header header = extractIpv4Header(buf);
+ if (header == null) {
return false;
}
- ByteBuffer buf = ByteBuffer.wrap(packet);
+ if (header.protocol != (byte) IPPROTO_ICMP) {
+ return false;
+ }
try {
- if (Struct.parse(Ipv6Header.class, buf).nextHeader != (byte) IPPROTO_ICMPV6) {
- return false;
- }
+ return Struct.parse(Icmpv4Header.class, buf).type == (short) type;
+ } catch (IllegalArgumentException ignored) {
+ // It's fine that the passed in packet is malformed because it's could be sent
+ // by anybody.
+ }
+ return false;
+ }
+
+ /** Returns {@code true} if {@code packet} is an ICMPv6 packet of given {@code type}. */
+ public static boolean isExpectedIcmpv6Packet(byte[] packet, int type) {
+ ByteBuffer buf = makeByteBuffer(packet);
+ Ipv6Header header = extractIpv6Header(buf);
+ if (header == null) {
+ return false;
+ }
+ if (header.nextHeader != (byte) IPPROTO_ICMPV6) {
+ return false;
+ }
+ try {
return Struct.parse(Icmpv6Header.class, buf).type == (short) type;
} catch (IllegalArgumentException ignored) {
// It's fine that the passed in packet is malformed because it's could be sent
@@ -210,32 +234,66 @@
return false;
}
- public static boolean isFromIpv6Source(byte[] packet, Inet6Address src) {
- if (packet == null) {
- return false;
- }
- ByteBuffer buf = ByteBuffer.wrap(packet);
- try {
- return Struct.parse(Ipv6Header.class, buf).srcIp.equals(src);
- } catch (IllegalArgumentException ignored) {
- // It's fine that the passed in packet is malformed because it's could be sent
- // by anybody.
+ public static boolean isFrom(byte[] packet, InetAddress src) {
+ if (src instanceof Inet4Address) {
+ return isFromIpv4Source(packet, (Inet4Address) src);
+ } else if (src instanceof Inet6Address) {
+ return isFromIpv6Source(packet, (Inet6Address) src);
}
return false;
}
- public static boolean isToIpv6Destination(byte[] packet, Inet6Address dest) {
- if (packet == null) {
- return false;
+ public static boolean isTo(byte[] packet, InetAddress dest) {
+ if (dest instanceof Inet4Address) {
+ return isToIpv4Destination(packet, (Inet4Address) dest);
+ } else if (dest instanceof Inet6Address) {
+ return isToIpv6Destination(packet, (Inet6Address) dest);
}
- ByteBuffer buf = ByteBuffer.wrap(packet);
+ return false;
+ }
+
+ private static boolean isFromIpv4Source(byte[] packet, Inet4Address src) {
+ Ipv4Header header = extractIpv4Header(makeByteBuffer(packet));
+ return header != null && header.srcIp.equals(src);
+ }
+
+ private static boolean isFromIpv6Source(byte[] packet, Inet6Address src) {
+ Ipv6Header header = extractIpv6Header(makeByteBuffer(packet));
+ return header != null && header.srcIp.equals(src);
+ }
+
+ private static boolean isToIpv4Destination(byte[] packet, Inet4Address dest) {
+ Ipv4Header header = extractIpv4Header(makeByteBuffer(packet));
+ return header != null && header.dstIp.equals(dest);
+ }
+
+ private static boolean isToIpv6Destination(byte[] packet, Inet6Address dest) {
+ Ipv6Header header = extractIpv6Header(makeByteBuffer(packet));
+ return header != null && header.dstIp.equals(dest);
+ }
+
+ private static ByteBuffer makeByteBuffer(byte[] packet) {
+ return packet == null ? null : ByteBuffer.wrap(packet);
+ }
+
+ private static Ipv4Header extractIpv4Header(ByteBuffer buf) {
try {
- return Struct.parse(Ipv6Header.class, buf).dstIp.equals(dest);
+ return Struct.parse(Ipv4Header.class, buf);
} catch (IllegalArgumentException ignored) {
// It's fine that the passed in packet is malformed because it's could be sent
// by anybody.
}
- return false;
+ return null;
+ }
+
+ private static Ipv6Header extractIpv6Header(ByteBuffer buf) {
+ try {
+ return Struct.parse(Ipv6Header.class, buf);
+ } catch (IllegalArgumentException ignored) {
+ // It's fine that the passed in packet is malformed because it's could be sent
+ // by anybody.
+ }
+ return null;
}
/** Returns the Prefix Information Options (PIO) extracted from an ICMPv6 RA message. */
diff --git a/thread/tests/integration/src/android/net/thread/utils/OtDaemonController.java b/thread/tests/integration/src/android/net/thread/utils/OtDaemonController.java
index b3175fd..15a3f5c 100644
--- a/thread/tests/integration/src/android/net/thread/utils/OtDaemonController.java
+++ b/thread/tests/integration/src/android/net/thread/utils/OtDaemonController.java
@@ -105,6 +105,29 @@
return prefixes.isEmpty() ? null : prefixes.get(0);
}
+ /** Enables/Disables NAT64 feature. */
+ public void setNat64Enabled(boolean enabled) {
+ executeCommand("nat64 " + (enabled ? "enable" : "disable"));
+ }
+
+ /** Sets the NAT64 CIDR. */
+ public void setNat64Cidr(String cidr) {
+ executeCommand("nat64 cidr " + cidr);
+ }
+
+ /** Returns whether there's a NAT64 prefix in network data */
+ public boolean hasNat64PrefixInNetdata() {
+ // Example (in the 'Routes' section):
+ // fdb2:bae3:5b59:2:0:0::/96 sn low c000
+ List<String> outputLines = executeCommandAndParse("netdata show");
+ for (String line : outputLines) {
+ if (line.contains(" sn")) {
+ return true;
+ }
+ }
+ return false;
+ }
+
public String executeCommand(String cmd) {
return SystemUtil.runShellCommand(OT_CTL + " " + cmd);
}
diff --git a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
index df1a65b..be32764 100644
--- a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
+++ b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
@@ -17,6 +17,11 @@
package com.android.server.thread;
import static android.Manifest.permission.NETWORK_SETTINGS;
+import static android.net.NetworkCapabilities.NET_CAPABILITY_INTERNET;
+import static android.net.NetworkCapabilities.NET_CAPABILITY_NOT_VPN;
+import static android.net.NetworkCapabilities.TRANSPORT_ETHERNET;
+import static android.net.NetworkCapabilities.TRANSPORT_THREAD;
+import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
import static android.net.thread.ActiveOperationalDataset.CHANNEL_PAGE_24_GHZ;
import static android.net.thread.ThreadNetworkController.STATE_DISABLED;
import static android.net.thread.ThreadNetworkController.STATE_ENABLED;
@@ -39,6 +44,7 @@
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doNothing;
@@ -58,7 +64,6 @@
import android.content.res.Resources;
import android.net.ConnectivityManager;
import android.net.NetworkAgent;
-import android.net.NetworkCapabilities;
import android.net.NetworkProvider;
import android.net.NetworkRequest;
import android.net.thread.ActiveOperationalDataset;
@@ -746,6 +751,30 @@
}
@Test
+ public void initialize_upstreamNetworkRequestHasCertainTransportTypesAndCapabilities() {
+ mService.initialize();
+ mTestLooper.dispatchAll();
+
+ ArgumentCaptor<NetworkRequest> networkRequestCaptor =
+ ArgumentCaptor.forClass(NetworkRequest.class);
+ verify(mMockConnectivityManager, atLeastOnce())
+ .registerNetworkCallback(
+ networkRequestCaptor.capture(),
+ any(ConnectivityManager.NetworkCallback.class),
+ any(Handler.class));
+ List<NetworkRequest> upstreamNetworkRequests =
+ networkRequestCaptor.getAllValues().stream()
+ .filter(nr -> !nr.hasTransport(TRANSPORT_THREAD))
+ .toList();
+ assertThat(upstreamNetworkRequests.size()).isEqualTo(1);
+ NetworkRequest upstreamNetworkRequest = upstreamNetworkRequests.get(0);
+ assertThat(upstreamNetworkRequest.hasTransport(TRANSPORT_WIFI)).isTrue();
+ assertThat(upstreamNetworkRequest.hasTransport(TRANSPORT_ETHERNET)).isTrue();
+ assertThat(upstreamNetworkRequest.hasCapability(NET_CAPABILITY_NOT_VPN)).isTrue();
+ assertThat(upstreamNetworkRequest.hasCapability(NET_CAPABILITY_INTERNET)).isTrue();
+ }
+
+ @Test
public void setTestNetworkAsUpstream_upstreamNetworkRequestAlwaysDisallowsVpn() {
mService.initialize();
mTestLooper.dispatchAll();
@@ -768,10 +797,8 @@
NetworkRequest networkRequest1 = networkRequestCaptor.getAllValues().get(0);
NetworkRequest networkRequest2 = networkRequestCaptor.getAllValues().get(1);
assertThat(networkRequest1.getNetworkSpecifier()).isNotNull();
- assertThat(networkRequest1.hasCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN))
- .isTrue();
+ assertThat(networkRequest1.hasCapability(NET_CAPABILITY_NOT_VPN)).isTrue();
assertThat(networkRequest2.getNetworkSpecifier()).isNull();
- assertThat(networkRequest2.hasCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN))
- .isTrue();
+ assertThat(networkRequest2.hasCapability(NET_CAPABILITY_NOT_VPN)).isTrue();
}
}