Add diagnostics to testNativeDatagramTransmission

When the test fails, retry it 10 times, and an extra time on the
original source port. Report results in the connectivity diagnostics
file header so they can be parsed across tests.

This should help investigate issues possibly linked to UDP packets not
being routed properly depending on the port.

Test: atest
Bug: 375477810
Bug: 253698734
Bug: 324389407
Bug: 338090529
Change-Id: I8400b40257aa25a373c5a3ec96868d6877e0afb1
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/ConnectivityDiagnosticsCollector.kt b/staticlibs/testutils/devicetests/com/android/testutils/ConnectivityDiagnosticsCollector.kt
index ea86281..9e63910 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/ConnectivityDiagnosticsCollector.kt
+++ b/staticlibs/testutils/devicetests/com/android/testutils/ConnectivityDiagnosticsCollector.kt
@@ -76,11 +76,13 @@
         private const val MAX_DUMPS = 20
 
         private val TAG = ConnectivityDiagnosticsCollector::class.simpleName
+        @JvmStatic
         var instance: ConnectivityDiagnosticsCollector? = null
     }
 
     private var failureHeader: String? = null
     private val buffer = ByteArrayOutputStream()
+    private val failureHeaderExtras = mutableMapOf<String, Any>()
     private val collectorDir: File by lazy {
         createAndEmptyDirectory(COLLECTOR_DIR)
     }
@@ -218,6 +220,8 @@
         val canUseShell = !isAtLeastS() ||
                 instr.uiAutomation.getAdoptedShellPermissions().isNullOrEmpty()
         val headerObj = JSONObject()
+        failureHeaderExtras.forEach { (k, v) -> headerObj.put(k, v) }
+        failureHeaderExtras.clear()
         if (canUseShell) {
             runAsShell(READ_PRIVILEGED_PHONE_STATE, NETWORK_SETTINGS) {
                 headerObj.apply {
@@ -332,6 +336,15 @@
         }
     }
 
+    /**
+     * Add a key->value attribute to the failure data, to be written to the diagnostics file.
+     *
+     * <p>This is to be called by tests that know they will fail.
+     */
+    fun addFailureAttribute(key: String, value: Any) {
+        failureHeaderExtras[key] = value
+    }
+
     private fun maybeWriteExceptionContext(writer: PrintWriter, exceptionContext: Throwable?) {
         if (exceptionContext == null) return
         writer.println("At: ")
diff --git a/tests/cts/net/jni/NativeMultinetworkJni.cpp b/tests/cts/net/jni/NativeMultinetworkJni.cpp
index f2214a3..1d848ec 100644
--- a/tests/cts/net/jni/NativeMultinetworkJni.cpp
+++ b/tests/cts/net/jni/NativeMultinetworkJni.cpp
@@ -415,9 +415,17 @@
     strlcpy(dst, buf, size);
 }
 
+static jobject create_query_test_result(JNIEnv* env, uint16_t src_port, int attempts, int errnum) {
+    jclass clazz = env->FindClass(
+        "android/net/cts/MultinetworkApiTest$QueryTestResult");
+    jmethodID ctor = env->GetMethodID(clazz, "<init>", "(III)V");
+
+    return env->NewObject(clazz, ctor, src_port, attempts, errnum);
+}
+
 extern "C"
-JNIEXPORT jint Java_android_net_cts_MultinetworkApiTest_runDatagramCheck(
-        JNIEnv*, jclass, jlong nethandle) {
+JNIEXPORT jobject Java_android_net_cts_MultinetworkApiTest_runDatagramCheck(
+        JNIEnv* env, jclass, jlong nethandle, jint src_port) {
     const struct addrinfo kHints = {
         .ai_flags = AI_ADDRCONFIG,
         .ai_family = AF_UNSPEC,
@@ -433,7 +441,7 @@
         LOGD("android_getaddrinfofornetwork(%llu, %s) returned rval=%d errno=%d",
               handle, kHostname, rval, errno);
         freeaddrinfo(res);
-        return -errno;
+        return create_query_test_result(env, 0, 0, errno);
     }
 
     // Rely upon getaddrinfo sorting the best destination to the front.
@@ -442,7 +450,7 @@
         LOGD("socket(%d, %d, %d) failed, errno=%d",
               res->ai_family, res->ai_socktype, res->ai_protocol, errno);
         freeaddrinfo(res);
-        return -errno;
+        return create_query_test_result(env, 0, 0, errno);
     }
 
     rval = android_setsocknetwork(handle, fd);
@@ -451,7 +459,31 @@
     if (rval != 0) {
         close(fd);
         freeaddrinfo(res);
-        return -errno;
+        return create_query_test_result(env, 0, 0, errno);
+    }
+
+    sockaddr_storage src_addr;
+    socklen_t src_addrlen = sizeof(src_addr);
+    if (src_port) {
+        if (res->ai_family == AF_INET6) {
+            *reinterpret_cast<sockaddr_in6*>(&src_addr) = (sockaddr_in6) {
+                .sin6_family = AF_INET6,
+                .sin6_port = htons(src_port),
+                .sin6_addr = in6addr_any,
+            };
+        } else {
+            *reinterpret_cast<sockaddr_in*>(&src_addr) = (sockaddr_in) {
+                .sin_family = AF_INET,
+                .sin_port = htons(src_port),
+                .sin_addr = { .s_addr = INADDR_ANY },
+            };
+        }
+        if (bind(fd, (sockaddr *)&src_addr, src_addrlen) != 0) {
+            LOGD("Error binding to port %d", src_port);
+            close(fd);
+            freeaddrinfo(res);
+            return create_query_test_result(env, 0, 0, errno);
+        }
     }
 
     char addrstr[kSockaddrStrLen+1];
@@ -462,19 +494,28 @@
     if (rval != 0) {
         close(fd);
         freeaddrinfo(res);
-        return -errno;
+        return create_query_test_result(env, 0, 0, errno);
     }
     freeaddrinfo(res);
 
-    struct sockaddr_storage src_addr;
-    socklen_t src_addrlen = sizeof(src_addr);
     if (getsockname(fd, (struct sockaddr *)&src_addr, &src_addrlen) != 0) {
         close(fd);
-        return -errno;
+        return create_query_test_result(env, 0, 0, errno);
     }
     sockaddr_ntop((const struct sockaddr *)&src_addr, sizeof(src_addr), addrstr, sizeof(addrstr));
     LOGD("... from %s", addrstr);
 
+    uint16_t socket_src_port;
+    if (res->ai_family == AF_INET6) {
+        socket_src_port = ntohs(reinterpret_cast<sockaddr_in6*>(&src_addr)->sin6_port);
+    } else if (src_addr.ss_family == AF_INET) {
+        socket_src_port = ntohs(reinterpret_cast<sockaddr_in*>(&src_addr)->sin_port);
+    } else {
+        LOGD("Invalid source address family %d", src_addr.ss_family);
+        close(fd);
+        return create_query_test_result(env, 0, 0, EAFNOSUPPORT);
+    }
+
     // Don't let reads or writes block indefinitely.
     const struct timeval timeo = { 2, 0 };  // 2 seconds
     setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeo, sizeof(timeo));
@@ -503,7 +544,7 @@
             errnum = errno;
             LOGD("send(QUIC packet) returned sent=%zd, errno=%d", sent, errnum);
             close(fd);
-            return -errnum;
+            return create_query_test_result(env, socket_src_port, i + 1, errnum);
         }
 
         rcvd = recv(fd, response, sizeof(response), 0);
