Merge "Uses identical hostName across all interface"
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java
index 977478a..ec3e997 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java
@@ -31,6 +31,7 @@
 
 import java.util.List;
 import java.util.Map;
+import java.util.UUID;
 import java.util.function.BiPredicate;
 import java.util.function.Consumer;
 
@@ -43,6 +44,9 @@
     private static final String TAG = MdnsAdvertiser.class.getSimpleName();
     static final boolean DBG = Log.isLoggable(TAG, Log.DEBUG);
 
+    // Top-level domain for link-local queries, as per RFC6762 3.
+    private static final String LOCAL_TLD = "local";
+
     private final Looper mLooper;
     private final AdvertiserCallback mCb;
 
@@ -60,6 +64,8 @@
     private final SparseArray<Registration> mRegistrations = new SparseArray<>();
     private final Dependencies mDeps;
 
+    private String[] mDeviceHostName;
+
     /**
      * Dependencies for {@link MdnsAdvertiser}, useful for testing.
      */
@@ -71,11 +77,32 @@
         public MdnsInterfaceAdvertiser makeAdvertiser(@NonNull MdnsInterfaceSocket socket,
                 @NonNull List<LinkAddress> initialAddresses,
                 @NonNull Looper looper, @NonNull byte[] packetCreationBuffer,
-                @NonNull MdnsInterfaceAdvertiser.Callback cb) {
+                @NonNull MdnsInterfaceAdvertiser.Callback cb,
+                @NonNull String[] deviceHostName) {
             // Note NetworkInterface is final and not mockable
             final String logTag = socket.getInterface().getName();
             return new MdnsInterfaceAdvertiser(logTag, socket, initialAddresses, looper,
-                    packetCreationBuffer, cb);
+                    packetCreationBuffer, cb, deviceHostName);
+        }
+
+        /**
+         * Generates a unique hostname to be used by the device.
+         */
+        @NonNull
+        public String[] generateHostname() {
+            // Generate a very-probably-unique hostname. This allows minimizing possible conflicts
+            // to the point that probing for it is no longer necessary (as per RFC6762 8.1 last
+            // paragraph), and does not leak more information than what could already be obtained by
+            // looking at the mDNS packets source address.
+            // This differs from historical behavior that just used "Android.local" for many
+            // devices, creating a lot of conflicts.
+            // Having a different hostname per interface is an acceptable option as per RFC6762 14.
+            // This hostname will change every time the interface is reconnected, so this does not
+            // allow tracking the device.
+            // TODO: consider deriving a hostname from other sources, such as the IPv6 addresses
+            // (reusing the same privacy-protecting mechanics).
+            return new String[] {
+                    "Android_" + UUID.randomUUID().toString().replace("-", ""), LOCAL_TLD };
         }
     }
 
@@ -260,7 +287,7 @@
             MdnsInterfaceAdvertiser advertiser = mAllAdvertisers.get(socket);
             if (advertiser == null) {
                 advertiser = mDeps.makeAdvertiser(socket, addresses, mLooper, mPacketCreationBuffer,
-                        mInterfaceAdvertiserCb);
+                        mInterfaceAdvertiserCb, mDeviceHostName);
                 mAllAdvertisers.put(socket, advertiser);
                 advertiser.start();
             }
@@ -389,6 +416,7 @@
         mCb = cb;
         mSocketProvider = socketProvider;
         mDeps = deps;
+        mDeviceHostName = deps.generateHostname();
     }
 
     private void checkThread() {
@@ -453,6 +481,10 @@
             advertiser.removeService(id);
         }
         mRegistrations.remove(id);
