Merge "Update record receipt time on records updated"
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsResponse.java b/service-t/src/com/android/server/connectivity/mdns/MdnsResponse.java
index 28aa640..ec1e462 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsResponse.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsResponse.java
@@ -84,16 +84,15 @@
     private <T extends MdnsRecord> boolean addOrReplaceRecord(@NonNull T record,
             @NonNull List<T> recordsList) {
         final int existing = recordsList.indexOf(record);
+        boolean isSame = false;
         if (existing >= 0) {
-            if (recordsAreSame(record, recordsList.get(existing))) {
-                return false;
-            }
+            isSame = recordsAreSame(record, recordsList.get(existing));
             final MdnsRecord existedRecord = recordsList.remove(existing);
             records.remove(existedRecord);
         }
         recordsList.add(record);
         records.add(record);
-        return true;
+        return !isSame;
     }
 
     /**
@@ -163,9 +162,7 @@
 
     /** Sets the service record. */
     public synchronized boolean setServiceRecord(MdnsServiceRecord serviceRecord) {
-        if (recordsAreSame(this.serviceRecord, serviceRecord)) {
-            return false;
-        }
+        boolean isSame = recordsAreSame(this.serviceRecord, serviceRecord);
         if (this.serviceRecord != null) {
             records.remove(this.serviceRecord);
         }
@@ -173,7 +170,7 @@
         if (this.serviceRecord != null) {
             records.add(this.serviceRecord);
         }
-        return true;
+        return !isSame;
     }
 
     /** Gets the service record. */
@@ -187,9 +184,7 @@
 
     /** Sets the text record. */
     public synchronized boolean setTextRecord(MdnsTextRecord textRecord) {
-        if (recordsAreSame(this.textRecord, textRecord)) {
-            return false;
-        }
+        boolean isSame = recordsAreSame(this.textRecord, textRecord);
         if (this.textRecord != null) {
             records.remove(this.textRecord);
         }
@@ -197,7 +192,7 @@
         if (this.textRecord != null) {
             records.add(this.textRecord);
         }
-        return true;
+        return !isSame;
     }
 
     /** Gets the text record. */
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsResponseDecoder.java b/service-t/src/com/android/server/connectivity/mdns/MdnsResponseDecoder.java
index 77b5c58..42f6107 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsResponseDecoder.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsResponseDecoder.java
@@ -21,6 +21,7 @@
 import android.net.Network;
 import android.os.SystemClock;
 import android.util.ArraySet;
+import android.util.Pair;
 
 import com.android.server.connectivity.mdns.util.MdnsLogger;
 import com.android.server.connectivity.mdns.util.MdnsUtils;
@@ -120,9 +121,14 @@
      * @param interfaceIndex the network interface index (or
      * {@link MdnsSocket#INTERFACE_INDEX_UNSPECIFIED} if not known) at which the packet was received
      * @param network the network at which the packet was received, or null if it is unknown.
-     * @return The set of response instances that were modified or newly added.
+     * @return The pair of 1) set of response instances that were modified or newly added. *not*
+     *                      including those which records were only updated with newer receive
+     *                      timestamps.
+     *                     2) A copy of the original responses with some of them have records
+     *                     update or only contains receive time updated.
      */
-    public ArraySet<MdnsResponse> augmentResponses(@NonNull MdnsPacket mdnsPacket,
+    public Pair<ArraySet<MdnsResponse>, ArrayList<MdnsResponse>> augmentResponses(
+            @NonNull MdnsPacket mdnsPacket,
             @NonNull Collection<MdnsResponse> existingResponses, int interfaceIndex,
             @Nullable Network network) {
         final ArrayList<MdnsRecord> records = new ArrayList<>(
@@ -177,7 +183,6 @@
                                 network);
                         responses.add(response);
                     }
-
                     if (response.addPointerRecord((MdnsPointerRecord) record)) {
                         modified.add(response);
                     }
@@ -269,7 +274,7 @@
             }
         }
 
