Renew the SRV/TXT records if half of TTL passed

As mentioned in RFC6762 7.1. The records only needed to be renewed if
at least half of the TTL passed. Usually A/AAAA records are included in
the response to the SRV record query, they are not refreshed individually.

Bug: 285260665
Bug: 285261577
Test: atest CtsNetTest FrameworksNetTests
(cherry picked from https://android-review.googlesource.com/q/commit:f2cc01dc126ba1bd8c89add0853546ab4627c3aa)
Merged-In: Ifd7140de0d733191256184c5481412e1822d279b
Change-Id: Ifd7140de0d733191256184c5481412e1822d279b
diff --git a/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java b/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
index 866ecba..84faf12 100644
--- a/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
+++ b/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
@@ -24,6 +24,7 @@
 import android.util.Pair;
 
 import com.android.server.connectivity.mdns.util.MdnsLogger;
+import com.android.server.connectivity.mdns.util.MdnsUtils;
 
 import java.io.IOException;
 import java.lang.ref.WeakReference;
@@ -75,6 +76,8 @@
     private final boolean sendDiscoveryQueries;
     @NonNull
     private final List<MdnsResponse> servicesToResolve;
+    @NonNull
+    private final MdnsResponseDecoder.Clock clock;
 
     EnqueueMdnsQueryCallable(
             @NonNull MdnsSocketClientBase requestSender,
@@ -85,7 +88,8 @@
             int transactionId,
             @Nullable Network network,
             boolean sendDiscoveryQueries,
-            @NonNull Collection<MdnsResponse> servicesToResolve) {
+            @NonNull Collection<MdnsResponse> servicesToResolve,
+            @NonNull MdnsResponseDecoder.Clock clock) {
         weakRequestSender = new WeakReference<>(requestSender);
         this.packetWriter = packetWriter;
         serviceTypeLabels = TextUtils.split(serviceType, "\\.");
@@ -95,6 +99,7 @@
         this.network = network;
         this.sendDiscoveryQueries = sendDiscoveryQueries;
         this.servicesToResolve = new ArrayList<>(servicesToResolve);
+        this.clock = clock;
     }
 
     // Incompatible return type for override of Callable#call().
@@ -119,22 +124,24 @@
 
             // List of (name, type) to query
             final ArrayList<Pair<String[], Integer>> missingKnownAnswerRecords = new ArrayList<>();
+            final long now = clock.elapsedRealtime();
             for (MdnsResponse response : servicesToResolve) {
-                // TODO: also send queries to renew record TTL (as per RFC6762 7.1 no need to query
-                // if remaining TTL is more than half the original one, so send the queries if half
-                // the TTL has passed).
-                if (response.isComplete()) continue;
                 final String[] serviceName = response.getServiceName();
                 if (serviceName == null) continue;
-                if (!response.hasTextRecord()) {
+                if (!response.hasTextRecord() || MdnsUtils.isRecordRenewalNeeded(
+                        response.getTextRecord(), now)) {
                     missingKnownAnswerRecords.add(new Pair<>(serviceName, MdnsRecord.TYPE_TXT));
                 }
-                if (!response.hasServiceRecord()) {
+                if (!response.hasServiceRecord() || MdnsUtils.isRecordRenewalNeeded(
+                        response.getServiceRecord(), now)) {
                     missingKnownAnswerRecords.add(new Pair<>(serviceName, MdnsRecord.TYPE_SRV));
                     // The hostname is not yet known, so queries for address records will be sent
                     // the next time the EnqueueMdnsQueryCallable is enqueued if the reply does not
                     // contain them. In practice, advertisers should include the address records
                     // when queried for SRV, although it's not a MUST requirement (RFC6763 12.2).
+                    // TODO: Figure out how to renew the A/AAAA record. Usually A/AAAA record will
+                    //  be included in the response to the SRV record so in high chances there is
+                    //  no need to renew them individually.
                 } else if (!response.hasInet4AddressRecord() && !response.hasInet6AddressRecord()) {
                     final String[] host = response.getServiceRecord().getServiceHost();
                     missingKnownAnswerRecords.add(new Pair<>(host, MdnsRecord.TYPE_A));
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
index d7eaea5..809750d 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
@@ -532,7 +532,8 @@
                                 config.transactionId,
                                 config.network,
                                 sendDiscoveryQueries,
-                                servicesToResolve)
+                                servicesToResolve,
+                                clock)
                                 .call();
             } catch (RuntimeException e) {
                 sharedLog.e(String.format("Failed to run EnqueueMdnsQueryCallable for subtype: %s",
diff --git a/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java b/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java
index 5413956..3added6 100644
--- a/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java
+++ b/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java
@@ -108,4 +108,15 @@
         encoder.encode(CharBuffer.wrap(originalName), out, true /* endOfInput */);
         return new String(out.array(), 0, out.position(), utf8);
     }
+
+    /**
+     * Checks if the MdnsRecord needs to be renewed or not.
+     *
+     * <p>As per RFC6762 7.1 no need to query if remaining TTL is more than half the original one,
+     * so send the queries if half the TTL has passed.
+     */
+    public static boolean isRecordRenewalNeeded(@NonNull MdnsRecord mdnsRecord, final long now) {
+        return mdnsRecord.getTtl() > 0
+                && mdnsRecord.getRemainingTTL(now) <= mdnsRecord.getTtl() / 2;
+    }
 }
\ No newline at end of file
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
index 43f3ace..f51079b 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
@@ -932,15 +932,11 @@
 
         final MdnsPacket srvTxtQueryPacket = MdnsPacket.parse(
                 new MdnsPacketReader(srvTxtQueryCaptor.getValue()));
-        final List<MdnsRecord> srvTxtQuestions = srvTxtQueryPacket.questions;
 
-        final String[] serviceName = Stream.concat(Stream.of(instanceName),
-                Arrays.stream(SERVICE_TYPE_LABELS)).toArray(String[]::new);
-        assertFalse(srvTxtQuestions.stream().anyMatch(q -> q.getType() == MdnsRecord.TYPE_PTR));
-        assertTrue(srvTxtQuestions.stream().anyMatch(q ->
-                q.getType() == MdnsRecord.TYPE_SRV && Arrays.equals(q.name, serviceName)));
-        assertTrue(srvTxtQuestions.stream().anyMatch(q ->
-                q.getType() == MdnsRecord.TYPE_TXT && Arrays.equals(q.name, serviceName)));
+        final String[] serviceName = getTestServiceName(instanceName);
+        assertFalse(hasQuestion(srvTxtQueryPacket, MdnsRecord.TYPE_PTR));
+        assertTrue(hasQuestion(srvTxtQueryPacket, MdnsRecord.TYPE_SRV, serviceName));
+        assertTrue(hasQuestion(srvTxtQueryPacket, MdnsRecord.TYPE_TXT, serviceName));
 
         // Process a response with SRV+TXT
         final MdnsPacket srvTxtResponse = new MdnsPacket(
@@ -967,11 +963,8 @@
 
         final MdnsPacket addressQueryPacket = MdnsPacket.parse(
                 new MdnsPacketReader(addressQueryCaptor.getValue()));
-        final List<MdnsRecord> addressQueryQuestions = addressQueryPacket.questions;
-        assertTrue(addressQueryQuestions.stream().anyMatch(q ->
-                q.getType() == MdnsRecord.TYPE_A && Arrays.equals(q.name, hostname)));
-        assertTrue(addressQueryQuestions.stream().anyMatch(q ->
-                q.getType() == MdnsRecord.TYPE_AAAA && Arrays.equals(q.name, hostname)));
+        assertTrue(hasQuestion(addressQueryPacket, MdnsRecord.TYPE_A, hostname));
+        assertTrue(hasQuestion(addressQueryPacket, MdnsRecord.TYPE_AAAA, hostname));
 
         // Process a response with address records
         final MdnsPacket addressResponse = new MdnsPacket(
@@ -1004,6 +997,81 @@
     }
 
     @Test
+    public void testRenewTxtSrvInResolve() throws Exception {
+        client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
+                mockDecoderClock, mockNetwork, mockSharedLog);
+
+        final String instanceName = "service-instance";
+        final String[] hostname = new String[] { "testhost "};
+        final String ipV4Address = "192.0.2.0";
+        final String ipV6Address = "2001:db8::";
+
+        final MdnsSearchOptions resolveOptions = MdnsSearchOptions.newBuilder()
+                .setResolveInstanceName(instanceName).build();
+
+        client.startSendAndReceive(mockListenerOne, resolveOptions);
+        InOrder inOrder = inOrder(mockListenerOne, mockSocketClient);
+
+        // Get the query for SRV/TXT
+        final ArgumentCaptor<DatagramPacket> srvTxtQueryCaptor =
+                ArgumentCaptor.forClass(DatagramPacket.class);
+        currentThreadExecutor.getAndClearLastScheduledRunnable().run();
+        // Send twice for IPv4 and IPv6
+        inOrder.verify(mockSocketClient, times(2)).sendUnicastPacket(srvTxtQueryCaptor.capture(),
+                eq(mockNetwork));
+
+        final MdnsPacket srvTxtQueryPacket = MdnsPacket.parse(
+                new MdnsPacketReader(srvTxtQueryCaptor.getValue()));
+
+        final String[] serviceName = getTestServiceName(instanceName);
+        assertTrue(hasQuestion(srvTxtQueryPacket, MdnsRecord.TYPE_SRV, serviceName));
+        assertTrue(hasQuestion(srvTxtQueryPacket, MdnsRecord.TYPE_TXT, serviceName));
+
+        // Process a response with all records
+        final MdnsPacket srvTxtResponse = new MdnsPacket(
+                0 /* flags */,
+                Collections.emptyList() /* questions */,
+                List.of(
+                        new MdnsServiceRecord(serviceName, TEST_ELAPSED_REALTIME,
+                                true /* cacheFlush */, TEST_TTL, 0 /* servicePriority */,
+                                0 /* serviceWeight */, 1234 /* servicePort */, hostname),
+                        new MdnsTextRecord(serviceName, TEST_ELAPSED_REALTIME,
+                                true /* cacheFlush */, TEST_TTL,
+                                Collections.emptyList() /* entries */),
+                        new MdnsInetAddressRecord(hostname, TEST_ELAPSED_REALTIME,
+                                true /* cacheFlush */, TEST_TTL,
+                                InetAddresses.parseNumericAddress(ipV4Address)),
+                        new MdnsInetAddressRecord(hostname, TEST_ELAPSED_REALTIME,
+                                true /* cacheFlush */, TEST_TTL,
+                                InetAddresses.parseNumericAddress(ipV6Address))),
+                Collections.emptyList() /* authorityRecords */,
+                Collections.emptyList() /* additionalRecords */);
+        client.processResponse(srvTxtResponse, INTERFACE_INDEX, mockNetwork);
+        inOrder.verify(mockListenerOne).onServiceNameDiscovered(any());
+        inOrder.verify(mockListenerOne).onServiceFound(any());
+
+        // Expect no query on the next run
+        currentThreadExecutor.getAndClearLastScheduledRunnable().run();
+        inOrder.verifyNoMoreInteractions();
+
+        // Advance time so 75% of TTL passes and re-execute
+        doReturn(TEST_ELAPSED_REALTIME + (long) (TEST_TTL * 0.75))
+                .when(mockDecoderClock).elapsedRealtime();
+        currentThreadExecutor.getAndClearLastScheduledRunnable().run();
+
+        // Expect a renewal query
+        final ArgumentCaptor<DatagramPacket> renewalQueryCaptor =
+                ArgumentCaptor.forClass(DatagramPacket.class);
+        // Second and later sends are sent as "expect multicast response" queries
+        inOrder.verify(mockSocketClient, times(2)).sendMulticastPacket(renewalQueryCaptor.capture(),
+                eq(mockNetwork));
+        final MdnsPacket renewalPacket = MdnsPacket.parse(
+                new MdnsPacketReader(renewalQueryCaptor.getValue()));
+        assertTrue(hasQuestion(renewalPacket, MdnsRecord.TYPE_SRV, serviceName));
+        assertTrue(hasQuestion(renewalPacket, MdnsRecord.TYPE_TXT, serviceName));
+    }
+
+    @Test
     public void testProcessResponse_ResolveExcludesOtherServices() {
         client = new MdnsServiceTypeClient(
                 SERVICE_TYPE, mockSocketClient, currentThreadExecutor, mockNetwork, mockSharedLog);
@@ -1244,6 +1312,20 @@
         }
     }
 
+    private static String[] getTestServiceName(String instanceName) {
+        return Stream.concat(Stream.of(instanceName),
+                Arrays.stream(SERVICE_TYPE_LABELS)).toArray(String[]::new);
+    }
+
+    private static boolean hasQuestion(MdnsPacket packet, int type) {
+        return hasQuestion(packet, type, null);
+    }
+
+    private static boolean hasQuestion(MdnsPacket packet, int type, @Nullable String[] name) {
+        return packet.questions.stream().anyMatch(q -> q.getType() == type
+                && (name == null || Arrays.equals(q.name, name)));
+    }
+
     // A fake ScheduledExecutorService that keeps tracking the last scheduled Runnable and its delay
     // time.
     private class FakeExecutor extends ScheduledThreadPoolExecutor {