+        // Regenerates host name when registrations removed.
+        if (mRegistrations.size() == 0) {
+            mDeviceHostName = mDeps.generateHostname();
+        }
     }
 
     private static <K, V> boolean any(@NonNull ArrayMap<K, V> map,
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
index c616e01..79cddce 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
@@ -141,8 +141,9 @@
     public static class Dependencies {
         /** @see MdnsRecordRepository */
         @NonNull
-        public MdnsRecordRepository makeRecordRepository(@NonNull Looper looper) {
-            return new MdnsRecordRepository(looper);
+        public MdnsRecordRepository makeRecordRepository(@NonNull Looper looper,
+                @NonNull String[] deviceHostName) {
+            return new MdnsRecordRepository(looper, deviceHostName);
         }
 
         /** @see MdnsReplySender */
@@ -169,17 +170,18 @@
 
     public MdnsInterfaceAdvertiser(@NonNull String logTag,
             @NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> initialAddresses,
-            @NonNull Looper looper, @NonNull byte[] packetCreationBuffer, @NonNull Callback cb) {
+            @NonNull Looper looper, @NonNull byte[] packetCreationBuffer, @NonNull Callback cb,
+            @NonNull String[] deviceHostName) {
         this(logTag, socket, initialAddresses, looper, packetCreationBuffer, cb,
-                new Dependencies());
+                new Dependencies(), deviceHostName);
     }
 
     public MdnsInterfaceAdvertiser(@NonNull String logTag,
             @NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> initialAddresses,
             @NonNull Looper looper, @NonNull byte[] packetCreationBuffer, @NonNull Callback cb,
-            @NonNull Dependencies deps) {
+            @NonNull Dependencies deps, @NonNull String[] deviceHostName) {
         mTag = MdnsInterfaceAdvertiser.class.getSimpleName() + "/" + logTag;
-        mRecordRepository = deps.makeRecordRepository(looper);
+        mRecordRepository = deps.makeRecordRepository(looper, deviceHostName);
         mRecordRepository.updateAddresses(initialAddresses);
         mSocket = socket;
         mCb = cb;
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
index e975ab4..1329172 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
@@ -47,7 +47,6 @@
 import java.util.Random;
 import java.util.Set;
 import java.util.TreeMap;
-import java.util.UUID;
 import java.util.concurrent.TimeUnit;
 
 /**
@@ -90,15 +89,16 @@
     @NonNull
     private final Looper mLooper;
     @NonNull
-    private String[] mDeviceHostname;
+    private final String[] mDeviceHostname;
 
-    public MdnsRecordRepository(@NonNull Looper looper) {
-        this(looper, new Dependencies());
+    public MdnsRecordRepository(@NonNull Looper looper, @NonNull String[] deviceHostname) {
+        this(looper, new Dependencies(), deviceHostname);
     }
 
     @VisibleForTesting
-    public MdnsRecordRepository(@NonNull Looper looper, @NonNull Dependencies deps) {
-        mDeviceHostname = deps.getHostname();
+    public MdnsRecordRepository(@NonNull Looper looper, @NonNull Dependencies deps,
+            @NonNull String[] deviceHostname) {
+        mDeviceHostname = deviceHostname;
         mLooper = looper;
     }
 
@@ -107,25 +107,6 @@
      */
     @VisibleForTesting
     public static class Dependencies {
-        /**
-         * Get a unique hostname to be used by the device.
-         */
-        @NonNull
-        public String[] getHostname() {
-            // Generate a very-probably-unique hostname. This allows minimizing possible conflicts
-            // to the point that probing for it is no longer necessary (as per RFC6762 8.1 last
-            // paragraph), and does not leak more information than what could already be obtained by
-            // looking at the mDNS packets source address.
-            // This differs from historical behavior that just used "Android.local" for many
-            // devices, creating a lot of conflicts.
-            // Having a different hostname per interface is an acceptable option as per RFC6762 14.
-            // This hostname will change every time the interface is reconnected, so this does not
-            // allow tracking the device.
-            // TODO: consider deriving a hostname from other sources, such as the IPv6 addresses
-            // (reusing the same privacy-protecting mechanics).
-            return new String[] {
-                    "Android_" + UUID.randomUUID().toString().replace("-", ""), LOCAL_TLD };
-        }
 
         /**
          * @see NetworkInterface#getInetAddresses().
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
index 1febe6d..375c150 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
@@ -42,6 +42,7 @@
 import org.mockito.Mockito.doReturn
 import org.mockito.Mockito.mock
 import org.mockito.Mockito.never
+import org.mockito.Mockito.times
 import org.mockito.Mockito.verify
 
 private const val SERVICE_ID_1 = 1
@@ -51,6 +52,7 @@
 private val TEST_LINKADDR = LinkAddress(TEST_ADDR, 64 /* prefixLength */)
 private val TEST_NETWORK_1 = mock(Network::class.java)
 private val TEST_NETWORK_2 = mock(Network::class.java)
+private val TEST_HOSTNAME = arrayOf("Android_test", "local")
 
 private val SERVICE_1 = NsdServiceInfo("TestServiceName", "_advertisertest._tcp").apply {
     port = 12345
@@ -81,10 +83,13 @@
     @Before
     fun setUp() {
         thread.start()
+        doReturn(TEST_HOSTNAME).`when`(mockDeps).generateHostname()
         doReturn(mockInterfaceAdvertiser1).`when`(mockDeps).makeAdvertiser(eq(mockSocket1),
-                any(), any(), any(), any())
+                any(), any(), any(), any(), eq(TEST_HOSTNAME)
+        )
         doReturn(mockInterfaceAdvertiser2).`when`(mockDeps).makeAdvertiser(eq(mockSocket2),
-                any(), any(), any(), any())
+                any(), any(), any(), any(), eq(TEST_HOSTNAME)
+        )
         doReturn(true).`when`(mockInterfaceAdvertiser1).isProbing(anyInt())
         doReturn(true).`when`(mockInterfaceAdvertiser2).isProbing(anyInt())
     }
@@ -106,8 +111,14 @@
         postSync { socketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR)) }
 
         val intAdvCbCaptor = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