-        return modified;
+        return Pair.create(modified, responses);
     }
 
     private static boolean assignInetRecord(
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
index 809750d..14302c2 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
@@ -63,6 +63,7 @@
     private final Object lock = new Object();
     private final ArrayMap<MdnsServiceBrowserListener, MdnsSearchOptions> listeners =
             new ArrayMap<>();
+    // TODO: change instanceNameToResponse to TreeMap with case insensitive comparator.
     private final Map<String, MdnsResponse> instanceNameToResponse = new HashMap<>();
     private final boolean removeServiceAfterTtlExpires =
             MdnsConfigs.removeServiceAfterTtlExpires();
@@ -260,15 +261,32 @@
             // Augment the list of current known responses, and generated responses for resolve
             // requests if there is no known response
             final List<MdnsResponse> currentList = new ArrayList<>(instanceNameToResponse.values());
-            currentList.addAll(makeResponsesForResolveIfUnknown(interfaceIndex, network));
-            final ArraySet<MdnsResponse> modifiedResponses = responseDecoder.augmentResponses(
-                    packet, currentList, interfaceIndex, network);
 
-            for (MdnsResponse modified : modifiedResponses) {
-                if (modified.isGoodbye()) {
-                    onGoodbyeReceived(modified.getServiceInstanceName());
-                } else {
-                    onResponseModified(modified);
+            List<MdnsResponse> additionalResponses = makeResponsesForResolve(interfaceIndex,
+                    network);
+            for (MdnsResponse additionalResponse : additionalResponses) {
+                if (!instanceNameToResponse.containsKey(
+                        additionalResponse.getServiceInstanceName())) {
+                    currentList.add(additionalResponse);
+                }
+            }
+            final Pair<ArraySet<MdnsResponse>, ArrayList<MdnsResponse>> augmentedResult =
+                    responseDecoder.augmentResponses(packet, currentList, interfaceIndex, network);
+
+            final ArraySet<MdnsResponse> modifiedResponse = augmentedResult.first;
+            final ArrayList<MdnsResponse> allResponses = augmentedResult.second;
+
+            for (MdnsResponse response : allResponses) {
+                if (modifiedResponse.contains(response)) {
+                    if (response.isGoodbye()) {
+                        onGoodbyeReceived(response.getServiceInstanceName());
+                    } else {
+                        onResponseModified(response);
+                    }
+                } else if (instanceNameToResponse.containsKey(response.getServiceInstanceName())) {
+                    // If the response is not modified and already in the cache. The cache will
+                    // need to be updated to refresh the last receipt time.
+                    instanceNameToResponse.put(response.getServiceInstanceName(), response);
                 }
             }
         }
@@ -474,7 +492,7 @@
         }
     }
 
-    private List<MdnsResponse> makeResponsesForResolveIfUnknown(int interfaceIndex,
+    private List<MdnsResponse> makeResponsesForResolve(int interfaceIndex,
             @NonNull Network network) {
         final List<MdnsResponse> resolveResponses = new ArrayList<>();
         for (int i = 0; i < listeners.size(); i++) {
@@ -516,7 +534,7 @@
                 // queried to complete it.
                 // Only the names are used to know which queries to send, other parameters like
                 // interfaceIndex do not matter.
-                servicesToResolve = makeResponsesForResolveIfUnknown(
+                servicesToResolve = makeResponsesForResolve(
                         0 /* interfaceIndex */, config.network);
                 sendDiscoveryQueries = servicesToResolve.size() < listeners.size();
             }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseDecoderTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseDecoderTests.java
index e16c448..6eb83da 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseDecoderTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseDecoderTests.java
@@ -344,7 +344,7 @@
         final Network network = mock(Network.class);
         responses = decoder.augmentResponses(parsedPacket,
                 /* existingResponses= */ Collections.emptyList(),
-                /* interfaceIndex= */ 10, network /* expireOnExit= */);
+                /* interfaceIndex= */ 10, network /* expireOnExit= */).first;
 
         assertEquals(responses.size(), 1);
         assertEquals(responses.valueAt(0).getInterfaceIndex(), 10);
@@ -593,6 +593,6 @@
 
         return decoder.augmentResponses(parsedPacket,
                 existingResponses,
-                MdnsSocket.INTERFACE_INDEX_UNSPECIFIED, mock(Network.class));
+                MdnsSocket.INTERFACE_INDEX_UNSPECIFIED, mock(Network.class)).first;
     }
 }
\ No newline at end of file
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseTests.java
index dc0e646..3e189f1 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseTests.java
@@ -95,24 +95,24 @@
         }
     }
 
