Update record TTL in set/add methods
MdnsResponse would not replace records if the TTL changed, but this is
necessary in particular in case of exit announcements (RFC6762 10.1)
where the TTL is updated to zero.
Update MdnsResponse to update records if the TTL changed. The
receiptTimeMillis is still not considered to update records, but doing
so would generate service change notifications every time a reply is
received, so this is not covered here.
Bug: 267570781
Test: atest MdnsResponseTests
Change-Id: Ied55829cba6b043d13603a8230bda457c07de42f
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsResponse.java b/service-t/src/com/android/server/connectivity/mdns/MdnsResponse.java
index 6b17d83..ee0a3d8 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsResponse.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsResponse.java
@@ -28,6 +28,7 @@
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
+import java.util.Objects;
/** An mDNS response. */
public class MdnsResponse {
@@ -62,27 +63,35 @@
network = base.network;
}
- // This generic typed helper compares records for equality.
- // Returns True if records are the same.
- private <T> boolean recordsAreSame(T a, T b) {
- return ((a == null) && (b == null)) || ((a != null) && (b != null) && a.equals(b));
+ /**
+ * Compare records for equality, including their TTL.
+ *
+ * MdnsRecord#equals ignores TTL and receiptTimeMillis, but methods in this class need to update
+ * records when the TTL changes (especially for goodbye announcements).
+ */
+ private boolean recordsAreSame(MdnsRecord a, MdnsRecord b) {
+ if (!Objects.equals(a, b)) return false;
+ return a == null || a.getTtl() == b.getTtl();
}
/**
* Adds a pointer record.
*
* @return <code>true</code> if the record was added, or <code>false</code> if a matching
- * pointer
- * record is already present in the response.
+ * pointer record is already present in the response with the same TTL.
*/
public synchronized boolean addPointerRecord(MdnsPointerRecord pointerRecord) {
- if (!pointerRecords.contains(pointerRecord)) {
- pointerRecords.add(pointerRecord);
- records.add(pointerRecord);
- return true;
+ final int existing = pointerRecords.indexOf(pointerRecord);
+ if (existing >= 0) {
+ if (recordsAreSame(pointerRecord, pointerRecords.get(existing))) {
+ return false;
+ }
+ pointerRecords.remove(existing);
+ records.remove(existing);
}
-
- return false;
+ pointerRecords.add(pointerRecord);
+ records.add(pointerRecord);
+ return true;
}
/** Gets the pointer records. */
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseTests.java
index 9148ac3..22f4812 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseTests.java
@@ -101,6 +101,7 @@
+ "21402324");
private static final int INTERFACE_INDEX = 999;
+ private static final int TEST_TTL_MS = 120_000;
private final Network mNetwork = mock(Network.class);
// The following helper classes act as wrappers so that IPv4 and IPv6 address records can
@@ -163,6 +164,27 @@
return response;
}
+ private MdnsResponse makeCompleteResponse(int recordsTtlMillis) {
+ final MdnsResponse response = new MdnsResponse(/* now= */ 0, INTERFACE_INDEX, mNetwork);
+ final String[] hostname = new String[] { "MyHostname" };
+ final String[] serviceName = new String[] { "MyService", "_type", "_tcp", "local" };
+ final String[] serviceType = new String[] { "_type", "_tcp", "local" };
+ response.addPointerRecord(new MdnsPointerRecord(serviceType, 0L /* receiptTimeMillis */,
+ false /* cacheFlush */, recordsTtlMillis, serviceName));
+ response.setServiceRecord(new MdnsServiceRecord(serviceName, 0L /* receiptTimeMillis */,
+ true /* cacheFlush */, recordsTtlMillis, 0 /* servicePriority */,
+ 0 /* serviceWeight */, 0 /* servicePort */, hostname));
+ response.setTextRecord(new MdnsTextRecord(serviceName, 0L /* receiptTimeMillis */,
+ true /* cacheFlush */, recordsTtlMillis, emptyList() /* entries */));
+ response.setInet4AddressRecord(new MdnsInetAddressRecord(
+ hostname, 0L /* receiptTimeMillis */, true /* cacheFlush */,
+ recordsTtlMillis, parseNumericAddress("192.0.2.123")));
+ response.setInet6AddressRecord(new MdnsInetAddressRecord(
+ hostname, 0L /* receiptTimeMillis */, true /* cacheFlush */,
+ recordsTtlMillis, parseNumericAddress("2001:db8::123")));
+ return response;
+ }
+
@Test
public void getInet4AddressRecord_returnsAddedRecord() throws IOException {
DatagramPacket packet = new DatagramPacket(dataIn_ipv4_1, dataIn_ipv4_1.length);
@@ -337,23 +359,7 @@
@Test
public void copyConstructor() {
- final MdnsResponse response = new MdnsResponse(/* now= */ 0, INTERFACE_INDEX, mNetwork);
- final String[] hostname = new String[] { "MyHostname" };
- final String[] serviceName = new String[] { "MyService", "_type", "_tcp", "local" };
- final String[] serviceType = new String[] { "_type", "_tcp", "local" };
- response.addPointerRecord(new MdnsPointerRecord(serviceType, 0L /* receiptTimeMillis */,
- false /* cacheFlush */, 1234L /* ttlMillis */, serviceName));
- response.setServiceRecord(new MdnsServiceRecord(serviceName, 0L /* receiptTimeMillis */,
- true /* cacheFlush */, 1234L /* ttlMillis */, 0 /* servicePriority */,
- 0 /* serviceWeight */, 0 /* servicePort */, hostname));
- response.setTextRecord(new MdnsTextRecord(serviceName, 0L /* receiptTimeMillis */,
- true /* cacheFlush */, 1234L /* ttlMillis */, emptyList() /* entries */));
- response.setInet4AddressRecord(new MdnsInetAddressRecord(
- hostname, 0L /* receiptTimeMillis */, true /* cacheFlush */,
- 1234L /* ttlMillis */, parseNumericAddress("192.0.2.123")));
- response.setInet6AddressRecord(new MdnsInetAddressRecord(
- hostname, 0L /* receiptTimeMillis */, true /* cacheFlush */,
- 1234L /* ttlMillis */, parseNumericAddress("2001:db8::123")));
+ final MdnsResponse response = makeCompleteResponse(TEST_TTL_MS);
final MdnsResponse copy = new MdnsResponse(response);
assertEquals(response.getInet6AddressRecord(), copy.getInet6AddressRecord());
@@ -365,4 +371,50 @@
assertEquals(response.getNetwork(), copy.getNetwork());
assertEquals(response.getInterfaceIndex(), copy.getInterfaceIndex());
}
+
+ @Test
+ public void addRecords_noChange() {
+ final MdnsResponse response = makeCompleteResponse(TEST_TTL_MS);
+
+ assertFalse(response.addPointerRecord(response.getPointerRecords().get(0)));
+ assertFalse(response.setInet6AddressRecord(response.getInet6AddressRecord()));
+ assertFalse(response.setInet4AddressRecord(response.getInet4AddressRecord()));
+ assertFalse(response.setServiceRecord(response.getServiceRecord()));
+ assertFalse(response.setTextRecord(response.getTextRecord()));
+ }
+
+ @Test
+ public void addRecords_ttlChange() {
+ final MdnsResponse response = makeCompleteResponse(TEST_TTL_MS);
+ final MdnsResponse ttlZeroResponse = makeCompleteResponse(0);
+
+ assertTrue(response.addPointerRecord(ttlZeroResponse.getPointerRecords().get(0)));
+ assertEquals(1, response.getPointerRecords().size());
+ assertEquals(0, response.getPointerRecords().get(0).getTtl());
+ assertTrue(response.getRecords().stream().anyMatch(r ->
+ r == response.getPointerRecords().get(0)));
+
+ assertTrue(response.setInet6AddressRecord(ttlZeroResponse.getInet6AddressRecord()));
+ assertEquals(0, response.getInet6AddressRecord().getTtl());
+ assertTrue(response.getRecords().stream().anyMatch(r ->
+ r == response.getInet6AddressRecord()));
+
+ assertTrue(response.setInet4AddressRecord(ttlZeroResponse.getInet4AddressRecord()));
+ assertEquals(0, response.getInet4AddressRecord().getTtl());
+ assertTrue(response.getRecords().stream().anyMatch(r ->
+ r == response.getInet4AddressRecord()));
+
+ assertTrue(response.setServiceRecord(ttlZeroResponse.getServiceRecord()));
+ assertEquals(0, response.getServiceRecord().getTtl());
+ assertTrue(response.getRecords().stream().anyMatch(r ->
+ r == response.getServiceRecord()));
+
+ assertTrue(response.setTextRecord(ttlZeroResponse.getTextRecord()));
+ assertEquals(0, response.getTextRecord().getTtl());
+ assertTrue(response.getRecords().stream().anyMatch(r ->
+ r == response.getTextRecord()));
+
+ // All records were replaced, not added
+ assertEquals(ttlZeroResponse.getRecords().size(), response.getRecords().size());
+ }
}
\ No newline at end of file