-        verify(mockDeps).makeAdvertiser(eq(mockSocket1),
-                eq(listOf(TEST_LINKADDR)), eq(thread.looper), any(), intAdvCbCaptor.capture())
+        verify(mockDeps).makeAdvertiser(
+            eq(mockSocket1),
+            eq(listOf(TEST_LINKADDR)),
+            eq(thread.looper),
+            any(),
+            intAdvCbCaptor.capture(),
+            eq(TEST_HOSTNAME)
+        )
 
         doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_1)
         postSync { intAdvCbCaptor.value.onRegisterServiceSucceeded(
@@ -134,9 +145,11 @@
         val intAdvCbCaptor1 = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
         val intAdvCbCaptor2 = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
         verify(mockDeps).makeAdvertiser(eq(mockSocket1), eq(listOf(TEST_LINKADDR)),
-                eq(thread.looper), any(), intAdvCbCaptor1.capture())
+                eq(thread.looper), any(), intAdvCbCaptor1.capture(), eq(TEST_HOSTNAME)
+        )
         verify(mockDeps).makeAdvertiser(eq(mockSocket2), eq(listOf(TEST_LINKADDR)),
-                eq(thread.looper), any(), intAdvCbCaptor2.capture())
+                eq(thread.looper), any(), intAdvCbCaptor2.capture(), eq(TEST_HOSTNAME)
+        )
 
         doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_1)
         postSync { intAdvCbCaptor1.value.onRegisterServiceSucceeded(
@@ -192,7 +205,8 @@
 
         val intAdvCbCaptor = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
         verify(mockDeps).makeAdvertiser(eq(mockSocket1), eq(listOf(TEST_LINKADDR)),
-                eq(thread.looper), any(), intAdvCbCaptor.capture())
+                eq(thread.looper), any(), intAdvCbCaptor.capture(), eq(TEST_HOSTNAME)
+        )
         verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_1),
                 argThat { it.matches(SERVICE_1) })
         verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_2),
@@ -216,6 +230,15 @@
         verify(mockInterfaceAdvertiser1, atLeastOnce()).destroyNow()
     }
 
