Move utilities to libs/net
These files used to be in the network stack directory, but the libs
directory is a much more suitable place for them.
Also fix a typo : ConcurrentIntepreter → ConcurrentInterpreter
Also move {FdEvents,Packet}Reader to internal annotations. That's
what they should have been using in the first place anyway.
Note that this does not fix preupload issues reported by
checkstyle to make review easier. The fixes are in a followup
patch to this one.
Test: checkbuild
Change-Id: I675077fd42cbb092c0e6bd56571f2fc022e582fd
diff --git a/staticlibs/Android.bp b/staticlibs/Android.bp
index 0f8dda0..4b15b16 100644
--- a/staticlibs/Android.bp
+++ b/staticlibs/Android.bp
@@ -27,10 +27,86 @@
// included in the bootclasspath, they could incorrectly be included in the SDK documentation even
// though they are not in the current.txt files.
+java_library {
+ name: "net-utils-device-common",
+ srcs: [
+ "device/**/*.java",
+ // This library is used by system modules, for which the system health impact of Kotlin
+ // has not yet been evaluated.
+ // "src_devicecommon/**/*.kt",
+ ":framework-annotations",
+ ],
+ sdk_version: "system_current",
+ min_sdk_version: "29",
+ target_sdk_version: "30",
+ apex_available: [
+ "//apex_available:anyapex",
+ "//apex_available:platform",
+ ],
+ visibility: [
+ "//frameworks/base/packages/Tethering",
+ "//frameworks/opt/net/ike",
+ "//frameworks/opt/net/wifi/service",
+ "//frameworks/opt/net/telephony",
+ "//packages/modules/NetworkStack",
+ "//packages/modules/CaptivePortalLogin",
+ "//frameworks/libs/net/common/tests:__subpackages__",
+ ],
+ libs: [
+ "androidx.annotation_annotation",
+ ],
+}
+
+java_library {
+ // Consider using net-tests-utils instead if writing device code.
+ // That library has a lot more useful tools into it for users that
+ // work on Android and includes this lib.
+ name: "net-tests-utils-host-device-common",
+ srcs: [
+ "hostdevice/**/*.java",
+ "hostdevice/**/*.kt",
+ ],
+ host_supported: true,
+ visibility: [
+ "//frameworks/libs/net/common/tests:__subpackages__",
+ ],
+ static_libs: [
+ "kotlin-test"
+ ]
+}
+
+java_defaults {
+ name: "lib_mockito_extended",
+ static_libs: [
+ "mockito-target-extended-minus-junit4"
+ ],
+ jni_libs: [
+ "libdexmakerjvmtiagent",
+ "libstaticjvmtiagent",
+ ],
+}
+
+java_library {
+ name: "net-tests-utils",
+ srcs: [
+ "devicetests/**/*.java",
+ "devicetests/**/*.kt",
+ ],
+ defaults: ["lib_mockito_extended"],
+ libs: [
+ "androidx.annotation_annotation",
+ ],
+ static_libs: [
+ "androidx.test.ext.junit",
+ "net-tests-utils-host-device-common",
+ "net-utils-device-common",
+ ],
+}
+
filegroup {
name: "net-utils-framework-common-srcs",
- srcs: ["src_frameworkcommon/**/*.java"],
- path: "src_frameworkcommon",
+ srcs: ["framework/**/*.java"],
+ path: "framework",
visibility: ["//frameworks/base"],
}
@@ -49,7 +125,7 @@
"//frameworks/base/packages/Tethering",
"//frameworks/opt/net/ike",
"//frameworks/opt/telephony",
- "//packages/modules/NetworkStack",
+ "//packages/modules/NetworkStack:__subpackages__",
"//packages/modules/CaptivePortalLogin",
"//frameworks/libs/net/common/tests:__subpackages__",
]
@@ -58,17 +134,12 @@
java_library {
name: "net-utils-services-common",
srcs: [
- "src_servicescommon/**/*.java",
+ "device/android/net/NetworkFactory.java",
":framework-annotations",
],
sdk_version: "system_current",
visibility: [
"//frameworks/base/services/net",
- "//frameworks/base/packages/Tethering",
- "//frameworks/opt/net/ike",
- "//packages/modules/NetworkStack",
- "//packages/modules/CaptivePortalLogin",
- "//frameworks/libs/net/common/tests:__subpackages__",
],
}
@@ -78,9 +149,9 @@
name: "net-utils-telephony-common-srcs",
srcs: [
// Any class here *must* have a corresponding jarjar rule in the telephony build rules.
- "src_servicescommon/android/net/NetworkFactory.java",
+ "device/android/net/NetworkFactory.java",
],
- path: "src_servicescommon",
+ path: "device",
visibility: [
"//frameworks/opt/telephony",
],
@@ -92,11 +163,11 @@
filegroup {
name: "net-utils-framework-wifi-common-srcs",
srcs: [
- "src_frameworkcommon/android/net/util/nsd/DnsSdTxtRecord.java",
- "src_frameworkcommon/android/net/util/MacAddressUtils.java",
- "src_frameworkcommon/com/android/net/module/util/**/*.java",
+ "framework/android/net/util/nsd/DnsSdTxtRecord.java",
+ "framework/android/net/util/MacAddressUtils.java",
+ "framework/com/android/net/module/util/**/*.java",
],
- path: "src_frameworkcommon",
+ path: "framework",
visibility: [
"//frameworks/base",
],
@@ -108,8 +179,8 @@
filegroup {
name: "net-utils-wifi-service-common-srcs",
srcs: [
- "src_frameworkcommon/android/net/util/NetUtils.java",
- "src_servicescommon/android/net/NetworkFactory.java",
+ "device/android/net/NetworkFactory.java",
+ "framework/android/net/util/NetUtils.java",
],
visibility: [
"//frameworks/opt/net/wifi/service",
diff --git a/staticlibs/src_servicescommon/android/net/NetworkFactory.java b/staticlibs/device/android/net/NetworkFactory.java
similarity index 100%
rename from staticlibs/src_servicescommon/android/net/NetworkFactory.java
rename to staticlibs/device/android/net/NetworkFactory.java
diff --git a/staticlibs/device/android/net/util/FdEventsReader.java b/staticlibs/device/android/net/util/FdEventsReader.java
new file mode 100644
index 0000000..a5714aa
--- /dev/null
+++ b/staticlibs/device/android/net/util/FdEventsReader.java
@@ -0,0 +1,269 @@
+/*
+ * Copyright (C) 2016 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// TODO : move this and PacketReader to com.android.net.module.util.
+package android.net.util;
+
+import static android.os.MessageQueue.OnFileDescriptorEventListener.EVENT_ERROR;
+import static android.os.MessageQueue.OnFileDescriptorEventListener.EVENT_INPUT;
+
+import android.os.Handler;
+import android.os.Looper;
+import android.os.MessageQueue;
+import android.system.ErrnoException;
+import android.system.OsConstants;
+import android.util.Log;
+
+import android.annotation.NonNull;
+import android.annotation.Nullable;
+
+import com.android.internal.annotations.VisibleForTesting;
+
+import java.io.FileDescriptor;
+import java.io.IOException;
+
+
+/**
+ * This class encapsulates the mechanics of registering a file descriptor
+ * with a thread's Looper and handling read events (and errors).
+ *
+ * Subclasses MUST implement createFd() and SHOULD override handlePacket(). They MAY override
+ * onStop() and onStart().
+ *
+ * Subclasses can expect a call life-cycle like the following:
+ *
+ * [1] when a client calls start(), createFd() is called, followed by the onStart() hook if all
+ * goes well. Implementations may override onStart() for additional initialization.
+ *
+ * [2] yield, waiting for read event or error notification:
+ *
+ * [a] readPacket() && handlePacket()
+ *
+ * [b] if (no error):
+ * goto 2
+ * else:
+ * goto 3
+ *
+ * [3] when a client calls stop(), the onStop() hook is called (unless already stopped or never
+ * started). Implementations may override onStop() for additional cleanup.
+ *
+ * The packet receive buffer is recycled on every read call, so subclasses
+ * should make any copies they would like inside their handlePacket()
+ * implementation.
+ *
+ * All public methods MUST only be called from the same thread with which
+ * the Handler constructor argument is associated.
+ *
+ * @param <BufferType> the type of the buffer used to read data.
+ */
+public abstract class FdEventsReader<BufferType> {
+ private static final int FD_EVENTS = EVENT_INPUT | EVENT_ERROR;
+ private static final int UNREGISTER_THIS_FD = 0;
+
+ @NonNull
+ private final Handler mHandler;
+ @NonNull
+ private final MessageQueue mQueue;
+ @NonNull
+ private final BufferType mBuffer;
+ @Nullable
+ private FileDescriptor mFd;
+ private long mPacketsReceived;
+
+ protected static void closeFd(FileDescriptor fd) {
+ try {
+ SocketUtils.closeSocket(fd);
+ } catch (IOException ignored) {
+ }
+ }
+
+ protected FdEventsReader(@NonNull Handler h, @NonNull BufferType buffer) {
+ mHandler = h;
+ mQueue = mHandler.getLooper().getQueue();
+ mBuffer = buffer;
+ }
+
+ @VisibleForTesting
+ @NonNull
+ protected MessageQueue getMessageQueue() {
+ return mQueue;
+ }
+
+ /** Start this FdEventsReader. */
+ public boolean start() {
+ if (!onCorrectThread()) {
+ throw new IllegalStateException("start() called from off-thread");
+ }
+
+ return createAndRegisterFd();
+ }
+
+ /** Stop this FdEventsReader and destroy the file descriptor. */
+ public void stop() {
+ if (!onCorrectThread()) {
+ throw new IllegalStateException("stop() called from off-thread");
+ }
+
+ unregisterAndDestroyFd();
+ }
+
+ @NonNull
+ public Handler getHandler() {
+ return mHandler;
+ }
+
+ protected abstract int recvBufSize(@NonNull BufferType buffer);
+
+ /** Returns the size of the receive buffer. */
+ public int recvBufSize() {
+ return recvBufSize(mBuffer);
+ }
+
+ /**
+ * Get the number of successful calls to {@link #readPacket(FileDescriptor, Object)}.
+ *
+ * <p>A call was successful if {@link #readPacket(FileDescriptor, Object)} returned a value > 0.
+ */
+ public final long numPacketsReceived() {
+ return mPacketsReceived;
+ }
+
+ /**
+ * Subclasses MUST create the listening socket here, including setting all desired socket
+ * options, interface or address/port binding, etc. The socket MUST be created nonblocking.
+ */
+ @Nullable
+ protected abstract FileDescriptor createFd();
+
+ /**
+ * Implementations MUST return the bytes read or throw an Exception.
+ *
+ * <p>The caller may throw a {@link ErrnoException} with {@link OsConstants#EAGAIN} or
+ * {@link OsConstants#EINTR}, in which case {@link FdEventsReader} will ignore the buffer
+ * contents and respectively wait for further input or retry the read immediately. For all other
+ * exceptions, the {@link FdEventsReader} will be stopped with no more interactions with this
+ * method.
+ */
+ protected abstract int readPacket(@NonNull FileDescriptor fd, @NonNull BufferType buffer)
+ throws Exception;
+
+ /**
+ * Called by the main loop for every packet. Any desired copies of
+ * |recvbuf| should be made in here, as the underlying byte array is
+ * reused across all reads.
+ */
+ protected void handlePacket(@NonNull BufferType recvbuf, int length) {}
+
+ /**
+ * Called by the main loop to log errors. In some cases |e| may be null.
+ */
+ protected void logError(@NonNull String msg, @Nullable Exception e) {}
+
+ /**
+ * Called by start(), if successful, just prior to returning.
+ */
+ protected void onStart() {}
+
+ /**
+ * Called by stop() just prior to returning.
+ */
+ protected void onStop() {}
+
+ private boolean createAndRegisterFd() {
+ if (mFd != null) return true;
+
+ try {
+ mFd = createFd();
+ } catch (Exception e) {
+ logError("Failed to create socket: ", e);
+ closeFd(mFd);
+ mFd = null;
+ }
+
+ if (mFd == null) return false;
+
+ getMessageQueue().addOnFileDescriptorEventListener(
+ mFd,
+ FD_EVENTS,
+ (fd, events) -> {
+ // Always call handleInput() so read/recvfrom are given
+ // a proper chance to encounter a meaningful errno and
+ // perhaps log a useful error message.
+ if (!isRunning() || !handleInput()) {
+ unregisterAndDestroyFd();
+ return UNREGISTER_THIS_FD;
+ }
+ return FD_EVENTS;
+ });
+ onStart();
+ return true;
+ }
+
+ private boolean isRunning() {
+ return (mFd != null) && mFd.valid();
+ }
+
+ // Keep trying to read until we get EAGAIN/EWOULDBLOCK or some fatal error.
+ private boolean handleInput() {
+ while (isRunning()) {
+ final int bytesRead;
+
+ try {
+ bytesRead = readPacket(mFd, mBuffer);
+ if (bytesRead < 1) {
+ if (isRunning()) logError("Socket closed, exiting", null);
+ break;
+ }
+ mPacketsReceived++;
+ } catch (ErrnoException e) {
+ if (e.errno == OsConstants.EAGAIN) {
+ // We've read everything there is to read this time around.
+ return true;
+ } else if (e.errno == OsConstants.EINTR) {
+ continue;
+ } else {
+ if (isRunning()) logError("readPacket error: ", e);
+ break;
+ }
+ } catch (Exception e) {
+ if (isRunning()) logError("readPacket error: ", e);
+ break;
+ }
+
+ try {
+ handlePacket(mBuffer, bytesRead);
+ } catch (Exception e) {
+ logError("handlePacket error: ", e);
+ Log.wtf(FdEventsReader.class.getSimpleName(), "Error handling packet", e);
+ }
+ }
+
+ return false;
+ }
+
+ private void unregisterAndDestroyFd() {
+ if (mFd == null) return;
+
+ getMessageQueue().removeOnFileDescriptorEventListener(mFd);
+ closeFd(mFd);
+ mFd = null;
+ onStop();
+ }
+
+ private boolean onCorrectThread() {
+ return (mHandler.getLooper() == Looper.myLooper());
+ }
+}
diff --git a/staticlibs/device/android/net/util/PacketReader.java b/staticlibs/device/android/net/util/PacketReader.java
new file mode 100644
index 0000000..61b3d6e
--- /dev/null
+++ b/staticlibs/device/android/net/util/PacketReader.java
@@ -0,0 +1,60 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// TODO : move this and FdEventsReader to com.android.net.module.util.
+package android.net.util;
+
+import static java.lang.Math.max;
+
+import android.os.Handler;
+import android.system.Os;
+
+import java.io.FileDescriptor;
+
+/**
+ * Specialization of {@link FdEventsReader} that reads packets into a byte array.
+ *
+ * TODO: rename this class to something more correctly descriptive (something
+ * like [or less horrible than] FdReadEventsHandler?).
+ */
+public abstract class PacketReader extends FdEventsReader<byte[]> {
+
+ public static final int DEFAULT_RECV_BUF_SIZE = 2 * 1024;
+
+ protected PacketReader(Handler h) {
+ this(h, DEFAULT_RECV_BUF_SIZE);
+ }
+
+ protected PacketReader(Handler h, int recvBufSize) {
+ super(h, new byte[max(recvBufSize, DEFAULT_RECV_BUF_SIZE)]);
+ }
+
+ @Override
+ protected final int recvBufSize(byte[] buffer) {
+ return buffer.length;
+ }
+
+ /**
+ * Subclasses MAY override this to change the default read() implementation
+ * in favour of, say, recvfrom().
+ *
+ * Implementations MUST return the bytes read or throw an Exception.
+ */
+ @Override
+ protected int readPacket(FileDescriptor fd, byte[] packetBuffer) throws Exception {
+ return Os.read(fd, packetBuffer, 0, packetBuffer.length);
+ }
+}
diff --git a/staticlibs/devicetests/com/android/testutils/ConcurrentInterpreter.kt b/staticlibs/devicetests/com/android/testutils/ConcurrentInterpreter.kt
new file mode 100644
index 0000000..98464eb
--- /dev/null
+++ b/staticlibs/devicetests/com/android/testutils/ConcurrentInterpreter.kt
@@ -0,0 +1,188 @@
+package com.android.testutils
+
+import android.os.SystemClock
+import java.util.concurrent.CyclicBarrier
+import kotlin.system.measureTimeMillis
+import kotlin.test.assertEquals
+import kotlin.test.assertFails
+import kotlin.test.assertNull
+import kotlin.test.assertTrue
+
+// The table contains pairs associating a regexp with the code to run. The statement is matched
+// against each matcher in sequence and when a match is found the associated code is run, passing
+// it the TrackRecord under test and the result of the regexp match.
+typealias InterpretMatcher<T> = Pair<Regex, (ConcurrentInterpreter<T>, T, MatchResult) -> Any?>
+
+// The default unit of time for interpreted tests
+val INTERPRET_TIME_UNIT = 40L // ms
+
+/**
+ * A small interpreter for testing parallel code. The interpreter will read a list of lines
+ * consisting of "|"-separated statements. Each column runs in a different concurrent thread
+ * and all threads wait for each other in between lines. Each statement is split on ";" then
+ * matched with regular expressions in the instructionTable constant, which contains the
+ * code associated with each statement. The interpreter supports an object being passed to
+ * the interpretTestSpec() method to be passed in each lambda (think about the object under
+ * test), and an optional transform function to be executed on the object at the start of
+ * every thread.
+ *
+ * The time unit is defined in milliseconds by the interpretTimeUnit member, which has a default
+ * value but can be passed to the constructor. Whitespace is ignored.
+ *
+ * The interpretation table has to be passed as an argument. It's a table associating a regexp
+ * with the code that should execute, as a function taking three arguments : the interpreter,
+ * the regexp match, and the object. See the individual tests for the DSL of that test.
+ * Implementors for new interpreting languages are encouraged to look at the defaultInterpretTable
+ * constant below for an example of how to write an interpreting table.
+ * Some expressions already exist by default and can be used by all interpreters. They include :
+ * sleep(x) : sleeps for x time units and returns Unit ; sleep alone means sleep(1)
+ * EXPR = VALUE : asserts that EXPR equals VALUE. EXPR is interpreted. VALUE can either be the
+ * string "null" or an int. Returns Unit.
+ * EXPR time x..y : measures the time taken by EXPR and asserts it took at least x and at most
+ * y time units.
+ * EXPR // any text : comments are ignored.
+ */
+open class ConcurrentInterpreter<T>(
+ localInterpretTable: List<InterpretMatcher<T>>,
+ val interpretTimeUnit: Long = INTERPRET_TIME_UNIT
+) {
+ private val interpretTable: List<InterpretMatcher<T>> =
+ localInterpretTable + getDefaultInstructions()
+
+ // Split the line into multiple statements separated by ";" and execute them. Return whatever
+ // the last statement returned.
+ fun interpretMultiple(instr: String, r: T): Any? {
+ return instr.split(";").map { interpret(it.trim(), r) }.last()
+ }
+
+ // Match the statement to a regex and interpret it.
+ fun interpret(instr: String, r: T): Any? {
+ val (matcher, code) =
+ interpretTable.find { instr matches it.first } ?: throw SyntaxException(instr)
+ val match = matcher.matchEntire(instr) ?: throw SyntaxException(instr)
+ return code(this, r, match)
+ }
+
+ // Spins as many threads as needed by the test spec and interpret each program concurrently,
+ // having all threads waiting on a CyclicBarrier after each line.
+ // |lineShift| says how many lines after the call the spec starts. This is used for error
+ // reporting. Unfortunately AFAICT there is no way to get the line of an argument rather
+ // than the line at which the expression starts.
+ fun interpretTestSpec(
+ spec: String,
+ initial: T,
+ lineShift: Int = 0,
+ threadTransform: (T) -> T = { it }
+ ) {
+ // For nice stack traces
+ val callSite = getCallingMethod()
+ val lines = spec.trim().trim('\n').split("\n").map { it.split("|") }
+ // |threads| contains arrays of strings that make up the statements of a thread : in other
+ // words, it's an array that contains a list of statements for each column in the spec.
+ val threadCount = lines[0].size
+ assertTrue(lines.all { it.size == threadCount })
+ val threadInstructions = (0 until threadCount).map { i -> lines.map { it[i].trim() } }
+ val barrier = CyclicBarrier(threadCount)
+ var crash: InterpretException? = null
+ threadInstructions.mapIndexed { threadIndex, instructions ->
+ Thread {
+ val threadLocal = threadTransform(initial)
+ barrier.await()
+ var lineNum = 0
+ instructions.forEach {
+ if (null != crash) return@Thread
+ lineNum += 1
+ try {
+ interpretMultiple(it, threadLocal)
+ } catch (e: Throwable) {
+ // If fail() or some exception was called, the thread will come here ; if
+ // the exception isn't caught the process will crash, which is not nice for
+ // testing. Instead, catch the exception, cancel other threads, and report
+ // nicely. Catch throwable because fail() is AssertionError, which inherits
+ // from Error.
+ crash = InterpretException(threadIndex, it,
+ callSite.lineNumber + lineNum + lineShift,
+ callSite.className, callSite.methodName, callSite.fileName, e)
+ }
+ barrier.await()
+ }
+ }.also { it.start() }
+ }.forEach { it.join() }
+ // If the test failed, crash with line number
+ crash?.let { throw it }
+ }
+
+ // Helper to get the stack trace for a calling method
+ private fun getCallingStackTrace(): Array<StackTraceElement> {
+ try {
+ throw RuntimeException()
+ } catch (e: RuntimeException) {
+ return e.stackTrace
+ }
+ }
+
+ // Find the calling method. This is the first method in the stack trace that is annotated
+ // with @Test.
+ fun getCallingMethod(): StackTraceElement {
+ val stackTrace = getCallingStackTrace()
+ return stackTrace.find { element ->
+ val clazz = Class.forName(element.className)
+ // Because the stack trace doesn't list the formal arguments, find all methods with
+ // this name and return this name if any of them is annotated with @Test.
+ clazz.declaredMethods
+ .filter { method -> method.name == element.methodName }
+ .any { method -> method.getAnnotation(org.junit.Test::class.java) != null }
+ } ?: stackTrace[3]
+ // If no method is annotated return the 4th one, because that's what it usually is :
+ // 0 is getCallingStackTrace, 1 is this method, 2 is ConcurrentInterpreter#interpretTestSpec
+ }
+}
+
+private fun <T> getDefaultInstructions() = listOf<InterpretMatcher<T>>(
+ // Interpret an empty line as doing nothing.
+ Regex("") to { _, _, _ -> null },
+ // Ignore comments.
+ Regex("(.*)//.*") to { i, t, r -> i.interpret(r.strArg(1), t) },
+ // Interpret "XXX time x..y" : run XXX and check it took at least x and not more than y
+ Regex("""(.*)\s*time\s*(\d+)\.\.(\d+)""") to { i, t, r ->
+ val time = measureTimeMillis { i.interpret(r.strArg(1), t) }
+ assertTrue(time in r.timeArg(2)..r.timeArg(3), "$time not in ${r.timeArg(2)..r.timeArg(3)}")
+ },
+ // Interpret "XXX = YYY" : run XXX and assert its return value is equal to YYY. "null" supported
+ Regex("""(.*)\s*=\s*(null|\d+)""") to { i, t, r ->
+ i.interpret(r.strArg(1), t).also {
+ if ("null" == r.strArg(2)) assertNull(it) else assertEquals(r.intArg(2), it)
+ }
+ },
+ // Interpret sleep. Optional argument for the count, in INTERPRET_TIME_UNIT units.
+ Regex("""sleep(\((\d+)\))?""") to { i, t, r ->
+ SystemClock.sleep(if (r.strArg(2).isEmpty()) i.interpretTimeUnit else r.timeArg(2))
+ },
+ Regex("""(.*)\s*fails""") to { i, t, r ->
+ assertFails { i.interpret(r.strArg(1), t) }
+ }
+)
+
+class SyntaxException(msg: String, cause: Throwable? = null) : RuntimeException(msg, cause)
+class InterpretException(
+ threadIndex: Int,
+ instr: String,
+ lineNum: Int,
+ className: String,
+ methodName: String,
+ fileName: String,
+ cause: Throwable
+) : RuntimeException("Failure: $instr", cause) {
+ init {
+ stackTrace = arrayOf(StackTraceElement(
+ className,
+ "$methodName:thread$threadIndex",
+ fileName,
+ lineNum)) + super.getStackTrace()
+ }
+}
+
+// Some small helpers to avoid to say the large ".groupValues[index].trim()" every time
+fun MatchResult.strArg(index: Int) = this.groupValues[index].trim()
+fun MatchResult.intArg(index: Int) = strArg(index).toInt()
+fun MatchResult.timeArg(index: Int) = INTERPRET_TIME_UNIT * intArg(index)
diff --git a/staticlibs/devicetests/com/android/testutils/DevSdkIgnoreRule.kt b/staticlibs/devicetests/com/android/testutils/DevSdkIgnoreRule.kt
new file mode 100644
index 0000000..4a83f6f
--- /dev/null
+++ b/staticlibs/devicetests/com/android/testutils/DevSdkIgnoreRule.kt
@@ -0,0 +1,90 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.os.Build
+import org.junit.Assume.assumeTrue
+import org.junit.rules.TestRule
+import org.junit.runner.Description
+import org.junit.runners.model.Statement
+
+/**
+ * Returns true if the development SDK version of the device is in the provided range.
+ *
+ * If the device is not using a release SDK, the development SDK is considered to be higher than
+ * [Build.VERSION.SDK_INT].
+ */
+fun isDevSdkInRange(minExclusive: Int?, maxInclusive: Int?): Boolean {
+ // In-development API n+1 will have SDK_INT == n and CODENAME != REL.
+ // Stable API n has SDK_INT == n and CODENAME == REL.
+ val release = "REL" == Build.VERSION.CODENAME
+ val sdkInt = Build.VERSION.SDK_INT
+ val devApiLevel = sdkInt + if (release) 0 else 1
+
+ return (minExclusive == null || devApiLevel > minExclusive) &&
+ (maxInclusive == null || devApiLevel <= maxInclusive)
+}
+
+/**
+ * A test rule to ignore tests based on the development SDK level.
+ *
+ * If the device is not using a release SDK, the development SDK is considered to be higher than
+ * [Build.VERSION.SDK_INT].
+ *
+ * @param ignoreClassUpTo Skip all tests in the class if the device dev SDK is <= this value.
+ * @param ignoreClassAfter Skip all tests in the class if the device dev SDK is > this value.
+ */
+class DevSdkIgnoreRule @JvmOverloads constructor(
+ private val ignoreClassUpTo: Int? = null,
+ private val ignoreClassAfter: Int? = null
+) : TestRule {
+ override fun apply(base: Statement, description: Description): Statement {
+ return IgnoreBySdkStatement(base, description)
+ }
+
+ /**
+ * Ignore the test for any development SDK that is strictly after [value].
+ *
+ * If the device is not using a release SDK, the development SDK is considered to be higher
+ * than [Build.VERSION.SDK_INT].
+ */
+ annotation class IgnoreAfter(val value: Int)
+
+ /**
+ * Ignore the test for any development SDK that lower than or equal to [value].
+ *
+ * If the device is not using a release SDK, the development SDK is considered to be higher
+ * than [Build.VERSION.SDK_INT].
+ */
+ annotation class IgnoreUpTo(val value: Int)
+
+ private inner class IgnoreBySdkStatement(
+ private val base: Statement,
+ private val description: Description
+ ) : Statement() {
+ override fun evaluate() {
+ val ignoreAfter = description.getAnnotation(IgnoreAfter::class.java)
+ val ignoreUpTo = description.getAnnotation(IgnoreUpTo::class.java)
+
+ val message = "Skipping test for build ${Build.VERSION.CODENAME} " +
+ "with SDK ${Build.VERSION.SDK_INT}"
+ assumeTrue(message, isDevSdkInRange(ignoreClassUpTo, ignoreClassAfter))
+ assumeTrue(message, isDevSdkInRange(ignoreUpTo?.value, ignoreAfter?.value))
+ base.evaluate()
+ }
+ }
+}
\ No newline at end of file
diff --git a/staticlibs/devicetests/com/android/testutils/DevSdkIgnoreRunner.kt b/staticlibs/devicetests/com/android/testutils/DevSdkIgnoreRunner.kt
new file mode 100644
index 0000000..73b2843
--- /dev/null
+++ b/staticlibs/devicetests/com/android/testutils/DevSdkIgnoreRunner.kt
@@ -0,0 +1,72 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import androidx.test.ext.junit.runners.AndroidJUnit4
+import com.android.testutils.DevSdkIgnoreRule.IgnoreAfter
+import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
+import org.junit.runner.Description
+import org.junit.runner.Runner
+import org.junit.runner.notification.RunNotifier
+
+/**
+ * A runner that can skip tests based on the development SDK as defined in [DevSdkIgnoreRule].
+ *
+ * Generally [DevSdkIgnoreRule] should be used for that purpose (using rules is preferable over
+ * replacing the test runner), however JUnit runners inspect all methods in the test class before
+ * processing test rules. This may cause issues if the test methods are referencing classes that do
+ * not exist on the SDK of the device the test is run on.
+ *
+ * This runner inspects [IgnoreAfter] and [IgnoreUpTo] annotations on the test class, and will skip
+ * the whole class if they do not match the development SDK as defined in [DevSdkIgnoreRule].
+ * Otherwise, it will delegate to [AndroidJUnit4] to run the test as usual.
+ *
+ * Example usage:
+ *
+ * @RunWith(DevSdkIgnoreRunner::class)
+ * @IgnoreUpTo(Build.VERSION_CODES.Q)
+ * class MyTestClass { ... }
+ */
+class DevSdkIgnoreRunner(private val klass: Class<*>) : Runner() {
+ private val baseRunner = klass.let {
+ val ignoreAfter = it.getAnnotation(IgnoreAfter::class.java)
+ val ignoreUpTo = it.getAnnotation(IgnoreUpTo::class.java)
+
+ if (isDevSdkInRange(ignoreUpTo?.value, ignoreAfter?.value)) AndroidJUnit4(klass) else null
+ }
+
+ override fun run(notifier: RunNotifier) {
+ if (baseRunner != null) {
+ baseRunner.run(notifier)
+ return
+ }
+
+ // Report a single, skipped placeholder test for this class, so that the class is still
+ // visible as skipped in test results.
+ notifier.fireTestIgnored(
+ Description.createTestDescription(klass, "skippedClassForDevSdkMismatch"))
+ }
+
+ override fun getDescription(): Description {
+ return baseRunner?.description ?: Description.createSuiteDescription(klass)
+ }
+
+ override fun testCount(): Int {
+ // When ignoring the tests, a skipped placeholder test is reported, so test count is 1.
+ return baseRunner?.testCount() ?: 1
+ }
+}
\ No newline at end of file
diff --git a/staticlibs/devicetests/com/android/testutils/FakeDns.kt b/staticlibs/devicetests/com/android/testutils/FakeDns.kt
new file mode 100644
index 0000000..825d748
--- /dev/null
+++ b/staticlibs/devicetests/com/android/testutils/FakeDns.kt
@@ -0,0 +1,94 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.net.DnsResolver
+import android.net.InetAddresses
+import android.os.Looper
+import android.os.Handler
+import com.android.internal.annotations.GuardedBy
+import java.net.InetAddress
+import java.util.concurrent.Executor
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.Mockito.any
+import org.mockito.Mockito.anyInt
+import org.mockito.Mockito.doAnswer
+
+const val TYPE_UNSPECIFIED = -1
+// TODO: Integrate with NetworkMonitorTest.
+class FakeDns(val mockResolver: DnsResolver) {
+ class DnsEntry(val hostname: String, val type: Int, val addresses: List<InetAddress>) {
+ fun match(host: String, type: Int) = hostname.equals(host) && type == type
+ }
+
+ @GuardedBy("answers")
+ val answers = ArrayList<DnsEntry>()
+
+ fun getAnswer(hostname: String, type: Int): DnsEntry? = synchronized(answers) {
+ return answers.firstOrNull { it.match(hostname, type) }
+ }
+
+ fun setAnswer(hostname: String, answer: Array<String>, type: Int) = synchronized(answers) {
+ val ans = DnsEntry(hostname, type, generateAnswer(answer))
+ // Replace or remove the existing one.
+ when (val index = answers.indexOfFirst { it.match(hostname, type) }) {
+ -1 -> answers.add(ans)
+ else -> answers[index] = ans
+ }
+ }
+
+ private fun generateAnswer(answer: Array<String>) =
+ answer.filterNotNull().map { InetAddresses.parseNumericAddress(it) }
+
+ fun startMocking() {
+ // Mock DnsResolver.query() w/o type
+ doAnswer {
+ mockAnswer(it, 1, -1, 3, 5)
+ }.`when`(mockResolver).query(any() /* network */, any() /* domain */, anyInt() /* flags */,
+ any() /* executor */, any() /* cancellationSignal */, any() /*callback*/)
+ // Mock DnsResolver.query() w/ type
+ doAnswer {
+ mockAnswer(it, 1, 2, 4, 6)
+ }.`when`(mockResolver).query(any() /* network */, any() /* domain */, anyInt() /* nsType */,
+ anyInt() /* flags */, any() /* executor */, any() /* cancellationSignal */,
+ any() /*callback*/)
+ }
+
+ private fun mockAnswer(
+ it: InvocationOnMock,
+ posHos: Int,
+ posType: Int,
+ posExecutor: Int,
+ posCallback: Int
+ ) {
+ val hostname = it.arguments[posHos] as String
+ val executor = it.arguments[posExecutor] as Executor
+ val callback = it.arguments[posCallback] as DnsResolver.Callback<List<InetAddress>>
+ var type = if (posType != -1) it.arguments[posType] as Int else TYPE_UNSPECIFIED
+ val answer = getAnswer(hostname, type)
+
+ if (!answer?.addresses.isNullOrEmpty()) {
+ Handler(Looper.getMainLooper()).post({ executor.execute({
+ callback.onAnswer(answer?.addresses, 0); }) })
+ }
+ }
+
+ /** Clears all entries. */
+ fun clearAll() = synchronized(answers) {
+ answers.clear()
+ }
+}
diff --git a/staticlibs/devicetests/com/android/testutils/HandlerUtils.kt b/staticlibs/devicetests/com/android/testutils/HandlerUtils.kt
new file mode 100644
index 0000000..fa36209
--- /dev/null
+++ b/staticlibs/devicetests/com/android/testutils/HandlerUtils.kt
@@ -0,0 +1,49 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.os.ConditionVariable
+import android.os.Handler
+import android.os.HandlerThread
+import java.util.concurrent.Executor
+import kotlin.system.measureTimeMillis
+import kotlin.test.fail
+
+/**
+ * Block until the specified Handler or HandlerThread becomes idle, or until timeoutMs has passed.
+ */
+fun HandlerThread.waitForIdle(timeoutMs: Int) = threadHandler.waitForIdle(timeoutMs.toLong())
+fun HandlerThread.waitForIdle(timeoutMs: Long) = threadHandler.waitForIdle(timeoutMs)
+fun Handler.waitForIdle(timeoutMs: Int) = waitForIdle(timeoutMs.toLong())
+fun Handler.waitForIdle(timeoutMs: Long) {
+ val cv = ConditionVariable(false)
+ post(cv::open)
+ if (!cv.block(timeoutMs)) {
+ fail("Handler did not become idle after ${timeoutMs}ms")
+ }
+}
+
+/**
+ * Block until the given Serial Executor becomes idle, or until timeoutMs has passed.
+ */
+fun waitForIdleSerialExecutor(executor: Executor, timeoutMs: Long) {
+ val cv = ConditionVariable()
+ executor.execute(cv::open)
+ if (!cv.block(timeoutMs)) {
+ fail("Executor did not become idle after ${timeoutMs}ms")
+ }
+}
diff --git a/staticlibs/devicetests/com/android/testutils/NetworkStatsUtils.kt b/staticlibs/devicetests/com/android/testutils/NetworkStatsUtils.kt
new file mode 100644
index 0000000..8324b25
--- /dev/null
+++ b/staticlibs/devicetests/com/android/testutils/NetworkStatsUtils.kt
@@ -0,0 +1,78 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.net.NetworkStats
+import kotlin.test.assertTrue
+
+@JvmOverloads
+fun orderInsensitiveEquals(
+ leftStats: NetworkStats,
+ rightStats: NetworkStats,
+ compareTime: Boolean = false
+): Boolean {
+ if (leftStats == rightStats) return true
+ if (compareTime && leftStats.getElapsedRealtime() != rightStats.getElapsedRealtime()) {
+ return false
+ }
+
+ // While operations such as add/subtract will preserve empty entries. This will make
+ // the result be hard to verify during test. Remove them before comparing since they
+ // are not really affect correctness.
+ // TODO (b/152827872): Remove empty entries after addition/subtraction.
+ val leftTrimmedEmpty = leftStats.removeEmptyEntries()
+ val rightTrimmedEmpty = rightStats.removeEmptyEntries()
+
+ if (leftTrimmedEmpty.size() != rightTrimmedEmpty.size()) return false
+ val left = NetworkStats.Entry()
+ val right = NetworkStats.Entry()
+ // Order insensitive compare.
+ for (i in 0 until leftTrimmedEmpty.size()) {
+ leftTrimmedEmpty.getValues(i, left)
+ val j: Int = rightTrimmedEmpty.findIndexHinted(left.iface, left.uid, left.set, left.tag,
+ left.metered, left.roaming, left.defaultNetwork, i)
+ if (j == -1) return false
+ rightTrimmedEmpty.getValues(j, right)
+ if (left != right) return false
+ }
+ return true
+}
+
+/**
+ * Assert that two {@link NetworkStats} are equals, assuming the order of the records are not
+ * necessarily the same.
+ *
+ * @note {@code elapsedRealtime} is not compared by default, given that in test cases that is not
+ * usually used.
+ */
+@JvmOverloads
+fun assertNetworkStatsEquals(
+ expected: NetworkStats,
+ actual: NetworkStats,
+ compareTime: Boolean = false
+) {
+ assertTrue(orderInsensitiveEquals(expected, actual, compareTime),
+ "expected: " + expected + " but was: " + actual)
+}
+
+/**
+ * Assert that after being parceled then unparceled, {@link NetworkStats} is equal to the original
+ * object.
+ */
+fun assertParcelingIsLossless(stats: NetworkStats) {
+ assertParcelingIsLossless(stats, { a, b -> orderInsensitiveEquals(a, b) })
+}
diff --git a/staticlibs/devicetests/com/android/testutils/ParcelUtils.kt b/staticlibs/devicetests/com/android/testutils/ParcelUtils.kt
new file mode 100644
index 0000000..5784f7c
--- /dev/null
+++ b/staticlibs/devicetests/com/android/testutils/ParcelUtils.kt
@@ -0,0 +1,68 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.os.Parcel
+import android.os.Parcelable
+import kotlin.test.assertTrue
+import kotlin.test.fail
+
+/**
+ * Return a new instance of `T` after being parceled then unparceled.
+ */
+fun <T : Parcelable> parcelingRoundTrip(source: T): T {
+ val creator: Parcelable.Creator<T>
+ try {
+ creator = source.javaClass.getField("CREATOR").get(null) as Parcelable.Creator<T>
+ } catch (e: IllegalAccessException) {
+ fail("Missing CREATOR field: " + e.message)
+ } catch (e: NoSuchFieldException) {
+ fail("Missing CREATOR field: " + e.message)
+ }
+
+ var p = Parcel.obtain()
+ source.writeToParcel(p, /* flags */ 0)
+ p.setDataPosition(0)
+ val marshalled = p.marshall()
+ p = Parcel.obtain()
+ p.unmarshall(marshalled, 0, marshalled.size)
+ p.setDataPosition(0)
+ return creator.createFromParcel(p)
+}
+
+/**
+ * Assert that after being parceled then unparceled, `source` is equal to the original
+ * object. If a customized equals function is provided, uses the provided one.
+ */
+@JvmOverloads
+fun <T : Parcelable> assertParcelingIsLossless(
+ source: T,
+ equals: (T, T) -> Boolean = { a, b -> a == b }
+) {
+ val actual = parcelingRoundTrip(source)
+ assertTrue(equals(source, actual), "Expected $source, but was $actual")
+}
+
+@JvmOverloads
+fun <T : Parcelable> assertParcelSane(
+ obj: T,
+ fieldCount: Int,
+ equals: (T, T) -> Boolean = { a, b -> a == b }
+) {
+ assertFieldCountEquals(fieldCount, obj::class.java)
+ assertParcelingIsLossless(obj, equals)
+}
diff --git a/staticlibs/devicetests/com/android/testutils/TapPacketReader.java b/staticlibs/devicetests/com/android/testutils/TapPacketReader.java
new file mode 100644
index 0000000..fea1cd9
--- /dev/null
+++ b/staticlibs/devicetests/com/android/testutils/TapPacketReader.java
@@ -0,0 +1,93 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils;
+
+import android.net.util.PacketReader;
+import android.os.Handler;
+
+import androidx.annotation.NonNull;
+import androidx.annotation.Nullable;
+
+import com.android.net.module.util.ArrayTrackRecord;
+
+import java.io.FileDescriptor;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.function.Predicate;
+
+import kotlin.Lazy;
+import kotlin.LazyKt;
+
+public class TapPacketReader extends PacketReader {
+ private final FileDescriptor mTapFd;
+ private final ArrayTrackRecord<byte[]> mReceivedPackets = new ArrayTrackRecord<>();
+ private final Lazy<ArrayTrackRecord<byte[]>.ReadHead> mReadHead =
+ LazyKt.lazy(mReceivedPackets::newReadHead);
+
+ public TapPacketReader(Handler h, FileDescriptor tapFd, int maxPacketSize) {
+ super(h, maxPacketSize);
+ mTapFd = tapFd;
+ }
+
+ @Override
+ protected FileDescriptor createFd() {
+ return mTapFd;
+ }
+
+ @Override
+ protected void handlePacket(byte[] recvbuf, int length) {
+ final byte[] newPacket = Arrays.copyOf(recvbuf, length);
+ if (!mReceivedPackets.add(newPacket)) {
+ throw new AssertionError("More than " + Integer.MAX_VALUE + " packets outstanding!");
+ }
+ }
+
+ /**
+ * Get the next packet that was received on the interface.
+ */
+ @Nullable
+ public byte[] popPacket(long timeoutMs) {
+ return mReadHead.getValue().poll(timeoutMs, packet -> true);
+ }
+
+ /**
+ * Get the next packet that was received on the interface and matches the specified filter.
+ */
+ @Nullable
+ public byte[] popPacket(long timeoutMs, @NonNull Predicate<byte[]> filter) {
+ return mReadHead.getValue().poll(timeoutMs, filter::test);
+ }
+
+ /**
+ * Get the {@link ArrayTrackRecord} that records all packets received by the reader since its
+ * creation.
+ */
+ public ArrayTrackRecord<byte[]> getReceivedPackets() {
+ return mReceivedPackets;
+ }
+
+ public void sendResponse(final ByteBuffer packet) throws IOException {
+ try (FileOutputStream out = new FileOutputStream(mTapFd)) {
+ byte[] packetBytes = new byte[packet.limit()];
+ packet.get(packetBytes);
+ packet.flip(); // So we can reuse it in the future.
+ out.write(packetBytes);
+ }
+ }
+}
diff --git a/staticlibs/devicetests/com/android/testutils/TestNetworkTracker.kt b/staticlibs/devicetests/com/android/testutils/TestNetworkTracker.kt
new file mode 100644
index 0000000..4bd9ae8
--- /dev/null
+++ b/staticlibs/devicetests/com/android/testutils/TestNetworkTracker.kt
@@ -0,0 +1,92 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.content.Context
+import android.net.ConnectivityManager
+import android.net.ConnectivityManager.NetworkCallback
+import android.net.LinkAddress
+import android.net.Network
+import android.net.NetworkCapabilities
+import android.net.NetworkRequest
+import android.net.StringNetworkSpecifier
+import android.net.TestNetworkInterface
+import android.net.TestNetworkManager
+import android.os.Binder
+import java.util.concurrent.CompletableFuture
+import java.util.concurrent.TimeUnit
+
+/**
+ * Create a test network based on a TUN interface.
+ *
+ * This method will block until the test network is available. Requires
+ * [android.Manifest.permission.CHANGE_NETWORK_STATE] and
+ * [android.Manifest.permission.MANAGE_TEST_NETWORKS].
+ */
+fun initTestNetwork(context: Context, interfaceAddr: LinkAddress, setupTimeoutMs: Long = 10_000L):
+ TestNetworkTracker {
+ val tnm = context.getSystemService(TestNetworkManager::class.java)
+ val iface = tnm.createTunInterface(arrayOf(interfaceAddr))
+ return TestNetworkTracker(context, iface, tnm, setupTimeoutMs)
+}
+
+/**
+ * Utility class to create and track test networks.
+ *
+ * This class is not thread-safe.
+ */
+class TestNetworkTracker internal constructor(
+ val context: Context,
+ val iface: TestNetworkInterface,
+ tnm: TestNetworkManager,
+ setupTimeoutMs: Long
+) {
+ private val cm = context.getSystemService(ConnectivityManager::class.java)
+ private val binder = Binder()
+
+ private val networkCallback: NetworkCallback
+ val network: Network
+
+ init {
+ val networkFuture = CompletableFuture<Network>()
+ val networkRequest = NetworkRequest.Builder()
+ .addTransportType(NetworkCapabilities.TRANSPORT_TEST)
+ // Test networks do not have NOT_VPN or TRUSTED capabilities by default
+ .removeCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)
+ .removeCapability(NetworkCapabilities.NET_CAPABILITY_TRUSTED)
+ .setNetworkSpecifier(StringNetworkSpecifier(iface.interfaceName))
+ .build()
+ networkCallback = object : NetworkCallback() {
+ override fun onAvailable(network: Network) {
+ networkFuture.complete(network)
+ }
+ }
+ cm.requestNetwork(networkRequest, networkCallback)
+
+ try {
+ tnm.setupTestNetwork(iface.interfaceName, binder)
+ network = networkFuture.get(setupTimeoutMs, TimeUnit.MILLISECONDS)
+ } catch (e: Throwable) {
+ teardown()
+ throw e
+ }
+ }
+
+ fun teardown() {
+ cm.unregisterNetworkCallback(networkCallback)
+ }
+}
\ No newline at end of file
diff --git a/staticlibs/devicetests/com/android/testutils/TestableNetworkCallback.kt b/staticlibs/devicetests/com/android/testutils/TestableNetworkCallback.kt
new file mode 100644
index 0000000..959a837
--- /dev/null
+++ b/staticlibs/devicetests/com/android/testutils/TestableNetworkCallback.kt
@@ -0,0 +1,393 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.net.ConnectivityManager.NetworkCallback
+import android.net.LinkProperties
+import android.net.Network
+import android.net.NetworkCapabilities
+import android.net.NetworkCapabilities.NET_CAPABILITY_VALIDATED
+import com.android.net.module.util.ArrayTrackRecord
+import com.android.testutils.RecorderCallback.CallbackEntry.Available
+import com.android.testutils.RecorderCallback.CallbackEntry.BlockedStatus
+import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged
+import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged
+import com.android.testutils.RecorderCallback.CallbackEntry.Losing
+import com.android.testutils.RecorderCallback.CallbackEntry.Lost
+import com.android.testutils.RecorderCallback.CallbackEntry.Resumed
+import com.android.testutils.RecorderCallback.CallbackEntry.Suspended
+import com.android.testutils.RecorderCallback.CallbackEntry.Unavailable
+import kotlin.reflect.KClass
+import kotlin.test.assertEquals
+import kotlin.test.assertNotNull
+import kotlin.test.assertTrue
+import kotlin.test.fail
+
+object NULL_NETWORK : Network(-1)
+object ANY_NETWORK : Network(-2)
+
+private val Int.capabilityName get() = NetworkCapabilities.capabilityNameOf(this)
+
+open class RecorderCallback private constructor(
+ private val backingRecord: ArrayTrackRecord<CallbackEntry>
+) : NetworkCallback() {
+ public constructor() : this(ArrayTrackRecord())
+ protected constructor(src: RecorderCallback?): this(src?.backingRecord ?: ArrayTrackRecord())
+
+ sealed class CallbackEntry {
+ // To get equals(), hashcode(), componentN() etc for free, the child classes of
+ // this class are data classes. But while data classes can inherit from other classes,
+ // they may only have visible members in the constructors, so they couldn't declare
+ // a constructor with a non-val arg to pass to CallbackEntry. Instead, force all
+ // subclasses to implement a `network' property, which can be done in a data class
+ // constructor by specifying override.
+ abstract val network: Network
+
+ data class Available(override val network: Network) : CallbackEntry()
+ data class CapabilitiesChanged(
+ override val network: Network,
+ val caps: NetworkCapabilities
+ ) : CallbackEntry()
+ data class LinkPropertiesChanged(
+ override val network: Network,
+ val lp: LinkProperties
+ ) : CallbackEntry()
+ data class Suspended(override val network: Network) : CallbackEntry()
+ data class Resumed(override val network: Network) : CallbackEntry()
+ data class Losing(override val network: Network, val maxMsToLive: Int) : CallbackEntry()
+ data class Lost(override val network: Network) : CallbackEntry()
+ data class Unavailable private constructor(
+ override val network: Network
+ ) : CallbackEntry() {
+ constructor() : this(NULL_NETWORK)
+ }
+ data class BlockedStatus(
+ override val network: Network,
+ val blocked: Boolean
+ ) : CallbackEntry()
+
+ // Convenience constants for expecting a type
+ companion object {
+ @JvmField
+ val AVAILABLE = Available::class
+ @JvmField
+ val NETWORK_CAPS_UPDATED = CapabilitiesChanged::class
+ @JvmField
+ val LINK_PROPERTIES_CHANGED = LinkPropertiesChanged::class
+ @JvmField
+ val SUSPENDED = Suspended::class
+ @JvmField
+ val RESUMED = Resumed::class
+ @JvmField
+ val LOSING = Losing::class
+ @JvmField
+ val LOST = Lost::class
+ @JvmField
+ val UNAVAILABLE = Unavailable::class
+ @JvmField
+ val BLOCKED_STATUS = BlockedStatus::class
+ }
+ }
+
+ val history = backingRecord.newReadHead()
+ val mark get() = history.mark
+
+ override fun onAvailable(network: Network) {
+ history.add(Available(network))
+ }
+
+ // PreCheck is not used in the tests today. For backward compatibility with existing tests that
+ // expect the callbacks not to record this, do not listen to PreCheck here.
+
+ override fun onCapabilitiesChanged(network: Network, caps: NetworkCapabilities) {
+ history.add(CapabilitiesChanged(network, caps))
+ }
+
+ override fun onLinkPropertiesChanged(network: Network, lp: LinkProperties) {
+ history.add(LinkPropertiesChanged(network, lp))
+ }
+
+ override fun onBlockedStatusChanged(network: Network, blocked: Boolean) {
+ history.add(BlockedStatus(network, blocked))
+ }
+
+ override fun onNetworkSuspended(network: Network) {
+ history.add(Suspended(network))
+ }
+
+ override fun onNetworkResumed(network: Network) {
+ history.add(Resumed(network))
+ }
+
+ override fun onLosing(network: Network, maxMsToLive: Int) {
+ history.add(Losing(network, maxMsToLive))
+ }
+
+ override fun onLost(network: Network) {
+ history.add(Lost(network))
+ }
+
+ override fun onUnavailable() {
+ history.add(Unavailable())
+ }
+}
+
+private const val DEFAULT_TIMEOUT = 200L // ms
+
+open class TestableNetworkCallback private constructor(
+ src: TestableNetworkCallback?,
+ val defaultTimeoutMs: Long = DEFAULT_TIMEOUT
+) : RecorderCallback(src) {
+ @JvmOverloads
+ constructor(timeoutMs: Long = DEFAULT_TIMEOUT): this(null, timeoutMs)
+
+ fun createLinkedCopy() = TestableNetworkCallback(this, defaultTimeoutMs)
+
+ // The last available network, or null if any network was lost since the last call to
+ // onAvailable. TODO : fix this by fixing the tests that rely on this behavior
+ val lastAvailableNetwork: Network?
+ get() = when (val it = history.lastOrNull { it is Available || it is Lost }) {
+ is Available -> it.network
+ else -> null
+ }
+
+ fun pollForNextCallback(timeoutMs: Long = defaultTimeoutMs): CallbackEntry {
+ return history.poll(timeoutMs) ?: fail("Did not receive callback after ${timeoutMs}ms")
+ }
+
+ // Make open for use in ConnectivityServiceTest which is the only one knowing its handlers.
+ @JvmOverloads
+ open fun assertNoCallback(timeoutMs: Long = defaultTimeoutMs) {
+ val cb = history.poll(timeoutMs)
+ if (null != cb) fail("Expected no callback but got $cb")
+ }
+
+ // Expects a callback of the specified type on the specified network within the timeout.
+ // If no callback arrives, or a different callback arrives, fail. Returns the callback.
+ inline fun <reified T : CallbackEntry> expectCallback(
+ network: Network = ANY_NETWORK,
+ timeoutMs: Long = defaultTimeoutMs
+ ): T = pollForNextCallback(timeoutMs).let {
+ if (it !is T || (ANY_NETWORK !== network && it.network != network)) {
+ fail("Unexpected callback : $it, expected ${T::class} with Network[$network]")
+ } else {
+ it
+ }
+ }
+
+ // Expects a callback of the specified type matching the predicate within the timeout.
+ // Any callback that doesn't match the predicate will be skipped. Fails only if
+ // no matching callback is received within the timeout.
+ inline fun <reified T : CallbackEntry> eventuallyExpect(
+ timeoutMs: Long = defaultTimeoutMs,
+ from: Int = mark,
+ crossinline predicate: (T) -> Boolean = { true }
+ ): T = eventuallyExpectOrNull(timeoutMs, from, predicate).also {
+ assertNotNull(it, "Callback ${T::class} not received within ${timeoutMs}ms")
+ } as T
+
+ // TODO (b/157405399) straighten and unify the method names
+ inline fun <reified T : CallbackEntry> eventuallyExpectOrNull(
+ timeoutMs: Long = defaultTimeoutMs,
+ from: Int = mark,
+ crossinline predicate: (T) -> Boolean = { true }
+ ) = history.poll(timeoutMs, from) { it is T && predicate(it) } as T?
+
+ fun expectCallbackThat(
+ timeoutMs: Long = defaultTimeoutMs,
+ valid: (CallbackEntry) -> Boolean
+ ) = pollForNextCallback(timeoutMs).also { assertTrue(valid(it), "Unexpected callback : $it") }
+
+ fun expectCapabilitiesThat(
+ net: Network,
+ tmt: Long = defaultTimeoutMs,
+ valid: (NetworkCapabilities) -> Boolean
+ ): CapabilitiesChanged {
+ return expectCallback<CapabilitiesChanged>(net, tmt).also {
+ assertTrue(valid(it.caps), "Capabilities don't match expectations ${it.caps}")
+ }
+ }
+
+ fun expectLinkPropertiesThat(
+ net: Network,
+ tmt: Long = defaultTimeoutMs,
+ valid: (LinkProperties) -> Boolean
+ ): LinkPropertiesChanged {
+ return expectCallback<LinkPropertiesChanged>(net, tmt).also {
+ assertTrue(valid(it.lp), "LinkProperties don't match expectations ${it.lp}")
+ }
+ }
+
+ // Expects onAvailable and the callbacks that follow it. These are:
+ // - onSuspended, iff the network was suspended when the callbacks fire.
+ // - onCapabilitiesChanged.
+ // - onLinkPropertiesChanged.
+ // - onBlockedStatusChanged.
+ //
+ // @param network the network to expect the callbacks on.
+ // @param suspended whether to expect a SUSPENDED callback.
+ // @param validated the expected value of the VALIDATED capability in the
+ // onCapabilitiesChanged callback.
+ // @param tmt how long to wait for the callbacks.
+ fun expectAvailableCallbacks(
+ net: Network,
+ suspended: Boolean = false,
+ validated: Boolean = true,
+ blocked: Boolean = false,
+ tmt: Long = defaultTimeoutMs
+ ) {
+ expectCallback<Available>(net, tmt)
+ if (suspended) {
+ expectCallback<Suspended>(net, tmt)
+ }
+ expectCapabilitiesThat(net, tmt) { validated == it.hasCapability(NET_CAPABILITY_VALIDATED) }
+ expectCallback<LinkPropertiesChanged>(net, tmt)
+ expectBlockedStatusCallback(blocked, net)
+ }
+
+ // Backward compatibility for existing Java code. Use named arguments instead and remove all
+ // these when there is no user left.
+ fun expectAvailableAndSuspendedCallbacks(
+ net: Network,
+ validated: Boolean,
+ tmt: Long = defaultTimeoutMs
+ ) = expectAvailableCallbacks(net, suspended = true, validated = validated, tmt = tmt)
+
+ fun expectBlockedStatusCallback(blocked: Boolean, net: Network, tmt: Long = defaultTimeoutMs) {
+ expectCallback<BlockedStatus>(net, tmt).also {
+ assertEquals(it.blocked, blocked, "Unexpected blocked status ${it.blocked}")
+ }
+ }
+
+ // Expects the available callbacks (where the onCapabilitiesChanged must contain the
+ // VALIDATED capability), plus another onCapabilitiesChanged which is identical to the
+ // one we just sent.
+ // TODO: this is likely a bug. Fix it and remove this method.
+ fun expectAvailableDoubleValidatedCallbacks(net: Network, tmt: Long = defaultTimeoutMs) {
+ val mark = history.mark
+ expectAvailableCallbacks(net, tmt = tmt)
+ val firstCaps = history.poll(tmt, mark) { it is CapabilitiesChanged }
+ assertEquals(firstCaps, expectCallback<CapabilitiesChanged>(net, tmt))
+ }
+
+ // Expects the available callbacks where the onCapabilitiesChanged must not have validated,
+ // then expects another onCapabilitiesChanged that has the validated bit set. This is used
+ // when a network connects and satisfies a callback, and then immediately validates.
+ fun expectAvailableThenValidatedCallbacks(net: Network, tmt: Long = defaultTimeoutMs) {
+ expectAvailableCallbacks(net, validated = false, tmt = tmt)
+ expectCapabilitiesThat(net, tmt) { it.hasCapability(NET_CAPABILITY_VALIDATED) }
+ }
+
+ // Temporary Java compat measure : have MockNetworkAgent implement this so that all existing
+ // calls with networkAgent can be routed through here without moving MockNetworkAgent.
+ // TODO: clean this up, remove this method.
+ interface HasNetwork {
+ val network: Network
+ }
+
+ @JvmOverloads
+ open fun <T : CallbackEntry> expectCallback(
+ type: KClass<T>,
+ n: Network?,
+ timeoutMs: Long = defaultTimeoutMs
+ ) = pollForNextCallback(timeoutMs).also {
+ val network = n ?: NULL_NETWORK
+ // TODO : remove this .java access if the tests ever use kotlin-reflect. At the time of
+ // this writing this would be the only use of this library in the tests.
+ assertTrue(type.java.isInstance(it) && it.network == network,
+ "Unexpected callback : $it, expected ${type.java} with Network[$network]")
+ } as T
+
+ @JvmOverloads
+ open fun <T : CallbackEntry> expectCallback(
+ type: KClass<T>,
+ n: HasNetwork?,
+ timeoutMs: Long = defaultTimeoutMs
+ ) = expectCallback(type, n?.network, timeoutMs)
+
+ fun expectAvailableCallbacks(
+ n: HasNetwork,
+ suspended: Boolean,
+ validated: Boolean,
+ blocked: Boolean,
+ timeoutMs: Long
+ ) = expectAvailableCallbacks(n.network, suspended, validated, blocked, timeoutMs)
+
+ fun expectAvailableAndSuspendedCallbacks(n: HasNetwork, expectValidated: Boolean) {
+ expectAvailableAndSuspendedCallbacks(n.network, expectValidated)
+ }
+
+ fun expectAvailableCallbacksValidated(n: HasNetwork) {
+ expectAvailableCallbacks(n.network)
+ }
+
+ fun expectAvailableCallbacksValidatedAndBlocked(n: HasNetwork) {
+ expectAvailableCallbacks(n.network, blocked = true)
+ }
+
+ fun expectAvailableCallbacksUnvalidated(n: HasNetwork) {
+ expectAvailableCallbacks(n.network, validated = false)
+ }
+
+ fun expectAvailableCallbacksUnvalidatedAndBlocked(n: HasNetwork) {
+ expectAvailableCallbacks(n.network, validated = false, blocked = true)
+ }
+
+ fun expectAvailableDoubleValidatedCallbacks(n: HasNetwork) {
+ expectAvailableDoubleValidatedCallbacks(n.network, defaultTimeoutMs)
+ }
+
+ fun expectAvailableThenValidatedCallbacks(n: HasNetwork) {
+ expectAvailableThenValidatedCallbacks(n.network, defaultTimeoutMs)
+ }
+
+ @JvmOverloads
+ fun expectLinkPropertiesThat(
+ n: HasNetwork,
+ tmt: Long = defaultTimeoutMs,
+ valid: (LinkProperties) -> Boolean
+ ) = expectLinkPropertiesThat(n.network, tmt, valid)
+
+ @JvmOverloads
+ fun expectCapabilitiesThat(
+ n: HasNetwork,
+ tmt: Long = defaultTimeoutMs,
+ valid: (NetworkCapabilities) -> Boolean
+ ) = expectCapabilitiesThat(n.network, tmt, valid)
+
+ @JvmOverloads
+ fun expectCapabilitiesWith(
+ capability: Int,
+ n: HasNetwork,
+ timeoutMs: Long = defaultTimeoutMs
+ ): NetworkCapabilities {
+ return expectCapabilitiesThat(n.network, timeoutMs) { it.hasCapability(capability) }.caps
+ }
+
+ @JvmOverloads
+ fun expectCapabilitiesWithout(
+ capability: Int,
+ n: HasNetwork,
+ timeoutMs: Long = defaultTimeoutMs
+ ): NetworkCapabilities {
+ return expectCapabilitiesThat(n.network, timeoutMs) { !it.hasCapability(capability) }.caps
+ }
+
+ fun expectBlockedStatusCallback(expectBlocked: Boolean, n: HasNetwork) {
+ expectBlockedStatusCallback(expectBlocked, n.network, defaultTimeoutMs)
+ }
+}
diff --git a/staticlibs/devicetests/com/android/testutils/TestableNetworkStatsProvider.kt b/staticlibs/devicetests/com/android/testutils/TestableNetworkStatsProvider.kt
new file mode 100644
index 0000000..d5c3a2a
--- /dev/null
+++ b/staticlibs/devicetests/com/android/testutils/TestableNetworkStatsProvider.kt
@@ -0,0 +1,98 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.net.netstats.provider.NetworkStatsProvider
+import com.android.net.module.util.ArrayTrackRecord
+import kotlin.test.assertEquals
+import kotlin.test.assertTrue
+import kotlin.test.fail
+
+private const val DEFAULT_TIMEOUT_MS = 200L
+
+open class TestableNetworkStatsProvider(
+ val defaultTimeoutMs: Long = DEFAULT_TIMEOUT_MS
+) : NetworkStatsProvider() {
+ sealed class CallbackType {
+ data class OnRequestStatsUpdate(val token: Int) : CallbackType()
+ data class OnSetLimit(val iface: String?, val quotaBytes: Long) : CallbackType()
+ data class OnSetAlert(val quotaBytes: Long) : CallbackType()
+ }
+
+ val history = ArrayTrackRecord<CallbackType>().newReadHead()
+ // See ReadHead#mark
+ val mark get() = history.mark
+
+ override fun onRequestStatsUpdate(token: Int) {
+ history.add(CallbackType.OnRequestStatsUpdate(token))
+ }
+
+ override fun onSetLimit(iface: String, quotaBytes: Long) {
+ history.add(CallbackType.OnSetLimit(iface, quotaBytes))
+ }
+
+ override fun onSetAlert(quotaBytes: Long) {
+ history.add(CallbackType.OnSetAlert(quotaBytes))
+ }
+
+ fun expectOnRequestStatsUpdate(token: Int, timeout: Long = defaultTimeoutMs) {
+ assertEquals(CallbackType.OnRequestStatsUpdate(token), history.poll(timeout))
+ }
+
+ fun expectOnSetLimit(iface: String?, quotaBytes: Long, timeout: Long = defaultTimeoutMs) {
+ assertEquals(CallbackType.OnSetLimit(iface, quotaBytes), history.poll(timeout))
+ }
+
+ fun expectOnSetAlert(quotaBytes: Long, timeout: Long = defaultTimeoutMs) {
+ assertEquals(CallbackType.OnSetAlert(quotaBytes), history.poll(timeout))
+ }
+
+ fun pollForNextCallback(timeout: Long = defaultTimeoutMs) =
+ history.poll(timeout) ?: fail("Did not receive callback after ${timeout}ms")
+
+ inline fun <reified T : CallbackType> expectCallback(
+ timeout: Long = defaultTimeoutMs,
+ predicate: (T) -> Boolean = { true }
+ ): T {
+ return pollForNextCallback(timeout).also { assertTrue(it is T && predicate(it)) } as T
+ }
+
+ // Expects a callback of the specified type matching the predicate within the timeout.
+ // Any callback that doesn't match the predicate will be skipped. Fails only if
+ // no matching callback is received within the timeout.
+ // TODO : factorize the code for this with the identical call in TestableNetworkCallback.
+ // There should be a common superclass doing this generically.
+ // TODO : have a better error message to have this fail. Right now the failure when no
+ // matching callback arrives comes from the casting to a non-nullable T.
+ // TODO : in fact, completely removing this method and have clients use
+ // history.poll(timeout, index, predicate) directly might be simpler.
+ inline fun <reified T : CallbackType> eventuallyExpect(
+ timeoutMs: Long = defaultTimeoutMs,
+ from: Int = mark,
+ crossinline predicate: (T) -> Boolean = { true }
+ ) = history.poll(timeoutMs, from) { it is T && predicate(it) } as T
+
+ fun drainCallbacks() {
+ history.mark = history.size
+ }
+
+ @JvmOverloads
+ fun assertNoCallback(timeout: Long = defaultTimeoutMs) {
+ val cb = history.poll(timeout)
+ cb?.let { fail("Expected no callback but got $cb") }
+ }
+}
diff --git a/staticlibs/devicetests/com/android/testutils/TestableNetworkStatsProviderBinder.kt b/staticlibs/devicetests/com/android/testutils/TestableNetworkStatsProviderBinder.kt
new file mode 100644
index 0000000..02922d8
--- /dev/null
+++ b/staticlibs/devicetests/com/android/testutils/TestableNetworkStatsProviderBinder.kt
@@ -0,0 +1,64 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.net.netstats.provider.INetworkStatsProvider
+import com.android.net.module.util.ArrayTrackRecord
+import kotlin.test.assertEquals
+import kotlin.test.fail
+
+private const val DEFAULT_TIMEOUT_MS = 200L
+
+open class TestableNetworkStatsProviderBinder : INetworkStatsProvider.Stub() {
+ sealed class CallbackType {
+ data class OnRequestStatsUpdate(val token: Int) : CallbackType()
+ data class OnSetLimit(val iface: String?, val quotaBytes: Long) : CallbackType()
+ data class OnSetAlert(val quotaBytes: Long) : CallbackType()
+ }
+
+ private val history = ArrayTrackRecord<CallbackType>().ReadHead()
+
+ override fun onRequestStatsUpdate(token: Int) {
+ history.add(CallbackType.OnRequestStatsUpdate(token))
+ }
+
+ override fun onSetLimit(iface: String?, quotaBytes: Long) {
+ history.add(CallbackType.OnSetLimit(iface, quotaBytes))
+ }
+
+ override fun onSetAlert(quotaBytes: Long) {
+ history.add(CallbackType.OnSetAlert(quotaBytes))
+ }
+
+ fun expectOnRequestStatsUpdate(token: Int) {
+ assertEquals(CallbackType.OnRequestStatsUpdate(token), history.poll(DEFAULT_TIMEOUT_MS))
+ }
+
+ fun expectOnSetLimit(iface: String?, quotaBytes: Long) {
+ assertEquals(CallbackType.OnSetLimit(iface, quotaBytes), history.poll(DEFAULT_TIMEOUT_MS))
+ }
+
+ fun expectOnSetAlert(quotaBytes: Long) {
+ assertEquals(CallbackType.OnSetAlert(quotaBytes), history.poll(DEFAULT_TIMEOUT_MS))
+ }
+
+ @JvmOverloads
+ fun assertNoCallback(timeout: Long = DEFAULT_TIMEOUT_MS) {
+ val cb = history.poll(timeout)
+ cb?.let { fail("Expected no callback but got $cb") }
+ }
+}
diff --git a/staticlibs/devicetests/com/android/testutils/TestableNetworkStatsProviderCbBinder.kt b/staticlibs/devicetests/com/android/testutils/TestableNetworkStatsProviderCbBinder.kt
new file mode 100644
index 0000000..163473a
--- /dev/null
+++ b/staticlibs/devicetests/com/android/testutils/TestableNetworkStatsProviderCbBinder.kt
@@ -0,0 +1,84 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.net.NetworkStats
+import android.net.netstats.provider.INetworkStatsProviderCallback
+import com.android.net.module.util.ArrayTrackRecord
+import kotlin.test.assertEquals
+import kotlin.test.assertTrue
+import kotlin.test.fail
+
+private const val DEFAULT_TIMEOUT_MS = 3000L
+
+open class TestableNetworkStatsProviderCbBinder : INetworkStatsProviderCallback.Stub() {
+ sealed class CallbackType {
+ data class NotifyStatsUpdated(
+ val token: Int,
+ val ifaceStats: NetworkStats,
+ val uidStats: NetworkStats
+ ) : CallbackType()
+ object NotifyLimitReached : CallbackType()
+ object NotifyAlertReached : CallbackType()
+ object Unregister : CallbackType()
+ }
+
+ private val history = ArrayTrackRecord<CallbackType>().ReadHead()
+
+ override fun notifyStatsUpdated(token: Int, ifaceStats: NetworkStats, uidStats: NetworkStats) {
+ history.add(CallbackType.NotifyStatsUpdated(token, ifaceStats, uidStats))
+ }
+
+ override fun notifyLimitReached() {
+ history.add(CallbackType.NotifyLimitReached)
+ }
+
+ override fun notifyAlertReached() {
+ history.add(CallbackType.NotifyAlertReached)
+ }
+
+ override fun unregister() {
+ history.add(CallbackType.Unregister)
+ }
+
+ fun expectNotifyStatsUpdated() {
+ val event = history.poll(DEFAULT_TIMEOUT_MS)
+ assertTrue(event is CallbackType.NotifyStatsUpdated)
+ }
+
+ fun expectNotifyStatsUpdated(ifaceStats: NetworkStats, uidStats: NetworkStats) {
+ val event = history.poll(DEFAULT_TIMEOUT_MS)!!
+ if (event !is CallbackType.NotifyStatsUpdated) {
+ throw Exception("Expected NotifyStatsUpdated callback, but got ${event::class}")
+ }
+ // TODO: verify token.
+ assertNetworkStatsEquals(ifaceStats, event.ifaceStats)
+ assertNetworkStatsEquals(uidStats, event.uidStats)
+ }
+
+ fun expectNotifyLimitReached() =
+ assertEquals(CallbackType.NotifyLimitReached, history.poll(DEFAULT_TIMEOUT_MS))
+
+ fun expectNotifyAlertReached() =
+ assertEquals(CallbackType.NotifyAlertReached, history.poll(DEFAULT_TIMEOUT_MS))
+
+ // Assert there is no callback in current queue.
+ fun assertNoCallback() {
+ val cb = history.poll(0)
+ cb?.let { fail("Expected no callback but got $cb") }
+ }
+}
diff --git a/staticlibs/src_frameworkcommon/android/net/util/IpRange.java b/staticlibs/framework/android/net/util/IpRange.java
similarity index 100%
rename from staticlibs/src_frameworkcommon/android/net/util/IpRange.java
rename to staticlibs/framework/android/net/util/IpRange.java
diff --git a/staticlibs/src_frameworkcommon/android/net/util/LinkPropertiesUtils.java b/staticlibs/framework/android/net/util/LinkPropertiesUtils.java
similarity index 100%
rename from staticlibs/src_frameworkcommon/android/net/util/LinkPropertiesUtils.java
rename to staticlibs/framework/android/net/util/LinkPropertiesUtils.java
diff --git a/staticlibs/src_frameworkcommon/android/net/util/MacAddressUtils.java b/staticlibs/framework/android/net/util/MacAddressUtils.java
similarity index 100%
rename from staticlibs/src_frameworkcommon/android/net/util/MacAddressUtils.java
rename to staticlibs/framework/android/net/util/MacAddressUtils.java
diff --git a/staticlibs/src_frameworkcommon/android/net/util/NetUtils.java b/staticlibs/framework/android/net/util/NetUtils.java
similarity index 100%
rename from staticlibs/src_frameworkcommon/android/net/util/NetUtils.java
rename to staticlibs/framework/android/net/util/NetUtils.java
diff --git a/staticlibs/src_frameworkcommon/android/net/util/nsd/DnsSdTxtRecord.java b/staticlibs/framework/android/net/util/nsd/DnsSdTxtRecord.java
similarity index 100%
rename from staticlibs/src_frameworkcommon/android/net/util/nsd/DnsSdTxtRecord.java
rename to staticlibs/framework/android/net/util/nsd/DnsSdTxtRecord.java
diff --git a/staticlibs/src_frameworkcommon/com/android/net/module/util/DnsPacket.java b/staticlibs/framework/com/android/net/module/util/DnsPacket.java
similarity index 100%
rename from staticlibs/src_frameworkcommon/com/android/net/module/util/DnsPacket.java
rename to staticlibs/framework/com/android/net/module/util/DnsPacket.java
diff --git a/staticlibs/src_frameworkcommon/com/android/net/module/util/Inet4AddressUtils.java b/staticlibs/framework/com/android/net/module/util/Inet4AddressUtils.java
similarity index 100%
rename from staticlibs/src_frameworkcommon/com/android/net/module/util/Inet4AddressUtils.java
rename to staticlibs/framework/com/android/net/module/util/Inet4AddressUtils.java
diff --git a/staticlibs/src_frameworkcommon/com/android/net/module/util/InetAddressUtils.java b/staticlibs/framework/com/android/net/module/util/InetAddressUtils.java
similarity index 100%
rename from staticlibs/src_frameworkcommon/com/android/net/module/util/InetAddressUtils.java
rename to staticlibs/framework/com/android/net/module/util/InetAddressUtils.java
diff --git a/staticlibs/hostdevice/com/android/net/module/util/TrackRecord.kt b/staticlibs/hostdevice/com/android/net/module/util/TrackRecord.kt
new file mode 100644
index 0000000..b647d99
--- /dev/null
+++ b/staticlibs/hostdevice/com/android/net/module/util/TrackRecord.kt
@@ -0,0 +1,247 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.net.module.util
+
+import java.util.concurrent.TimeUnit
+import java.util.concurrent.locks.Condition
+import java.util.concurrent.locks.ReentrantLock
+import kotlin.concurrent.withLock
+
+/**
+ * A List that additionally offers the ability to append via the add() method, and to retrieve
+ * an element by its index optionally waiting for it to become available.
+ */
+interface TrackRecord<E> : List<E> {
+ /**
+ * Adds an element to this queue, waking up threads waiting for one. Returns true, as
+ * per the contract for List.
+ */
+ fun add(e: E): Boolean
+
+ /**
+ * Returns the first element after {@param pos}, possibly blocking until one is available, or
+ * null if no such element can be found within the timeout.
+ * If a predicate is given, only elements matching the predicate are returned.
+ *
+ * @param timeoutMs how long, in milliseconds, to wait at most (best effort approximation).
+ * @param pos the position at which to start polling.
+ * @param predicate an optional predicate to filter elements to be returned.
+ * @return an element matching the predicate, or null if timeout.
+ */
+ fun poll(timeoutMs: Long, pos: Int, predicate: (E) -> Boolean = { true }): E?
+}
+
+/**
+ * A thread-safe implementation of TrackRecord that is backed by an ArrayList.
+ *
+ * This class also supports the creation of a read-head for easier single-thread access.
+ * Refer to the documentation of {@link ArrayTrackRecord.ReadHead}.
+ */
+class ArrayTrackRecord<E> : TrackRecord<E> {
+ private val lock = ReentrantLock()
+ private val condition = lock.newCondition()
+ // Backing store. This stores the elements in this ArrayTrackRecord.
+ private val elements = ArrayList<E>()
+
+ // The list iterator for RecordingQueue iterates over a snapshot of the collection at the
+ // time the operator is created. Because TrackRecord is only ever mutated by appending,
+ // that makes this iterator thread-safe as it sees an effectively immutable List.
+ class ArrayTrackRecordIterator<E>(
+ private val list: ArrayList<E>,
+ start: Int,
+ private val end: Int
+ ) : ListIterator<E> {
+ var index = start
+ override fun hasNext() = index < end
+ override fun next() = list[index++]
+ override fun hasPrevious() = index > 0
+ override fun nextIndex() = index + 1
+ override fun previous() = list[--index]
+ override fun previousIndex() = index - 1
+ }
+
+ // List<E> implementation
+ override val size get() = lock.withLock { elements.size }
+ override fun contains(element: E) = lock.withLock { elements.contains(element) }
+ override fun containsAll(elements: Collection<E>) = lock.withLock {
+ this.elements.containsAll(elements)
+ }
+ override operator fun get(index: Int) = lock.withLock { elements[index] }
+ override fun indexOf(element: E): Int = lock.withLock { elements.indexOf(element) }
+ override fun lastIndexOf(element: E): Int = lock.withLock { elements.lastIndexOf(element) }
+ override fun isEmpty() = lock.withLock { elements.isEmpty() }
+ override fun listIterator(index: Int) = ArrayTrackRecordIterator(elements, index, size)
+ override fun listIterator() = listIterator(0)
+ override fun iterator() = listIterator()
+ override fun subList(fromIndex: Int, toIndex: Int): List<E> = lock.withLock {
+ elements.subList(fromIndex, toIndex)
+ }
+
+ // TrackRecord<E> implementation
+ override fun add(e: E): Boolean {
+ lock.withLock {
+ elements.add(e)
+ condition.signalAll()
+ }
+ return true
+ }
+ override fun poll(timeoutMs: Long, pos: Int, predicate: (E) -> Boolean) = lock.withLock {
+ elements.getOrNull(pollForIndexReadLocked(timeoutMs, pos, predicate))
+ }
+
+ // For convenience
+ fun getOrNull(pos: Int, predicate: (E) -> Boolean) = lock.withLock {
+ if (pos < 0 || pos > size) null else elements.subList(pos, size).find(predicate)
+ }
+
+ // Returns the index of the next element whose position is >= pos matching the predicate, if
+ // necessary waiting until such a time that such an element is available, with a timeout.
+ // If no such element is found within the timeout -1 is returned.
+ private fun pollForIndexReadLocked(timeoutMs: Long, pos: Int, predicate: (E) -> Boolean): Int {
+ val deadline = System.currentTimeMillis() + timeoutMs
+ var index = pos
+ do {
+ while (index < elements.size) {
+ if (predicate(elements[index])) return index
+ ++index
+ }
+ } while (condition.await(deadline - System.currentTimeMillis()))
+ return -1
+ }
+
+ /**
+ * Returns a ReadHead over this ArrayTrackRecord. The returned ReadHead is tied to the
+ * current thread.
+ */
+ fun newReadHead() = ReadHead()
+
+ /**
+ * ReadHead is an object that helps users of ArrayTrackRecord keep track of how far
+ * it has read this far in the ArrayTrackRecord. A ReadHead is always associated with
+ * a single instance of ArrayTrackRecord. Multiple ReadHeads can be created and used
+ * on the same instance of ArrayTrackRecord concurrently, and the ArrayTrackRecord
+ * instance can also be used concurrently. ReadHead maintains the current index that is
+ * the next to be read, and calls this the "mark".
+ *
+ * A ReadHead delegates all TrackRecord methods to its associated ArrayTrackRecord, and
+ * inherits its thread-safe properties. However, the additional methods that ReadHead
+ * offers on top of TrackRecord do not share these properties and can only be used by
+ * the thread that created the ReadHead. This is because by construction it does not
+ * make sense to use a ReadHead on multiple threads concurrently (see below for details).
+ *
+ * In a ReadHead, {@link poll(Long, (E) -> Boolean)} works similarly to a LinkedBlockingQueue.
+ * It can be called repeatedly and will return the elements as they arrive.
+ *
+ * Intended usage looks something like this :
+ * val TrackRecord<MyObject> record = ArrayTrackRecord().newReadHead()
+ * Thread().start {
+ * // do stuff
+ * record.add(something)
+ * // do stuff
+ * }
+ *
+ * val obj1 = record.poll(timeout)
+ * // do something with obj1
+ * val obj2 = record.poll(timeout)
+ * // do something with obj2
+ *
+ * The point is that the caller does not have to track the mark like it would have to if
+ * it was using ArrayTrackRecord directly.
+ *
+ * Note that if multiple threads were using poll() concurrently on the same ReadHead, what
+ * happens to the mark and the return values could be well defined, but it could not
+ * be useful because there is no way to provide either a guarantee not to skip objects nor
+ * a guarantee about the mark position at the exit of poll(). This is even more true in the
+ * presence of a predicate to filter returned elements, because one thread might be
+ * filtering out the events the other is interested in.
+ * Instead, this use case is supported by creating multiple ReadHeads on the same instance
+ * of ArrayTrackRecord. Each ReadHead is then guaranteed to see all events always and
+ * guarantees are made on the value of the mark upon return. {@see poll(Long, (E) -> Boolean)}
+ * for details. Be careful to create each ReadHead on the thread it is meant to be used on.
+ *
+ * Users of a ReadHead can ask for the current position of the mark at any time. This mark
+ * can be used later to replay the history of events either on this ReadHead, on the associated
+ * ArrayTrackRecord or on another ReadHead associated with the same ArrayTrackRecord. It
+ * might look like this in the reader thread :
+ *
+ * val markAtStart = record.mark
+ * // Start processing interesting events
+ * while (val element = record.poll(timeout) { it.isInteresting() }) {
+ * // Do something with element
+ * }
+ * // Look for stuff that happened while searching for interesting events
+ * val firstElementReceived = record.getOrNull(markAtStart)
+ * val firstSpecialElement = record.getOrNull(markAtStart) { it.isSpecial() }
+ * // Get the first special element since markAtStart, possibly blocking until one is available
+ * val specialElement = record.poll(timeout, markAtStart) { it.isSpecial() }
+ */
+ inner class ReadHead : TrackRecord<E> by this@ArrayTrackRecord {
+ private val owningThread = Thread.currentThread()
+ private var readHead = 0
+
+ /**
+ * @return the current value of the mark.
+ */
+ var mark
+ get() = readHead.also { checkThread() }
+ set(v: Int) = rewind(v)
+ fun rewind(v: Int) {
+ checkThread()
+ readHead = v
+ }
+
+ private fun checkThread() = check(Thread.currentThread() == owningThread) {
+ "Must be called by the thread that created this object"
+ }
+
+ /**
+ * Returns the first element after the mark, optionally blocking until one is available, or
+ * null if no such element can be found within the timeout.
+ * If a predicate is given, only elements matching the predicate are returned.
+ *
+ * Upon return the mark will be set to immediately after the returned element, or after
+ * the last element in the queue if null is returned. This means this method will always
+ * skip elements that do not match the predicate, even if it returns null.
+ *
+ * This method can only be used by the thread that created this ManagedRecordingQueue.
+ * If used on another thread, this throws IllegalStateException.
+ *
+ * @param timeoutMs how long, in milliseconds, to wait at most (best effort approximation).
+ * @param predicate an optional predicate to filter elements to be returned.
+ * @return an element matching the predicate, or null if timeout.
+ */
+ fun poll(timeoutMs: Long, predicate: (E) -> Boolean = { true }): E? {
+ checkThread()
+ lock.withLock {
+ val index = pollForIndexReadLocked(timeoutMs, readHead, predicate)
+ readHead = if (index < 0) size else index + 1
+ return getOrNull(index)
+ }
+ }
+
+ /**
+ * Returns the first element after the mark or null. This never blocks.
+ *
+ * This method can only be used by the thread that created this ManagedRecordingQueue.
+ * If used on another thread, this throws IllegalStateException.
+ */
+ fun peek(): E? = getOrNull(readHead).also { checkThread() }
+ }
+}
+
+// Private helper
+private fun Condition.await(timeoutMs: Long) = this.await(timeoutMs, TimeUnit.MILLISECONDS)
diff --git a/staticlibs/hostdevice/com/android/testutils/ConcurrentUtils.kt b/staticlibs/hostdevice/com/android/testutils/ConcurrentUtils.kt
new file mode 100644
index 0000000..a365af5
--- /dev/null
+++ b/staticlibs/hostdevice/com/android/testutils/ConcurrentUtils.kt
@@ -0,0 +1,26 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import java.util.concurrent.CountDownLatch
+import java.util.concurrent.TimeUnit
+import kotlin.system.measureTimeMillis
+
+// For Java usage
+fun durationOf(fn: Runnable) = measureTimeMillis { fn.run() }
+
+fun CountDownLatch.await(timeoutMs: Long): Boolean = await(timeoutMs, TimeUnit.MILLISECONDS)
diff --git a/staticlibs/hostdevice/com/android/testutils/ExceptionUtils.java b/staticlibs/hostdevice/com/android/testutils/ExceptionUtils.java
new file mode 100644
index 0000000..e7dbed5
--- /dev/null
+++ b/staticlibs/hostdevice/com/android/testutils/ExceptionUtils.java
@@ -0,0 +1,67 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils;
+
+import java.util.function.Supplier;
+
+public class ExceptionUtils {
+ /**
+ * Like a Consumer, but declared to throw an exception.
+ * @param <T>
+ */
+ @FunctionalInterface
+ public interface ThrowingConsumer<T> {
+ void accept(T t) throws Exception;
+ }
+
+ /**
+ * Like a Supplier, but declared to throw an exception.
+ * @param <T>
+ */
+ @FunctionalInterface
+ public interface ThrowingSupplier<T> {
+ T get() throws Exception;
+ }
+
+ /**
+ * Like a Runnable, but declared to throw an exception.
+ */
+ @FunctionalInterface
+ public interface ThrowingRunnable {
+ void run() throws Exception;
+ }
+
+
+ public static <T> Supplier<T> ignoreExceptions(ThrowingSupplier<T> func) {
+ return () -> {
+ try {
+ return func.get();
+ } catch (Exception e) {
+ return null;
+ }
+ };
+ }
+
+ public static Runnable ignoreExceptions(ThrowingRunnable r) {
+ return () -> {
+ try {
+ r.run();
+ } catch (Exception e) {
+ }
+ };
+ }
+}
diff --git a/staticlibs/hostdevice/com/android/testutils/FileUtils.kt b/staticlibs/hostdevice/com/android/testutils/FileUtils.kt
new file mode 100644
index 0000000..678f977
--- /dev/null
+++ b/staticlibs/hostdevice/com/android/testutils/FileUtils.kt
@@ -0,0 +1,27 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+// This function is private because the 2 is hardcoded here, and is not correct if not called
+// directly from __LINE__ or __FILE__.
+private fun callerStackTrace(): StackTraceElement = try {
+ throw RuntimeException()
+} catch (e: RuntimeException) {
+ e.stackTrace[2] // 0 is here, 1 is get() in __FILE__ or __LINE__
+}
+val __FILE__: String get() = callerStackTrace().fileName
+val __LINE__: Int get() = callerStackTrace().lineNumber
diff --git a/staticlibs/hostdevice/com/android/testutils/MiscAsserts.kt b/staticlibs/hostdevice/com/android/testutils/MiscAsserts.kt
new file mode 100644
index 0000000..32c22c2
--- /dev/null
+++ b/staticlibs/hostdevice/com/android/testutils/MiscAsserts.kt
@@ -0,0 +1,103 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import com.android.testutils.ExceptionUtils.ThrowingRunnable
+import java.lang.reflect.Modifier
+import kotlin.system.measureTimeMillis
+import kotlin.test.assertEquals
+import kotlin.test.assertFailsWith
+import kotlin.test.assertFalse
+import kotlin.test.assertTrue
+
+private const val TAG = "Connectivity unit test"
+
+fun <T> assertEmpty(ts: Array<T>) = ts.size.let { len ->
+ assertEquals(0, len, "Expected empty array, but length was $len")
+}
+
+fun <T> assertLength(expected: Int, got: Array<T>) = got.size.let { len ->
+ assertEquals(expected, len, "Expected array of length $expected, but was $len for $got")
+}
+
+// Bridge method to help write this in Java. If you're writing Kotlin, consider using native
+// kotlin.test.assertFailsWith instead, as that method is reified and inlined.
+fun <T : Exception> assertThrows(expected: Class<T>, block: ThrowingRunnable): T {
+ return assertFailsWith(expected.kotlin) { block.run() }
+}
+
+fun <T : Exception> assertThrows(msg: String, expected: Class<T>, block: ThrowingRunnable): T {
+ return assertFailsWith(expected.kotlin, msg) { block.run() }
+}
+
+fun <T> assertEqualBothWays(o1: T, o2: T) {
+ assertTrue(o1 == o2)
+ assertTrue(o2 == o1)
+}
+
+fun <T> assertNotEqualEitherWay(o1: T, o2: T) {
+ assertFalse(o1 == o2)
+ assertFalse(o2 == o1)
+}
+
+fun assertStringContains(got: String, want: String) {
+ assertTrue(got.contains(want), "$got did not contain \"${want}\"")
+}
+
+fun assertContainsExactly(actual: IntArray, vararg expected: Int) {
+ // IntArray#sorted() returns a list, so it's fine to test with equals()
+ assertEquals(actual.sorted(), expected.sorted(),
+ "$actual does not contain exactly $expected")
+}
+
+fun assertContainsStringsExactly(actual: Array<String>, vararg expected: String) {
+ assertEquals(actual.sorted(), expected.sorted(),
+ "$actual does not contain exactly $expected")
+}
+
+fun <T> assertContainsAll(list: Collection<T>, vararg elems: T) {
+ assertContainsAll(list, elems.asList())
+}
+
+fun <T> assertContainsAll(list: Collection<T>, elems: Collection<T>) {
+ elems.forEach { assertTrue(list.contains(it), "$it not in list") }
+}
+
+fun assertRunsInAtMost(descr: String, timeLimit: Long, fn: Runnable) {
+ assertRunsInAtMost(descr, timeLimit) { fn.run() }
+}
+
+fun assertRunsInAtMost(descr: String, timeLimit: Long, fn: () -> Unit) {
+ val timeTaken = measureTimeMillis(fn)
+ val msg = String.format("%s: took %dms, limit was %dms", descr, timeTaken, timeLimit)
+ assertTrue(timeTaken <= timeLimit, msg)
+}
+
+/**
+ * Verifies that the number of nonstatic fields in a java class equals a given count.
+ * Note: this is essentially not useful for Kotlin code where fields are not really a thing.
+ *
+ * This assertion serves as a reminder to update test code around it if fields are added
+ * after the test is written.
+ * @param count Expected number of nonstatic fields in the class.
+ * @param clazz Class to test.
+ */
+fun <T> assertFieldCountEquals(count: Int, clazz: Class<T>) {
+ assertEquals(count, clazz.declaredFields.filter {
+ !Modifier.isStatic(it.modifiers) && !Modifier.isTransient(it.modifiers)
+ }.size)
+}
diff --git a/staticlibs/hostdevice/com/android/testutils/PacketFilter.kt b/staticlibs/hostdevice/com/android/testutils/PacketFilter.kt
new file mode 100644
index 0000000..cd8d6a5
--- /dev/null
+++ b/staticlibs/hostdevice/com/android/testutils/PacketFilter.kt
@@ -0,0 +1,87 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import java.util.function.Predicate
+
+const val ETHER_TYPE_OFFSET = 12
+const val ETHER_HEADER_LENGTH = 14
+const val IPV4_PROTOCOL_OFFSET = ETHER_HEADER_LENGTH + 9
+const val IPV4_CHECKSUM_OFFSET = ETHER_HEADER_LENGTH + 10
+const val IPV4_HEADER_LENGTH = 20
+const val IPV4_UDP_OFFSET = ETHER_HEADER_LENGTH + IPV4_HEADER_LENGTH
+const val BOOTP_OFFSET = IPV4_UDP_OFFSET + 8
+const val BOOTP_TID_OFFSET = BOOTP_OFFSET + 4
+const val BOOTP_CLIENT_MAC_OFFSET = BOOTP_OFFSET + 28
+const val DHCP_OPTIONS_OFFSET = BOOTP_OFFSET + 240
+
+/**
+ * A [Predicate] that matches a [ByteArray] if it contains the specified [bytes] at the specified
+ * [offset].
+ */
+class OffsetFilter(val offset: Int, vararg val bytes: Byte) : Predicate<ByteArray> {
+ override fun test(packet: ByteArray) =
+ bytes.withIndex().all { it.value == packet[offset + it.index] }
+}
+
+/**
+ * A [Predicate] that matches ethernet-encapped packets that contain an UDP over IPv4 datagram.
+ */
+class IPv4UdpFilter : Predicate<ByteArray> {
+ private val impl = OffsetFilter(ETHER_TYPE_OFFSET, 0x08, 0x00 /* IPv4 */).and(
+ OffsetFilter(IPV4_PROTOCOL_OFFSET, 17 /* UDP */))
+ override fun test(t: ByteArray) = impl.test(t)
+}
+
+/**
+ * A [Predicate] that matches ethernet-encapped DHCP packets sent from a DHCP client.
+ */
+class DhcpClientPacketFilter : Predicate<ByteArray> {
+ private val impl = IPv4UdpFilter()
+ .and(OffsetFilter(IPV4_UDP_OFFSET /* source port */, 0x00, 0x44 /* 68 */))
+ .and(OffsetFilter(IPV4_UDP_OFFSET + 2 /* dest port */, 0x00, 0x43 /* 67 */))
+ override fun test(t: ByteArray) = impl.test(t)
+}
+
+/**
+ * A [Predicate] that matches a [ByteArray] if it contains a ethernet-encapped DHCP packet that
+ * contains the specified option with the specified [bytes] as value.
+ */
+class DhcpOptionFilter(val option: Byte, vararg val bytes: Byte) : Predicate<ByteArray> {
+ override fun test(packet: ByteArray): Boolean {
+ val option = findDhcpOption(packet, option) ?: return false
+ return option.contentEquals(bytes)
+ }
+}
+
+/**
+ * Find a DHCP option in a packet and return its value, if found.
+ */
+fun findDhcpOption(packet: ByteArray, option: Byte): ByteArray? =
+ findOptionOffset(packet, option, DHCP_OPTIONS_OFFSET)?.let {
+ val optionLen = packet[it + 1]
+ return packet.copyOfRange(it + 2 /* type, length bytes */, it + 2 + optionLen)
+ }
+
+private tailrec fun findOptionOffset(packet: ByteArray, option: Byte, searchOffset: Int): Int? {
+ if (packet.size <= searchOffset + 2 /* type, length bytes */) return null
+
+ return if (packet[searchOffset] == option) searchOffset else {
+ val optionLen = packet[searchOffset + 1]
+ findOptionOffset(packet, option, searchOffset + 2 + optionLen)
+ }
+}
diff --git a/staticlibs/hostdevice/com/android/testutils/SkipPresubmit.kt b/staticlibs/hostdevice/com/android/testutils/SkipPresubmit.kt
new file mode 100644
index 0000000..69ed048
--- /dev/null
+++ b/staticlibs/hostdevice/com/android/testutils/SkipPresubmit.kt
@@ -0,0 +1,24 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+/**
+ * Skip the test in presubmit runs for the reason specified in [reason].
+ *
+ * This annotation is typically used to document hardware or test bench limitations.
+ */
+annotation class SkipPresubmit(val reason: String)
\ No newline at end of file
diff --git a/staticlibs/jarjar-rules-shared.txt b/staticlibs/jarjar-rules-shared.txt
index b22771d..9e618db 100644
--- a/staticlibs/jarjar-rules-shared.txt
+++ b/staticlibs/jarjar-rules-shared.txt
@@ -1,3 +1,8 @@
-rule android.net.util.** com.android.net.module.util.@1
+# TODO: move the classes to the target package in java
+rule android.net.util.IpRange* com.android.net.module.util.IpRange@1
+rule android.net.util.MacAddressUtils* com.android.net.module.util.MacAddressUtils@1
+rule android.net.util.LinkPropertiesUtils* com.android.net.module.util.LinkPropertiesUtils@1
+rule android.net.util.NetUtils* com.android.net.module.util.NetUtils@1
+rule android.net.util.nsd.** com.android.net.module.util.nsd.@1
rule android.annotation.** com.android.net.module.annotation.@1
-rule com.android.internal.annotations.** com.android.net.module.annotation.@1
\ No newline at end of file
+rule com.android.internal.annotations.** com.android.net.module.annotation.@1
diff --git a/staticlibs/tests/unit/Android.bp b/staticlibs/tests/unit/Android.bp
index cb15277..f18ffcf 100644
--- a/staticlibs/tests/unit/Android.bp
+++ b/staticlibs/tests/unit/Android.bp
@@ -10,6 +10,8 @@
static_libs: [
"net-utils-framework-common",
"androidx.test.rules",
+ "net-utils-device-common",
+ "net-tests-utils-host-device-common",
],
libs: [
"android.test.runner",
diff --git a/staticlibs/tests/unit/jarjar-rules.txt b/staticlibs/tests/unit/jarjar-rules.txt
index dbb3974..fceccfb 100644
--- a/staticlibs/tests/unit/jarjar-rules.txt
+++ b/staticlibs/tests/unit/jarjar-rules.txt
@@ -1,4 +1,6 @@
-# Ensure that the tests can directly use the version of classes from the library. Otherwise, they
-# will use whatever version is currently in the bootclasspath on the device running the test.
-# These rules must match the jarjar rules used to build the library.
-rule android.net.util.** com.android.net.module.util.@1
+# TODO: move the classes to the target package in java
+rule android.net.util.IpRange* com.android.net.module.util.IpRange@1
+rule android.net.util.MacAddressUtils* com.android.net.module.util.MacAddressUtils@1
+rule android.net.util.LinkPropertiesUtils* com.android.net.module.util.LinkPropertiesUtils@1
+rule android.net.util.NetUtils* com.android.net.module.util.NetUtils@1
+rule android.net.util.nsd.** com.android.net.module.util.nsd.@1
diff --git a/staticlibs/tests/unit/src/com/android/net/module/util/PacketReaderTest.java b/staticlibs/tests/unit/src/com/android/net/module/util/PacketReaderTest.java
new file mode 100644
index 0000000..e3af97e
--- /dev/null
+++ b/staticlibs/tests/unit/src/com/android/net/module/util/PacketReaderTest.java
@@ -0,0 +1,248 @@
+/*
+ * Copyright (C) 2016 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.net.module.util;
+
+import static android.net.util.PacketReader.DEFAULT_RECV_BUF_SIZE;
+import static android.system.OsConstants.AF_INET6;
+import static android.system.OsConstants.IPPROTO_UDP;
+import static android.system.OsConstants.SOCK_DGRAM;
+import static android.system.OsConstants.SOCK_NONBLOCK;
+import static android.system.OsConstants.SOL_SOCKET;
+import static android.system.OsConstants.SO_SNDTIMEO;
+
+import static com.android.testutils.MiscAssertsKt.assertThrows;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+import android.net.util.PacketReader;
+import android.os.Handler;
+import android.os.HandlerThread;
+import android.system.ErrnoException;
+import android.system.Os;
+import android.system.StructTimeval;
+
+import androidx.test.filters.SmallTest;
+import androidx.test.runner.AndroidJUnit4;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+import java.io.FileDescriptor;
+import java.net.DatagramPacket;
+import java.net.DatagramSocket;
+import java.net.Inet6Address;
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
+import java.net.SocketException;
+import java.util.Arrays;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * Tests for PacketReader.
+ *
+ * @hide
+ */
+@RunWith(AndroidJUnit4.class)
+@SmallTest
+public class PacketReaderTest {
+ static final InetAddress LOOPBACK6 = Inet6Address.getLoopbackAddress();
+ static final StructTimeval TIMEO = StructTimeval.fromMillis(500);
+
+ // TODO : reassigning the latch voids its synchronization properties, which means this
+ // scheme just doesn't work. Latches almost always need to be final to work.
+ protected CountDownLatch mLatch;
+ protected FileDescriptor mLocalSocket;
+ protected InetSocketAddress mLocalSockName;
+ protected byte[] mLastRecvBuf;
+ protected boolean mStopped;
+ protected HandlerThread mHandlerThread;
+ protected PacketReader mReceiver;
+
+ class UdpLoopbackReader extends PacketReader {
+ UdpLoopbackReader(Handler h) {
+ super(h);
+ }
+
+ @Override
+ protected FileDescriptor createFd() {
+ FileDescriptor s = null;
+ try {
+ s = Os.socket(AF_INET6, SOCK_DGRAM | SOCK_NONBLOCK, IPPROTO_UDP);
+ Os.bind(s, LOOPBACK6, 0);
+ mLocalSockName = (InetSocketAddress) Os.getsockname(s);
+ Os.setsockoptTimeval(s, SOL_SOCKET, SO_SNDTIMEO, TIMEO);
+ } catch (ErrnoException | SocketException e) {
+ closeFd(s);
+ fail();
+ return null;
+ }
+
+ mLocalSocket = s;
+ return s;
+ }
+
+ @Override
+ protected void handlePacket(byte[] recvbuf, int length) {
+ mLastRecvBuf = Arrays.copyOf(recvbuf, length);
+ mLatch.countDown();
+ }
+
+ @Override
+ protected void onStart() {
+ mStopped = false;
+ mLatch.countDown();
+ }
+
+ @Override
+ protected void onStop() {
+ mStopped = true;
+ mLatch.countDown();
+ }
+ };
+
+ @Before
+ public void setUp() {
+ resetLatch();
+ mLocalSocket = null;
+ mLocalSockName = null;
+ mLastRecvBuf = null;
+ mStopped = false;
+
+ mHandlerThread = new HandlerThread(PacketReaderTest.class.getSimpleName());
+ mHandlerThread.start();
+ }
+
+ @After
+ public void tearDown() throws Exception {
+ if (mReceiver != null) {
+ mHandlerThread.getThreadHandler().post(() -> mReceiver.stop());
+ waitForActivity();
+ }
+ mReceiver = null;
+ mHandlerThread.quit();
+ mHandlerThread = null;
+ }
+
+ void resetLatch() {
+ mLatch = new CountDownLatch(1);
+ }
+
+ void waitForActivity() throws Exception {
+ try {
+ mLatch.await(1000, TimeUnit.MILLISECONDS);
+ } finally {
+ resetLatch();
+ }
+ }
+
+ void sendPacket(byte[] contents) throws Exception {
+ final DatagramSocket sender = new DatagramSocket();
+ sender.connect(mLocalSockName);
+ sender.send(new DatagramPacket(contents, contents.length));
+ sender.close();
+ }
+
+ @Test
+ public void testBasicWorking() throws Exception {
+ final Handler h = mHandlerThread.getThreadHandler();
+ mReceiver = new UdpLoopbackReader(h);
+
+ h.post(() -> mReceiver.start());
+ waitForActivity();
+ assertTrue(mLocalSockName != null);
+ assertEquals(LOOPBACK6, mLocalSockName.getAddress());
+ assertTrue(0 < mLocalSockName.getPort());
+ assertTrue(mLocalSocket != null);
+ assertFalse(mStopped);
+
+ final byte[] one = "one 1".getBytes("UTF-8");
+ sendPacket(one);
+ waitForActivity();
+ assertEquals(1, mReceiver.numPacketsReceived());
+ assertTrue(Arrays.equals(one, mLastRecvBuf));
+ assertFalse(mStopped);
+
+ final byte[] two = "two 2".getBytes("UTF-8");
+ sendPacket(two);
+ waitForActivity();
+ assertEquals(2, mReceiver.numPacketsReceived());
+ assertTrue(Arrays.equals(two, mLastRecvBuf));
+ assertFalse(mStopped);
+
+ h.post(() -> mReceiver.stop());
+ waitForActivity();
+ assertEquals(2, mReceiver.numPacketsReceived());
+ assertTrue(Arrays.equals(two, mLastRecvBuf));
+ assertTrue(mStopped);
+ mReceiver = null;
+ }
+
+ class NullPacketReader extends PacketReader {
+ NullPacketReader(Handler h, int recvbufsize) {
+ super(h, recvbufsize);
+ }
+
+ @Override
+ public FileDescriptor createFd() {
+ return null;
+ }
+ }
+
+ @Test
+ public void testMinimalRecvBufSize() throws Exception {
+ final Handler h = mHandlerThread.getThreadHandler();
+
+ for (int i : new int[] { -1, 0, 1, DEFAULT_RECV_BUF_SIZE - 1 }) {
+ final PacketReader b = new NullPacketReader(h, i);
+ assertEquals(DEFAULT_RECV_BUF_SIZE, b.recvBufSize());
+ }
+ }
+
+ @Test
+ public void testStartingFromWrongThread() throws Exception {
+ final Handler h = mHandlerThread.getThreadHandler();
+ final PacketReader b = new NullPacketReader(h, DEFAULT_RECV_BUF_SIZE);
+ assertThrows(IllegalStateException.class, () -> b.start());
+ }
+
+ @Test
+ public void testStoppingFromWrongThread() throws Exception {
+ final Handler h = mHandlerThread.getThreadHandler();
+ final PacketReader b = new NullPacketReader(h, DEFAULT_RECV_BUF_SIZE);
+ assertThrows(IllegalStateException.class, () -> b.stop());
+ }
+
+ @Test
+ public void testSuccessToCreateSocket() throws Exception {
+ final Handler h = mHandlerThread.getThreadHandler();
+ final PacketReader b = new UdpLoopbackReader(h);
+ h.post(() -> assertTrue(b.start()));
+ }
+
+ @Test
+ public void testFailToCreateSocket() throws Exception {
+ final Handler h = mHandlerThread.getThreadHandler();
+ final PacketReader b = new NullPacketReader(h, DEFAULT_RECV_BUF_SIZE);
+ h.post(() -> assertFalse(b.start()));
+ }
+}