Merge "Dump tcpdump on failure in testUidTagStateDetails" into main
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/ConnectivityDiagnosticsCollector.kt b/staticlibs/testutils/devicetests/com/android/testutils/ConnectivityDiagnosticsCollector.kt
index e5b8471..0624e5f 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/ConnectivityDiagnosticsCollector.kt
+++ b/staticlibs/testutils/devicetests/com/android/testutils/ConnectivityDiagnosticsCollector.kt
@@ -43,9 +43,14 @@
 import androidx.test.platform.app.InstrumentationRegistry
 import com.android.modules.utils.build.SdkLevel.isAtLeastS
 import java.io.ByteArrayOutputStream
+import java.io.CharArrayWriter
 import java.io.File
 import java.io.FileOutputStream
+import java.io.FileReader
+import java.io.OutputStream
+import java.io.OutputStreamWriter
 import java.io.PrintWriter
+import java.io.Reader
 import java.time.ZonedDateTime
 import java.time.format.DateTimeFormatter
 import java.util.concurrent.CompletableFuture
@@ -80,7 +85,38 @@
         var instance: ConnectivityDiagnosticsCollector? = null
     }
 
+    /**
+     * Indicates tcpdump should be started and written to the diagnostics file on test case failure.
+     */
+    annotation class CollectTcpdumpOnFailure
+
+    private class DumpThread(
+        // Keep a reference to the ParcelFileDescriptor otherwise GC would close it
+        private val fd: ParcelFileDescriptor,
+        private val reader: Reader
+    ) : Thread() {
+        private val writer = CharArrayWriter()
+        override fun run() {
+            reader.copyTo(writer)
+        }
+
+        fun closeAndWriteTo(output: OutputStream?) {
+            join()
+            fd.close()
+            if (output != null) {
+                val outputWriter = OutputStreamWriter(output)
+                outputWriter.write("--- tcpdump stopped at ${ZonedDateTime.now()} ---\n")
+                writer.writeTo(outputWriter)
+            }
+        }
+    }
+
+    private data class TcpdumpRun(val pid: Int, val reader: DumpThread)
+
     private var failureHeader: String? = null
+
+    // Accessed from the test listener methods which are synchronized by junit (see TestListener)
+    private var tcpdumpRun: TcpdumpRun? = null
     private val buffer = ByteArrayOutputStream()
     private val failureHeaderExtras = mutableMapOf<String, Any>()
     private val collectorDir: File by lazy {
@@ -157,7 +193,57 @@
         flushBufferToFileMetric(testData, baseFilename)
     }
 
+    override fun onTestStart(testData: DataRecord, description: Description) {
+        val tcpdumpAnn = description.annotations.firstOrNull { it is CollectTcpdumpOnFailure }
+                as? CollectTcpdumpOnFailure
+        if (tcpdumpAnn != null) {
+            startTcpdumpForTestcaseIfSupported()
+        }
+    }
+
+    private fun startTcpdumpForTestcaseIfSupported() {
+        if (!DeviceInfoUtils.isDebuggable()) {
+            Log.d(TAG, "Cannot start tcpdump, build is not debuggable")
+            return
+        }
+        if (tcpdumpRun != null) {
+            Log.e(TAG, "Cannot start tcpdump: it is already running")
+            return
+        }
+        // executeShellCommand won't tokenize quoted arguments containing spaces (like pcap filters)
+        // properly, so pass in the command in stdin instead of using sh -c 'command'
+        val fds = instrumentation.uiAutomation.executeShellCommandRw("sh")
+
+        val stdout = fds[0]
+        val stdin = fds[1]
+        ParcelFileDescriptor.AutoCloseOutputStream(stdin).use { writer ->
+            // Echo the current pid, and replace it (with exec) with the tcpdump process, so the
+            // tcpdump pid is known.
+            writer.write(
+                "echo $$; exec su 0 tcpdump -n -i any -U -xx".encodeToByteArray()
+            )
+        }
+        val reader = FileReader(stdout.fileDescriptor).buffered()
+        val tcpdumpPid = Integer.parseInt(reader.readLine())
+        val dumpThread = DumpThread(stdout, reader)
+        dumpThread.start()
+        tcpdumpRun = TcpdumpRun(tcpdumpPid, dumpThread)
+    }
+
+    private fun stopTcpdumpIfRunning(output: OutputStream?) {
+        val run = tcpdumpRun ?: return
+        // Send SIGTERM for graceful shutdown of tcpdump so that it can flush its output
+        executeCommandBlocking("su 0 kill ${run.pid}")
+        run.reader.closeAndWriteTo(output)
+        tcpdumpRun = null
+    }
+
     override fun onTestEnd(testData: DataRecord, description: Description) {
+        // onTestFail is called before onTestEnd, so if the test failed tcpdump would already have
+        // been stopped and output dumped. Here this stops tcpdump if the test succeeded, throwing
+        // away its output.
+        stopTcpdumpIfRunning(output = null)
+
         // Tests may call methods like collectDumpsysConnectivity to collect diagnostics at any time
         // during the run, for example to observe state at various points to investigate a flake
         // and compare passing/failing cases.
@@ -196,6 +282,7 @@
                 fos.write("\n".toByteArray())
             }
             fos.write(buffer.toByteArray())
+            stopTcpdumpIfRunning(fos)
         }
         failureHeader = null
         buffer.reset()
diff --git a/tests/cts/net/src/android/net/cts/NetworkStatsManagerTest.java b/tests/cts/net/src/android/net/cts/NetworkStatsManagerTest.java
index e3d7240..005f6ad 100644
--- a/tests/cts/net/src/android/net/cts/NetworkStatsManagerTest.java
+++ b/tests/cts/net/src/android/net/cts/NetworkStatsManagerTest.java
@@ -75,6 +75,7 @@
 import com.android.compatibility.common.util.SystemUtil;
 import com.android.modules.utils.build.SdkLevel;
 import com.android.testutils.AutoReleaseNetworkCallbackRule;
+import com.android.testutils.ConnectivityDiagnosticsCollector;
 import com.android.testutils.ConnectivityModuleTest;
 import com.android.testutils.DevSdkIgnoreRule;
 import com.android.testutils.DevSdkIgnoreRunner;
@@ -759,6 +760,7 @@
                 bucket.getRxBytes(), bucket.getTxBytes()));
     }
 
+    @ConnectivityDiagnosticsCollector.CollectTcpdumpOnFailure
     @Test
     public void testUidTagStateDetails() throws Exception {
         for (int i = 0; i < mNetworkInterfacesToTest.length; ++i) {