Add MdnsProber

MdnsProber is an implementation of MdnsPacketRepeater that will be used
to send probes for service names before advertising them, to know if
they are already in use.

Bug: 241738458
Test: atest
Change-Id: I4e5f779b891e2c665ba7f752fb5fbd4255070725
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsProber.java b/service/mdns/com/android/server/connectivity/mdns/MdnsProber.java
new file mode 100644
index 0000000..db7049e
--- /dev/null
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsProber.java
@@ -0,0 +1,156 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivity.mdns;
+
+import android.annotation.NonNull;
+import android.os.Looper;
+
+import com.android.internal.annotations.VisibleForTesting;
+import com.android.net.module.util.CollectionUtils;
+
+import java.net.SocketAddress;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.function.Supplier;
+
+/**
+ * Sends mDns probe requests to verify service records are unique on the network.
+ *
+ * TODO: implement receiving replies and handling conflicts.
+ */
+public class MdnsProber extends MdnsPacketRepeater<MdnsProber.ProbingInfo> {
+    @NonNull
+    private final String mLogTag;
+
+    public MdnsProber(@NonNull String interfaceTag, @NonNull Looper looper,
+            @NonNull MdnsReplySender replySender,
+            @NonNull PacketRepeaterCallback<ProbingInfo> cb) {
+        // 3 packets as per https://datatracker.ietf.org/doc/html/rfc6762#section-8.1
+        super(looper, replySender, cb);
+        mLogTag = MdnsProber.class.getSimpleName() + "/" + interfaceTag;
+    }
+
+    static class ProbingInfo implements Request {
+
+        private final int mServiceId;
+        @NonNull
+        private final MdnsPacket mPacket;
+        @NonNull
+        private final Supplier<Iterable<SocketAddress>> mDestinationsSupplier;
+
+        /**
+         * Create a new ProbingInfo
+         * @param serviceId Service to probe for.
+         * @param probeRecords Records to be probed for uniqueness.
+         * @param destinationsSupplier Supplier for the probe destinations. Will be called on the
+         *                             probe handler thread for each probe.
+         */
+        ProbingInfo(int serviceId, @NonNull List<MdnsRecord> probeRecords,
+                @NonNull Supplier<Iterable<SocketAddress>> destinationsSupplier) {
+            mServiceId = serviceId;
+            mPacket = makePacket(probeRecords);
+            mDestinationsSupplier = destinationsSupplier;
+        }
+
+        public int getServiceId() {
+            return mServiceId;
+        }
+
+        @NonNull
+        @Override
+        public MdnsPacket getPacket(int index) {
+            return mPacket;
+        }
+
+        @NonNull
+        @Override
+        public Iterable<SocketAddress> getDestinations(int index) {
+            return mDestinationsSupplier.get();
+        }
+
+        @Override
+        public long getDelayMs(int nextIndex) {
+            // As per https://datatracker.ietf.org/doc/html/rfc6762#section-8.1
+            return 250L;
+        }
+
+        @Override
+        public int getNumSends() {
+            // 3 packets as per https://datatracker.ietf.org/doc/html/rfc6762#section-8.1
+            return 3;
+        }
+
+        private static MdnsPacket makePacket(@NonNull List<MdnsRecord> records) {
+            final ArrayList<MdnsRecord> questions = new ArrayList<>(records.size());
+            for (final MdnsRecord record : records) {
+                if (containsName(questions, record.getName())) {
+                    // Already added this name
+                    continue;
+                }
+
+                // TODO: legacy Android mDNS used to send the first probe (only) as unicast, even
+                //  though https://datatracker.ietf.org/doc/html/rfc6762#section-8.1 says they
+                // SHOULD all be. rfc6762 15.1 says that if the port is shared with another
+                // responder unicast questions should not be used, and the legacy mdnsresponder may
+                // be running, so not using unicast at all may be better. Consider using legacy
+                // behavior if this causes problems.
+                questions.add(new MdnsAnyRecord(record.getName(), false /* unicast */));
+            }
+
+            return new MdnsPacket(
+                    MdnsConstants.FLAGS_QUERY,
+                    questions,
+                    Collections.emptyList() /* answers */,
+                    records /* authorityRecords */,
+                    Collections.emptyList() /* additionalRecords */);
+        }
+
+        /**
+         * Return whether the specified name is present in the list of records.
+         */
+        private static boolean containsName(@NonNull List<MdnsRecord> records,
+                @NonNull String[] name) {
+            return CollectionUtils.any(records, r -> Arrays.equals(name, r.getName()));
+        }
+    }
+
+    @NonNull
+    @Override
+    protected String getTag() {
+        return mLogTag;
+    }
+
+    @VisibleForTesting
+    protected long getInitialDelay() {
+        // First wait for a random time in 0-250ms
+        // as per https://datatracker.ietf.org/doc/html/rfc6762#section-8.1
+        return (long) (Math.random() * 250);
+    }
+
+    /**
+     * Start sending packets for probing.
+     */
+    public void startProbing(@NonNull ProbingInfo info) {
+        startProbing(info, getInitialDelay());
+    }
+
+    private void startProbing(@NonNull ProbingInfo info, long delay) {
+        startSending(info.getServiceId(), info, delay);
+    }
+}
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsRecord.java b/service/mdns/com/android/server/connectivity/mdns/MdnsRecord.java
index 10b8825..00871ea 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsRecord.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsRecord.java
@@ -201,6 +201,17 @@
     protected abstract void readData(MdnsPacketReader reader) throws IOException;
 
     /**
+     * Write the first fields of the record, which are common fields for questions and answers.
+     *
+     * @param writer The writer to use.
+     */
+    public final void writeHeaderFields(MdnsPacketWriter writer) throws IOException {
+        writer.writeLabels(name);
+        writer.writeUInt16(type);
+        writer.writeUInt16(cls);
+    }
+
+    /**
      * Writes the record to a packet.
      *
      * @param writer The writer to use.
@@ -208,9 +219,7 @@
      */
     @VisibleForTesting
     public final void write(MdnsPacketWriter writer, long now) throws IOException {
-        writer.writeLabels(name);
-        writer.writeUInt16(type);
-        writer.writeUInt16(cls);
+        writeHeaderFields(writer);
 
         writer.writeUInt32(MILLISECONDS.toSeconds(getRemainingTTL(now)));
 
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsReplySender.java b/service/mdns/com/android/server/connectivity/mdns/MdnsReplySender.java
index 2acd789..1fdbc5c 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsReplySender.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsReplySender.java
@@ -67,7 +67,8 @@
         writer.writeUInt16(packet.additionalRecords.size()); // additional records count
 
         for (MdnsRecord record : packet.questions) {
-            record.write(writer, 0L);
+            // Questions do not have TTL or data
+            record.writeHeaderFields(writer);
         }
         for (MdnsRecord record : packet.answers) {
             record.write(writer, 0L);
diff --git a/tests/unit/Android.bp b/tests/unit/Android.bp
index 8ed735a..209430a 100644
--- a/tests/unit/Android.bp
+++ b/tests/unit/Android.bp
@@ -74,6 +74,7 @@
         "java/com/android/server/connectivity/VpnTest.java",
         "java/com/android/server/net/ipmemorystore/*.java",
         "java/com/android/server/connectivity/mdns/**/*.java",
+        "java/com/android/server/connectivity/mdns/**/*.kt",
     ]
 }
 
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt
new file mode 100644
index 0000000..cc75191
--- /dev/null
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt
@@ -0,0 +1,201 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivity.mdns
+
+import android.os.Build
+import android.os.Handler
+import android.os.HandlerThread
+import android.os.Looper
+import com.android.internal.util.HexDump
+import com.android.server.connectivity.mdns.MdnsProber.ProbingInfo
+import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
+import com.android.testutils.DevSdkIgnoreRunner
+import java.net.DatagramPacket
+import java.net.InetSocketAddress
+import java.net.MulticastSocket
+import java.util.concurrent.CompletableFuture
+import java.util.concurrent.TimeUnit
+import kotlin.test.assertEquals
+import kotlin.test.assertTrue
+import org.junit.After
+import org.junit.Before
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.mockito.ArgumentCaptor
+import org.mockito.Mockito.any
+import org.mockito.Mockito.atLeast
+import org.mockito.Mockito.mock
+import org.mockito.Mockito.never
+import org.mockito.Mockito.timeout
+import org.mockito.Mockito.times
+import org.mockito.Mockito.verify
+
+private val destinationsSupplier = {
+    listOf(InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT)) }
+
+private const val TEST_TIMEOUT_MS = 10_000L
+private const val SHORT_TIMEOUT_MS = 200L
+
+private val TEST_SERVICE_NAME_1 = arrayOf("testservice", "_nmt", "_tcp", "local")
+private val TEST_SERVICE_NAME_2 = arrayOf("testservice2", "_nmt", "_tcp", "local")
+
+@RunWith(DevSdkIgnoreRunner::class)
+@IgnoreUpTo(Build.VERSION_CODES.S_V2)
+class MdnsProberTest {
+    private val thread = HandlerThread(MdnsProberTest::class.simpleName)
+    private val socket = mock(MulticastSocket::class.java)
+    @Suppress("UNCHECKED_CAST")
+    private val cb = mock(MdnsPacketRepeater.PacketRepeaterCallback::class.java)
+        as MdnsPacketRepeater.PacketRepeaterCallback<ProbingInfo>
+    private val buffer = ByteArray(1500)
+
+    @Before
+    fun setUp() {
+        thread.start()
+    }
+
+    @After
+    fun tearDown() {
+        thread.quitSafely()
+    }
+
+    private class TestProbeInfo(probeRecords: List<MdnsRecord>, private val delayMs: Long = 1L) :
+            ProbingInfo(1 /* serviceId */, probeRecords, destinationsSupplier) {
+        // Just send the packets quickly. Timing-related tests for MdnsPacketRepeater are already
+        // done in MdnsAnnouncerTest.
+        override fun getDelayMs(nextIndex: Int) = delayMs
+    }
+
+    private class TestProber(
+        looper: Looper,
+        replySender: MdnsReplySender,
+        cb: PacketRepeaterCallback<ProbingInfo>
+    ) : MdnsProber("testiface", looper, replySender, cb) {
+        override fun getInitialDelay() = 0L
+    }
+
+    private fun assertProbesSent(probeInfo: TestProbeInfo, expectedHex: String) {
+        repeat(probeInfo.numSends) { i ->
+            verify(cb, timeout(TEST_TIMEOUT_MS)).onSent(i, probeInfo)
+            // If the probe interval is short, more than (i+1) probes may have been sent already
+            verify(socket, atLeast(i + 1)).send(any())
+        }
+
+        val captor = ArgumentCaptor.forClass(DatagramPacket::class.java)
+        // There should be exactly numSends probes sent at the end
+        verify(socket, times(probeInfo.numSends)).send(captor.capture())
+
+        captor.allValues.forEach {
+            assertEquals(expectedHex, HexDump.toHexString(it.data))
+        }
+        verify(cb, timeout(TEST_TIMEOUT_MS)).onFinished(probeInfo)
+    }
+
+    private fun makeServiceRecord(name: Array<String>, port: Int) = MdnsServiceRecord(
+            name,
+            0L /* receiptTimeMillis */,
+            false /* cacheFlush */,
+            120_000L /* ttlMillis */,
+            0 /* servicePriority */,
+            0 /* serviceWeight */,
+            port,
+            arrayOf("myhostname", "local"))
+
+    @Test
+    fun testProbe() {
+        val replySender = MdnsReplySender(thread.looper, socket, buffer)
+        val prober = TestProber(thread.looper, replySender, cb)
+        val probeInfo = TestProbeInfo(
+                listOf(makeServiceRecord(TEST_SERVICE_NAME_1, 37890)))
+        prober.startProbing(probeInfo)
+
+        // Inspect with python3:
+        // import scapy.all as scapy; scapy.DNS(bytes.fromhex('[bytes]')).show2()
+        val expected = "0000000000010000000100000B7465737473657276696365045F6E6D74045F746370056C" +
+                "6F63616C0000FF0001C00C002100010000007800130000000094020A6D79686F73746E616D65C022"
+        assertProbesSent(probeInfo, expected)
+    }
+
+    @Test
+    fun testProbeMultipleRecords() {
+        val replySender = MdnsReplySender(thread.looper, socket, buffer)
+        val prober = TestProber(thread.looper, replySender, cb)
+        val probeInfo = TestProbeInfo(listOf(
+                makeServiceRecord(TEST_SERVICE_NAME_1, 37890),
+                makeServiceRecord(TEST_SERVICE_NAME_2, 37891),
+                MdnsTextRecord(
+                        // Same name as the first record; there should not be 2 duplicated questions
+                        TEST_SERVICE_NAME_1,
+                        0L /* receiptTimeMillis */,
+                        false /* cacheFlush */,
+                        120_000L /* ttlMillis */,
+                        listOf(MdnsServiceInfo.TextEntry("testKey", "testValue")))))
+        prober.startProbing(probeInfo)
+
+        /*
+        Expected data obtained with:
+        scapy.raw(scapy.dns_compress(scapy.DNS(rd=0,
+            qd =
+                scapy.DNSQR(qname='testservice._nmt._tcp.local.', qtype='ALL') /
+                scapy.DNSQR(qname='testservice2._nmt._tcp.local.', qtype='ALL'),
+            ns=
+                scapy.DNSRRSRV(rrname='testservice._nmt._tcp.local.', type='SRV', ttl=120,
+                    port=37890, target='myhostname.local.') /
+                scapy.DNSRRSRV(rrname='testservice2._nmt._tcp.local.', type='SRV', ttl=120,
+                    port=37891, target='myhostname.local.') /
+                scapy.DNSRR(type='TXT', ttl=120, rrname='testservice._nmt._tcp.local.',
+                    rdata='testKey=testValue'))
+        )).hex().upper()
+        // NOTE: due to a bug the second "myhostname" is not getting DNS compressed in the current
+        // actual probe, so data below is slightly different. Fix compression so it gets compressed.
+         */
+        val expected = "0000000000020000000300000B7465737473657276696365045F6E6D74045F746370056C6" +
+                "F63616C0000FF00010C746573747365727669636532C01800FF0001C00C002100010000007800130" +
+                "000000094020A6D79686F73746E616D65C0220C746573747365727669636532C0180021000100000" +
+                "07800130000000094030A6D79686F73746E616D65C022C00C0010000100000078001211746573744" +
+                "B65793D7465737456616C7565"
+        assertProbesSent(probeInfo, expected)
+    }
+
+    @Test
+    fun testStopProbing() {
+        val replySender = MdnsReplySender(thread.looper, socket, buffer)
+        val prober = TestProber(thread.looper, replySender, cb)
+        val probeInfo = TestProbeInfo(
+                listOf(makeServiceRecord(TEST_SERVICE_NAME_1, 37890)),
+                // delayMs is the delay between each probe, so does not apply to the first one
+                delayMs = SHORT_TIMEOUT_MS)
+        prober.startProbing(probeInfo)
+
+        // Expect the initial probe
+        verify(cb, timeout(TEST_TIMEOUT_MS)).onSent(0, probeInfo)
+
+        // Stop probing
+        val stopResult = CompletableFuture<Boolean>()
+        Handler(thread.looper).post { stopResult.complete(prober.stop(probeInfo.serviceId)) }
+        assertTrue(stopResult.get(TEST_TIMEOUT_MS, TimeUnit.MILLISECONDS),
+                "stop should return true when probing was in progress")
+
+        // Wait for a bit (more than the probe delay) to ensure no more probes were sent
+        Thread.sleep(SHORT_TIMEOUT_MS * 2)
+        verify(cb, never()).onSent(1, probeInfo)
+        verify(cb, never()).onFinished(probeInfo)
+
+        // Only one sent packet
+        verify(socket, times(1)).send(any())
+    }
+}