@@ -521,18 +562,19 @@
             LOGD("Does this network block UDP port %s?", kPort);
         }
         close(fd);
-        return -EPROTO;
+        return create_query_test_result(env, socket_src_port, i + 1,
+                rcvd <= 0 ? errnum : EPROTO);
     }
 
     int conn_id_cmp = memcmp(quic_packet + 6, response + 7, 8);
     if (conn_id_cmp != 0) {
         LOGD("sent and received connection IDs do not match");
         close(fd);
-        return -EPROTO;
+        return create_query_test_result(env, socket_src_port, i + 1, EPROTO);
     }
 
     // TODO: Replace this quick 'n' dirty test with proper QUIC-capable code.
 
     close(fd);
-    return 0;
+    return create_query_test_result(env, socket_src_port, i + 1, 0);
 }
diff --git a/tests/cts/net/src/android/net/cts/MultinetworkApiTest.java b/tests/cts/net/src/android/net/cts/MultinetworkApiTest.java
index 2c7d5c6..c67443e 100644
--- a/tests/cts/net/src/android/net/cts/MultinetworkApiTest.java
+++ b/tests/cts/net/src/android/net/cts/MultinetworkApiTest.java
@@ -39,10 +39,13 @@
 import android.system.ErrnoException;
 import android.system.OsConstants;
 import android.util.ArraySet;
+import android.util.Log;
 
 import androidx.test.platform.app.InstrumentationRegistry;
 
+import com.android.net.module.util.CollectionUtils;
 import com.android.testutils.AutoReleaseNetworkCallbackRule;
+import com.android.testutils.ConnectivityDiagnosticsCollector;
 import com.android.testutils.DevSdkIgnoreRunner;
 import com.android.testutils.DeviceConfigRule;
 
@@ -51,6 +54,8 @@
 import org.junit.Test;
 import org.junit.runner.RunWith;
 