+    @Test
+    fun testRemoveService_whenAllServiceRemoved_thenUpdateHostName() {
+        val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps)
+        verify(mockDeps, times(1)).generateHostname()
+        postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1) }
+        postSync { advertiser.removeService(SERVICE_ID_1) }
+        verify(mockDeps, times(2)).generateHostname()
+    }
+
     private fun postSync(r: () -> Unit) {
         handler.post(r)
         handler.waitForIdle(TIMEOUT_MS)
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
index 4a806b1..2d8d8f3 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
@@ -55,6 +55,7 @@
 
 private val TEST_ADDRS = listOf(LinkAddress(parseNumericAddress("2001:db8::123"), 64))
 private val TEST_BUFFER = ByteArray(1300)
+private val TEST_HOSTNAME = arrayOf("Android_test", "local")
 
 private const val TEST_SERVICE_ID_1 = 42
 private val TEST_SERVICE_1 = NsdServiceInfo().apply {
@@ -88,12 +89,23 @@
     private val packetHandler get() = packetHandlerCaptor.value
 
     private val advertiser by lazy {
-        MdnsInterfaceAdvertiser(LOG_TAG, socket, TEST_ADDRS, thread.looper, TEST_BUFFER, cb, deps)
+        MdnsInterfaceAdvertiser(
+            LOG_TAG,
+            socket,
+            TEST_ADDRS,
+            thread.looper,
+            TEST_BUFFER,
+            cb,
+            deps,
+            TEST_HOSTNAME
+        )
     }
 
     @Before
     fun setUp() {
-        doReturn(repository).`when`(deps).makeRecordRepository(any())
+        doReturn(repository).`when`(deps).makeRecordRepository(any(),
+            eq(TEST_HOSTNAME)
+        )
         doReturn(replySender).`when`(deps).makeReplySender(anyString(), any(), any(), any())
         doReturn(announcer).`when`(deps).makeMdnsAnnouncer(anyString(), any(), any(), any())
         doReturn(prober).`when`(deps).makeMdnsProber(anyString(), any(), any(), any())
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
index ecc11ec..5665091 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
@@ -67,7 +67,6 @@
 class MdnsRecordRepositoryTest {
     private val thread = HandlerThread(MdnsRecordRepositoryTest::class.simpleName)
     private val deps = object : Dependencies() {
-        override fun getHostname() = TEST_HOSTNAME
         override fun getInterfaceInetAddresses(iface: NetworkInterface) =
                 Collections.enumeration(TEST_ADDRESSES.map { it.address })
     }
@@ -84,7 +83,7 @@
 
     @Test
     fun testAddServiceAndProbe() {
-        val repository = MdnsRecordRepository(thread.looper, deps)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
         assertEquals(0, repository.servicesCount)
         assertEquals(-1, repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1))
         assertEquals(1, repository.servicesCount)
@@ -117,7 +116,7 @@
 
     @Test
     fun testAddAndConflicts() {
-        val repository = MdnsRecordRepository(thread.looper, deps)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
         repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
         assertFailsWith(NameConflictException::class) {
             repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1)
@@ -126,7 +125,7 @@
 
     @Test
     fun testInvalidReuseOfServiceId() {
-        val repository = MdnsRecordRepository(thread.looper, deps)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
         repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
         assertFailsWith(IllegalArgumentException::class) {
             repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_2)
@@ -135,7 +134,7 @@
 
     @Test
     fun testHasActiveService() {
-        val repository = MdnsRecordRepository(thread.looper, deps)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
         assertFalse(repository.hasActiveService(TEST_SERVICE_ID_1))
 
         repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
@@ -152,7 +151,7 @@
 
     @Test
     fun testExitAnnouncements() {
-        val repository = MdnsRecordRepository(thread.looper, deps)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
         repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
         repository.onAdvertisementSent(TEST_SERVICE_ID_1)
 
@@ -181,7 +180,7 @@
 
     @Test
     fun testExitingServiceReAdded() {
-        val repository = MdnsRecordRepository(thread.looper, deps)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
         repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
         repository.onAdvertisementSent(TEST_SERVICE_ID_1)
         repository.exitService(TEST_SERVICE_ID_1)
@@ -195,7 +194,7 @@
 
     @Test
     fun testOnProbingSucceeded() {
-        val repository = MdnsRecordRepository(thread.looper, deps)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
         val announcementInfo = repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
         repository.onAdvertisementSent(TEST_SERVICE_ID_1)
         val packet = announcementInfo.getPacket(0)
@@ -319,7 +318,7 @@
 
     @Test
     fun testGetReply() {
-        val repository = MdnsRecordRepository(thread.looper, deps)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
         repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
         val questions = listOf(MdnsPointerRecord(arrayOf("_testservice", "_tcp", "local"),
                 0L /* receiptTimeMillis */,
@@ -404,7 +403,7 @@
 
     @Test
     fun testGetConflictingServices() {
-        val repository = MdnsRecordRepository(thread.looper, deps)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
         repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
         repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2)
 
@@ -432,7 +431,7 @@
 
     @Test
     fun testGetConflictingServices_IdenticalService() {
-        val repository = MdnsRecordRepository(thread.looper, deps)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
         repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
         repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2)