-    private MdnsResponse makeCompleteResponse(int recordsTtlMillis) {
+    private MdnsResponse makeCompleteResponse(int recordsTtlMillis, int receiptTimeMillis) {
         final String[] hostname = new String[] { "MyHostname" };
         final String[] serviceName = new String[] { "MyService", "_type", "_tcp", "local" };
         final String[] serviceType = new String[] { "_type", "_tcp", "local" };
         final MdnsResponse response = new MdnsResponse(/* now= */ 0, serviceName, INTERFACE_INDEX,
                 mNetwork);
-        response.addPointerRecord(new MdnsPointerRecord(serviceType, 0L /* receiptTimeMillis */,
+        response.addPointerRecord(new MdnsPointerRecord(serviceType, receiptTimeMillis,
                 false /* cacheFlush */, recordsTtlMillis, serviceName));
-        response.setServiceRecord(new MdnsServiceRecord(serviceName, 0L /* receiptTimeMillis */,
+        response.setServiceRecord(new MdnsServiceRecord(serviceName, receiptTimeMillis,
                 true /* cacheFlush */, recordsTtlMillis, 0 /* servicePriority */,
                 0 /* serviceWeight */, 0 /* servicePort */, hostname));
-        response.setTextRecord(new MdnsTextRecord(serviceName, 0L /* receiptTimeMillis */,
+        response.setTextRecord(new MdnsTextRecord(serviceName, receiptTimeMillis,
                 true /* cacheFlush */, recordsTtlMillis, emptyList() /* entries */));
         response.addInet4AddressRecord(new MdnsInetAddressRecord(
-                hostname, 0L /* receiptTimeMillis */, true /* cacheFlush */,
+                hostname, receiptTimeMillis, true /* cacheFlush */,
                 recordsTtlMillis, parseNumericAddress("192.0.2.123")));
         response.addInet6AddressRecord(new MdnsInetAddressRecord(
-                hostname, 0L /* receiptTimeMillis */, true /* cacheFlush */,
+                hostname, receiptTimeMillis, true /* cacheFlush */,
                 recordsTtlMillis, parseNumericAddress("2001:db8::123")));
         return response;
     }
@@ -210,7 +210,7 @@
 
     @Test
     public void copyConstructor() {
-        final MdnsResponse response = makeCompleteResponse(TEST_TTL_MS);
+        final MdnsResponse response = makeCompleteResponse(TEST_TTL_MS, 0 /* receiptTimeMillis */);
         final MdnsResponse copy = new MdnsResponse(response);
 
         assertEquals(response.getInet6AddressRecord(), copy.getInet6AddressRecord());
@@ -225,7 +225,7 @@
 
     @Test
     public void addRecords_noChange() {
-        final MdnsResponse response = makeCompleteResponse(TEST_TTL_MS);
+        final MdnsResponse response = makeCompleteResponse(TEST_TTL_MS, 0 /* receiptTimeMillis */);
 
         assertFalse(response.addPointerRecord(response.getPointerRecords().get(0)));
         final String[] serviceName = new String[] { "MYSERVICE", "_TYPE", "_tcp", "local" };
@@ -242,8 +242,8 @@
 
     @Test
     public void addRecords_ttlChange() {
-        final MdnsResponse response = makeCompleteResponse(TEST_TTL_MS);
-        final MdnsResponse ttlZeroResponse = makeCompleteResponse(0);
+        final MdnsResponse response = makeCompleteResponse(TEST_TTL_MS, 0 /* receiptTimeMillis */);
+        final MdnsResponse ttlZeroResponse = makeCompleteResponse(0, 0 /* receiptTimeMillis */);
 
         assertTrue(response.addPointerRecord(ttlZeroResponse.getPointerRecords().get(0)));
         assertEquals(1, response.getPointerRecords().size());
@@ -278,6 +278,46 @@
     }
 
     @Test
+    public void addRecords_receiptTimeChange() {
+        final MdnsResponse response = makeCompleteResponse(TEST_TTL_MS, 0 /* receiptTimeMillis */);
+        final MdnsResponse receiptTimeChangedResponse = makeCompleteResponse(TEST_TTL_MS,
+                1 /* receiptTimeMillis */);
+
+        assertFalse(
+                response.addPointerRecord(receiptTimeChangedResponse.getPointerRecords().get(0)));
+        assertEquals(1, response.getPointerRecords().get(0).getReceiptTime());
+        assertTrue(response.getRecords().stream().anyMatch(r ->
+                r == response.getPointerRecords().get(0)));
+
+        assertFalse(
+                response.addInet6AddressRecord(receiptTimeChangedResponse.getInet6AddressRecord()));
+        assertEquals(1, response.getInet6AddressRecords().size());
+        assertEquals(1, response.getInet6AddressRecord().getReceiptTime());
+        assertTrue(response.getRecords().stream().anyMatch(r ->
+                r == response.getInet6AddressRecord()));
+
+        assertFalse(
+                response.addInet4AddressRecord(receiptTimeChangedResponse.getInet4AddressRecord()));
+        assertEquals(1, response.getInet4AddressRecords().size());
+        assertEquals(1, response.getInet4AddressRecord().getReceiptTime());
+        assertTrue(response.getRecords().stream().anyMatch(r ->
+                r == response.getInet4AddressRecord()));
+
+        assertFalse(response.setServiceRecord(receiptTimeChangedResponse.getServiceRecord()));
+        assertEquals(1, response.getServiceRecord().getReceiptTime());
+        assertTrue(response.getRecords().stream().anyMatch(r ->
+                r == response.getServiceRecord()));
+
+        assertFalse(response.setTextRecord(receiptTimeChangedResponse.getTextRecord()));
+        assertEquals(1, response.getTextRecord().getReceiptTime());
+        assertTrue(response.getRecords().stream().anyMatch(r ->
+                r == response.getTextRecord()));
+
+        // All records were replaced, not added
+        assertEquals(receiptTimeChangedResponse.getRecords().size(), response.getRecords().size());
+    }
+
+    @Test
     public void dropUnmatchedAddressRecords_caseInsensitive() {
 
         final String[] hostname = new String[] { "MyHostname" };
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
index da51240..a696150 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
@@ -25,6 +25,7 @@
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.ArgumentMatchers.argThat;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.doReturn;
@@ -1065,10 +1066,43 @@
         // Second and later sends are sent as "expect multicast response" queries
         inOrder.verify(mockSocketClient, times(2)).sendMulticastPacket(renewalQueryCaptor.capture(),
                 eq(mockNetwork));
+        inOrder.verify(mockListenerOne).onDiscoveryQuerySent(any(), anyInt());
         final MdnsPacket renewalPacket = MdnsPacket.parse(
                 new MdnsPacketReader(renewalQueryCaptor.getValue()));
         assertTrue(hasQuestion(renewalPacket, MdnsRecord.TYPE_SRV, serviceName));
         assertTrue(hasQuestion(renewalPacket, MdnsRecord.TYPE_TXT, serviceName));
+        inOrder.verifyNoMoreInteractions();
+
+        long updatedReceiptTime =  TEST_ELAPSED_REALTIME + TEST_TTL;
+        final MdnsPacket refreshedSrvTxtResponse = new MdnsPacket(
+                0 /* flags */,
+                Collections.emptyList() /* questions */,
+                List.of(
+                        // TODO: cacheFlush will cause addresses to be cleared and re-added every
+                        //  time, which is considered a change and triggers extra
+                        //  onServiceChanged callbacks. Sets cacheFlush bit to false until the
+                        //  issue is fixed.
+                        new MdnsServiceRecord(serviceName, updatedReceiptTime,
+                                false /* cacheFlush */, TEST_TTL, 0 /* servicePriority */,
+                                0 /* serviceWeight */, 1234 /* servicePort */, hostname),
+                        new MdnsTextRecord(serviceName, updatedReceiptTime,
+                                false /* cacheFlush */, TEST_TTL,
+                                Collections.emptyList() /* entries */),
+                        new MdnsInetAddressRecord(hostname, updatedReceiptTime,
+                                false /* cacheFlush */, TEST_TTL,
+                                InetAddresses.parseNumericAddress(ipV4Address)),
+                        new MdnsInetAddressRecord(hostname, updatedReceiptTime,
+                                false /* cacheFlush */, TEST_TTL,
+                                InetAddresses.parseNumericAddress(ipV6Address))),
+                Collections.emptyList() /* authorityRecords */,
+                Collections.emptyList() /* additionalRecords */);
+        client.processResponse(refreshedSrvTxtResponse, INTERFACE_INDEX, mockNetwork);
+
+        // Advance time to updatedReceiptTime + 1, expected no refresh query because the cache
+        // should contain the record that have update last receipt time.
+        doReturn(updatedReceiptTime + 1).when(mockDecoderClock).elapsedRealtime();
+        currentThreadExecutor.getAndClearLastScheduledRunnable().run();
+        inOrder.verifyNoMoreInteractions();
     }
 
     @Test