+import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Set;
 
 @DevSdkIgnoreRunner.RestoreDefaultNetwork
@@ -70,13 +75,34 @@
     private static final String TAG = "MultinetworkNativeApiTest";
     static final String GOOGLE_PRIVATE_DNS_SERVER = "dns.google";
 
+    public static class QueryTestResult {
+        public final int sourcePort;
+        public final int attempts;
+        public final int errNo;
+
+        public QueryTestResult(int sourcePort, int attempts, int errNo) {
+            this.sourcePort = sourcePort;
+            this.attempts = attempts;
+            this.errNo = errNo;
+        }
+
+        @Override
+        public String toString() {
+            return "QueryTestResult{"
+                    + "sourcePort=" + sourcePort
+                    + ", attempts=" + attempts
+                    + ", errNo=" + errNo
+                    + '}';
+        }
+    }
+
     /**
      * @return 0 on success
      */
     private static native int runGetaddrinfoCheck(long networkHandle);
     private static native int runSetprocnetwork(long networkHandle);
     private static native int runSetsocknetwork(long networkHandle);
-    private static native int runDatagramCheck(long networkHandle);
+    private static native QueryTestResult runDatagramCheck(long networkHandle, int sourcePort);
     private static native void runResNapiMalformedCheck(long networkHandle);
     private static native void runResNcancelCheck(long networkHandle);
     private static native void runResNqueryCheck(long networkHandle);
@@ -165,14 +191,69 @@
         }
     }
 
+    private void runNativeDatagramTransmissionDiagnostics(Network network,
+            QueryTestResult failedResult) {
+        final ConnectivityDiagnosticsCollector collector = ConnectivityDiagnosticsCollector
+                .getInstance();
+        if (collector == null) {
+            Log.e(TAG, "Missing ConnectivityDiagnosticsCollector, not adding diagnostics");
+            return;
+        }
+
+        final int numReruns = 10;
+        final ArrayList<QueryTestResult> reruns = new ArrayList<>(numReruns);
+        for (int i = 0; i < numReruns; i++) {
+            final QueryTestResult rerunResult =
+                    runDatagramCheck(network.getNetworkHandle(), 0 /* sourcePort */);
+            Log.d(TAG, "Rerun result " + i + ": " + rerunResult);
+            reruns.add(rerunResult);
+        }
+        // Rerun on the original port after trying the other ports, to check that the results are
+        // consistent, as opposed to the network recovering halfway through.
+        int originalPortFailedReruns = 0;
+        for (int i = 0; i < numReruns; i++) {
+            final QueryTestResult originalPortRerun = runDatagramCheck(network.getNetworkHandle(),
+                    failedResult.sourcePort);
+            Log.d(TAG, "Rerun result " + i + " with original port: " + originalPortRerun);
+            if (originalPortRerun.errNo != 0) {
+                originalPortFailedReruns++;
+            }
+        }
+
+        final int noRetrySuccessResults = reruns.stream()
+                .filter(result -> result.errNo == 0 && result.attempts == 1)
+                .mapToInt(result -> 1)
+                .sum();
+        final int failedResults = reruns.stream()
+                .filter(result -> result.errNo != 0)
+                .mapToInt(result -> 1)
+                .sum();
+        collector.addFailureAttribute("numReruns", numReruns);
+        collector.addFailureAttribute("noRetrySuccessReruns", noRetrySuccessResults);
+        collector.addFailureAttribute("failedReruns", failedResults);
+        collector.addFailureAttribute("originalPortFailedReruns", originalPortFailedReruns);
+    }
+
     @Test
     public void testNativeDatagramTransmission() throws Exception {
         for (Network network : getTestableNetworks()) {
-            int errno = runDatagramCheck(network.getNetworkHandle());
-            if (errno != 0) {
-                throw new ErrnoException(
-                        "DatagramCheck on " + mCM.getNetworkInfo(network), -errno);
+            final QueryTestResult result = runDatagramCheck(network.getNetworkHandle(),
+                    0 /* sourcePort */);
+            if (result.errNo == 0) {
+                continue;
             }
+            final NetworkCapabilities nc = mCM.getNetworkCapabilities(network);
+            final int[] transports = nc != null ? nc.getTransportTypes() : null;
+            if (CollectionUtils.contains(transports, TRANSPORT_WIFI)) {
+                runNativeDatagramTransmissionDiagnostics(network, result);
+            }
+
+            // Log the whole result (with source port and attempts) to logcat, but use only the
+            // errno and transport in the fail message so similar failures have consistent messages
+            final String error = "DatagramCheck on transport " + Arrays.toString(transports)
+                    + " failed: " + result.errNo;
+            Log.e(TAG, error + ", result: " + result);
+            fail(error);
         }
     }