Merge changes I3b1ad1be,Id4c2e610

* changes:
  Implement exit announcements
  Implement announcements on probing success
diff --git a/service-t/src/com/android/server/mdns/MdnsAnnouncer.java b/service-t/src/com/android/server/mdns/MdnsAnnouncer.java
index 7c84323..27fc945 100644
--- a/service-t/src/com/android/server/mdns/MdnsAnnouncer.java
+++ b/service-t/src/com/android/server/mdns/MdnsAnnouncer.java
@@ -30,23 +30,27 @@
  *
  * This allows maintaining other hosts' caches up-to-date. See RFC6762 8.3.
  */
-public class MdnsAnnouncer extends MdnsPacketRepeater<MdnsAnnouncer.AnnouncementInfo> {
+public class MdnsAnnouncer extends MdnsPacketRepeater<MdnsAnnouncer.BaseAnnouncementInfo> {
     private static final long ANNOUNCEMENT_INITIAL_DELAY_MS = 1000L;
     @VisibleForTesting
     static final int ANNOUNCEMENT_COUNT = 8;
 
+    // Matches delay and GoodbyeCount used by the legacy implementation
+    private static final long EXIT_DELAY_MS = 2000L;
+    private static final int EXIT_COUNT = 3;
+
     @NonNull
     private final String mLogTag;
 
-    /** Announcement request to send with {@link MdnsAnnouncer}. */
-    public static class AnnouncementInfo implements MdnsPacketRepeater.Request {
+    /** Base class for announcement requests to send with {@link MdnsAnnouncer}. */
+    public abstract static class BaseAnnouncementInfo implements MdnsPacketRepeater.Request {
+        private final int mServiceId;
         @NonNull
         private final MdnsPacket mPacket;
 
-        AnnouncementInfo(List<MdnsRecord> announcedRecords, List<MdnsRecord> additionalRecords) {
-            // Records to announce (as answers)
-            // Records to place in the "Additional records", with NSEC negative responses
-            // to mark records that have been verified unique
+        protected BaseAnnouncementInfo(int serviceId, @NonNull List<MdnsRecord> announcedRecords,
+                @NonNull List<MdnsRecord> additionalRecords) {
+            mServiceId = serviceId;
             final int flags = 0x8400; // Response, authoritative (rfc6762 18.4)
             mPacket = new MdnsPacket(flags,
                     Collections.emptyList() /* questions */,
@@ -55,10 +59,22 @@
                     additionalRecords);
         }
 
+        public int getServiceId() {
+            return mServiceId;
+        }
+
         @Override
         public MdnsPacket getPacket(int index) {
             return mPacket;
         }
+    }
+
+    /** Announcement request to send with {@link MdnsAnnouncer}. */
+    public static class AnnouncementInfo extends BaseAnnouncementInfo {
+        AnnouncementInfo(int serviceId, List<MdnsRecord> announcedRecords,
+                List<MdnsRecord> additionalRecords) {
+            super(serviceId, announcedRecords, additionalRecords);
+        }
 
         @Override
         public long getDelayMs(int nextIndex) {
@@ -72,9 +88,26 @@
         }
     }
 
+    /** Service exit announcement request to send with {@link MdnsAnnouncer}. */
+    public static class ExitAnnouncementInfo extends BaseAnnouncementInfo {
+        ExitAnnouncementInfo(int serviceId, List<MdnsRecord> announcedRecords) {
+            super(serviceId, announcedRecords, Collections.emptyList() /* additionalRecords */);
+        }
+
+        @Override
+        public long getDelayMs(int nextIndex) {
+            return EXIT_DELAY_MS;
+        }
+
+        @Override
+        public int getNumSends() {
+            return EXIT_COUNT;
+        }
+    }
+
     public MdnsAnnouncer(@NonNull String interfaceTag, @NonNull Looper looper,
             @NonNull MdnsReplySender replySender,
-            @Nullable PacketRepeaterCallback<AnnouncementInfo> cb) {
+            @Nullable PacketRepeaterCallback<BaseAnnouncementInfo> cb) {
         super(looper, replySender, cb);
         mLogTag = MdnsAnnouncer.class.getSimpleName() + "/" + interfaceTag;
     }
diff --git a/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java b/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java
index 997dcbb..790e69a 100644
--- a/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java
+++ b/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java
@@ -25,6 +25,7 @@
 import android.util.Log;
 
 import com.android.internal.annotations.VisibleForTesting;
+import com.android.server.connectivity.mdns.MdnsAnnouncer.BaseAnnouncementInfo;
 import com.android.server.connectivity.mdns.MdnsPacketRepeater.PacketRepeaterCallback;
 
 import java.io.IOException;
@@ -113,9 +114,22 @@
     /**
      * Callbacks from {@link MdnsAnnouncer}.
      */
-    private class AnnouncingCallback
-            implements PacketRepeaterCallback<MdnsAnnouncer.AnnouncementInfo> {
-        // TODO: implement
+    private class AnnouncingCallback implements PacketRepeaterCallback<BaseAnnouncementInfo> {
+        @Override
+        public void onSent(int index, @NonNull BaseAnnouncementInfo info) {
+            mRecordRepository.onAdvertisementSent(info.getServiceId());
+        }
+
+        @Override
+        public void onFinished(@NonNull BaseAnnouncementInfo info) {
+            if (info instanceof MdnsAnnouncer.ExitAnnouncementInfo) {
+                mRecordRepository.removeService(info.getServiceId());
+
+                if (mRecordRepository.getServicesCount() == 0) {
+                    destroyNow();
+                }
+            }
+        }
     }
 
     /**
@@ -139,7 +153,7 @@
         /** @see MdnsAnnouncer */
         public MdnsAnnouncer makeMdnsAnnouncer(@NonNull String interfaceTag, @NonNull Looper looper,
                 @NonNull MdnsReplySender replySender,
-                @Nullable PacketRepeaterCallback<MdnsAnnouncer.AnnouncementInfo> cb) {
+                @Nullable PacketRepeaterCallback<MdnsAnnouncer.BaseAnnouncementInfo> cb) {
             return new MdnsAnnouncer(interfaceTag, looper, replySender, cb);
         }
 
@@ -210,9 +224,10 @@
      * This will trigger exit announcements for the service.
      */
     public void removeService(int id) {
+        if (!mRecordRepository.hasActiveService(id)) return;
         mProber.stop(id);
         mAnnouncer.stop(id);
-        final MdnsAnnouncer.AnnouncementInfo exitInfo = mRecordRepository.exitService(id);
+        final MdnsAnnouncer.ExitAnnouncementInfo exitInfo = mRecordRepository.exitService(id);
         if (exitInfo != null) {
             // This effectively schedules destroyNow(), as it is to be called when the exit
             // announcement finishes if there is no service left.
diff --git a/service-t/src/com/android/server/mdns/MdnsNsecRecord.java b/service-t/src/com/android/server/mdns/MdnsNsecRecord.java
index 06fdd5e..6ec2f99 100644
--- a/service-t/src/com/android/server/mdns/MdnsNsecRecord.java
+++ b/service-t/src/com/android/server/mdns/MdnsNsecRecord.java
@@ -17,12 +17,14 @@
 package com.android.server.connectivity.mdns;
 
 import android.net.DnsResolver;
+import android.text.TextUtils;
 
 import com.android.net.module.util.CollectionUtils;
 
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Objects;
 
 /**
  * A mDNS "NSEC" record, used in particular for negative responses (RFC6762 6.1).
@@ -140,4 +142,30 @@
         writer.writeUInt8(bytesInBlock);
         writer.writeBytes(bytes);
     }
+
+    @Override
+    public String toString() {
+        return "NSEC: NextDomain: " + TextUtils.join(".", mNextDomain)
+                + " Types " + Arrays.toString(mTypes);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(super.hashCode(),
+                Arrays.hashCode(mNextDomain), Arrays.hashCode(mTypes));
+    }
+
+    @Override
+    public boolean equals(Object other) {
+        if (this == other) {
+            return true;
+        }
+        if (!(other instanceof MdnsNsecRecord)) {
+            return false;
+        }
+
+        return super.equals(other)
+                && Arrays.equals(mNextDomain, ((MdnsNsecRecord) other).mNextDomain)
+                && Arrays.equals(mTypes, ((MdnsNsecRecord) other).mTypes);
+    }
 }
diff --git a/service-t/src/com/android/server/mdns/MdnsRecordRepository.java b/service-t/src/com/android/server/mdns/MdnsRecordRepository.java
index bb9c751..dd00212 100644
--- a/service-t/src/com/android/server/mdns/MdnsRecordRepository.java
+++ b/service-t/src/com/android/server/mdns/MdnsRecordRepository.java
@@ -18,21 +18,32 @@
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.annotation.TargetApi;
 import android.net.LinkAddress;
 import android.net.nsd.NsdServiceInfo;
+import android.os.Build;
 import android.os.Looper;
+import android.os.SystemClock;
+import android.util.ArraySet;
 import android.util.SparseArray;
 
 import com.android.internal.annotations.VisibleForTesting;
+import com.android.net.module.util.CollectionUtils;
+import com.android.net.module.util.HexDump;
 
 import java.io.IOException;
+import java.net.Inet4Address;
 import java.net.InetAddress;
 import java.net.NetworkInterface;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.Enumeration;
+import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
+import java.util.TreeMap;
 import java.util.UUID;
 import java.util.concurrent.TimeUnit;
 
@@ -41,6 +52,7 @@
  *
  * Must be used on a consistent looper thread.
  */
+@TargetApi(Build.VERSION_CODES.TIRAMISU) // Allow calling T+ APIs; this is only loaded on T+
 public class MdnsRecordRepository {
     // TTLs as per RFC6762 10.
     // TTL for records with a host name as the resource record's name (e.g., A, AAAA, HINFO) or a
@@ -61,6 +73,8 @@
     @NonNull
     private final SparseArray<ServiceRegistration> mServices = new SparseArray<>();
     @NonNull
+    private final List<RecordInfo<?>> mGeneralRecords = new ArrayList<>();
+    @NonNull
     private final Looper mLooper;
     @NonNull
     private String[] mDeviceHostname;
@@ -124,6 +138,12 @@
          */
         public boolean isProbing;
 
+        /**
+         * Last time (as per SystemClock.elapsedRealtime) when sent via unicast or multicast,
+         * 0 if never
+         */
+        public long lastSentTimeMs;
+
         RecordInfo(NsdServiceInfo serviceInfo, T record, boolean sharedName,
                  boolean probing) {
             this.serviceInfo = serviceInfo;
@@ -221,7 +241,29 @@
      * Inform the repository of the latest interface addresses.
      */
     public void updateAddresses(@NonNull List<LinkAddress> newAddresses) {
-        // TODO: implement to update addresses in records
+        mGeneralRecords.clear();
+        for (LinkAddress addr : newAddresses) {
+            final String[] revDnsAddr = getReverseDnsAddress(addr.getAddress());
+            mGeneralRecords.add(new RecordInfo<>(
+                    null /* serviceInfo */,
+                    new MdnsPointerRecord(
+                            revDnsAddr,
+                            0L /* receiptTimeMillis */,
+                            true /* cacheFlush */,
+                            NAME_RECORDS_TTL_MILLIS,
+                            mDeviceHostname),
+                    false /* sharedName */, false /* probing */));
+
+            mGeneralRecords.add(new RecordInfo<>(
+                    null /* serviceInfo */,
+                    new MdnsInetAddressRecord(
+                            mDeviceHostname,
+                            0L /* receiptTimeMillis */,
+                            true /* cacheFlush */,
+                            NAME_RECORDS_TTL_MILLIS,
+                            addr.getAddress()),
+                    false /* sharedName */, false /* probing */));
+        }
     }
 
     /**
@@ -298,20 +340,31 @@
      * @return The exit announcement to indicate the service was removed, or null if not necessary.
      */
     @Nullable
-    public MdnsAnnouncer.AnnouncementInfo exitService(int id) {
+    public MdnsAnnouncer.ExitAnnouncementInfo exitService(int id) {
         final ServiceRegistration registration = mServices.get(id);
         if (registration == null) return null;
         if (registration.exiting) return null;
 
-        registration.exiting = true;
+        // Send exit (TTL 0) for the PTR record, if the record was sent (in particular don't send
+        // if still probing)
+        if (registration.ptrRecord.lastSentTimeMs == 0L) {
+            return null;
+        }
 
-        // TODO: implement
-        return null;
+        registration.exiting = true;
+        final MdnsPointerRecord expiredRecord = new MdnsPointerRecord(
+                registration.ptrRecord.record.getName(),
+                0L /* receiptTimeMillis */,
+                true /* cacheFlush */,
+                0L /* ttlMillis */,
+                registration.ptrRecord.record.getPointer());
+
+        // Exit should be skipped if the record is still advertised by another service, but that
+        // would be a conflict (2 service registrations with the same service name), so it would
+        // not have been allowed by the repository.
+        return new MdnsAnnouncer.ExitAnnouncementInfo(id, Collections.singletonList(expiredRecord));
     }
 
-    /**
-     * Remove a service from the repository
-     */
     public void removeService(int id) {
         mServices.remove(id);
     }
@@ -338,14 +391,110 @@
     }
 
     /**
+     * Add NSEC records indicating that the response records are unique.
+     *
+     * Following RFC6762 6.1:
+     * "On receipt of a question for a particular name, rrtype, and rrclass, for which a responder
+     * does have one or more unique answers, the responder MAY also include an NSEC record in the
+     * Additional Record Section indicating the nonexistence of other rrtypes for that name and
+     * rrclass."
+     * @param destinationList List to add the NSEC records to.
+     * @param answerRecords Lists of answered records based on which to add NSEC records (typically
+     *                      answer and additionalAnswer sections)
+     */
+    @SafeVarargs
+    private static void addNsecRecordsForUniqueNames(
+            List<MdnsRecord> destinationList,
+            Iterator<RecordInfo<?>>... answerRecords) {
+        // Group unique records by name. Use a TreeMap with comparator as arrays don't implement
+        // equals / hashCode.
+        final Map<String[], List<MdnsRecord>> nsecByName = new TreeMap<>(Arrays::compare);
+        // But keep the list of names in added order, otherwise records would be sorted in
+        // alphabetical order instead of the order of the original records, which would look like
+        // inconsistent behavior depending on service name.
+        final List<String[]> namesInAddedOrder = new ArrayList<>();
+        for (Iterator<RecordInfo<?>> answers : answerRecords) {
+            addNonSharedRecordsToMap(answers, nsecByName, namesInAddedOrder);
+        }
+
+        for (String[] nsecName : namesInAddedOrder) {
+            final List<MdnsRecord> entryRecords = nsecByName.get(nsecName);
+            long minTtl = Long.MAX_VALUE;
+            final Set<Integer> types = new ArraySet<>(entryRecords.size());
+            for (MdnsRecord record : entryRecords) {
+                if (minTtl > record.getTtl()) minTtl = record.getTtl();
+                types.add(record.getType());
+            }
+
+            destinationList.add(new MdnsNsecRecord(
+                    nsecName,
+                    0L /* receiptTimeMillis */,
+                    true /* cacheFlush */,
+                    minTtl,
+                    nsecName,
+                    CollectionUtils.toIntArray(types)));
+        }
+    }
+
+    /**
+     * Add non-shared records to a map listing them by record name, and to a list of names that
+     * remembers the adding order.
+     *
+     * In the destination map records are grouped by name; so the map has one key per record name,
+     * and the values are the lists of different records that share the same name.
+     * @param records Records to scan.
+     * @param dest Map to add the records to.
+     * @param namesInAddedOrder List of names to add the names in order, keeping the first
+     *                          occurrence of each name.
+     */
+    private static void addNonSharedRecordsToMap(
+            Iterator<RecordInfo<?>> records,
+            Map<String[], List<MdnsRecord>> dest,
+            List<String[]> namesInAddedOrder) {
+        while (records.hasNext()) {
+            final RecordInfo<?> record = records.next();
+            if (record.isSharedName) continue;
+            final List<MdnsRecord> recordsForName = dest.computeIfAbsent(record.record.name,
+                    key -> {
+                        namesInAddedOrder.add(key);
+                        return new ArrayList<>();
+                    });
+            recordsForName.add(record.record);
+        }
+    }
+
+    /**
      * Called to indicate that probing succeeded for a service.
      * @param probeSuccessInfo The successful probing info.
      * @return The {@link MdnsAnnouncer.AnnouncementInfo} to send, now that probing has succeeded.
      */
     public MdnsAnnouncer.AnnouncementInfo onProbingSucceeded(
-            MdnsProber.ProbingInfo probeSuccessInfo) throws IOException {
-        // TODO: implement: set service as not probing anymore and generate announcements
-        throw new IOException("Announcements not implemented");
+            MdnsProber.ProbingInfo probeSuccessInfo)
+            throws IOException {
+
+        final ServiceRegistration registration = mServices.get(probeSuccessInfo.getServiceId());
+        if (registration == null) throw new IOException(
+                "Service is not registered: " + probeSuccessInfo.getServiceId());
+        registration.setProbing(false);
+
+        final ArrayList<MdnsRecord> answers = new ArrayList<>();
+        final ArrayList<MdnsRecord> additionalAnswers = new ArrayList<>();
+
+        // Interface address records in general records
+        for (RecordInfo<?> record : mGeneralRecords) {
+            answers.add(record.record);
+        }
+
+        // All service records
+        for (RecordInfo<?> info : registration.allRecords) {
+            answers.add(info.record);
+        }
+
+        addNsecRecordsForUniqueNames(additionalAnswers,
+                mGeneralRecords.iterator(), registration.allRecords.iterator());
+
+        return new MdnsAnnouncer.AnnouncementInfo(probeSuccessInfo.getServiceId(),
+                answers, additionalAnswers);
     }
 
     /**
@@ -371,6 +520,60 @@
         return registration.srvRecord.isProbing;
     }
 
+    /**
+     * Return whether the repository has an active (non-exiting) service for the given ID.
+     */
+    public boolean hasActiveService(int serviceId) {
+        final ServiceRegistration registration = mServices.get(serviceId);
+        if (registration == null) return false;
+
+        return !registration.exiting;
+    }
+
+    /**
+     * Called when {@link MdnsAdvertiser} sent an advertisement for the given service.
+     */
+    public void onAdvertisementSent(int serviceId) {
+        final ServiceRegistration registration = mServices.get(serviceId);
+        if (registration == null) return;
+
+        final long now = SystemClock.elapsedRealtime();
+        for (RecordInfo<?> record : registration.allRecords) {
+            record.lastSentTimeMs = now;
+        }
+    }
+
+    /**
+     * Compute:
+     * 2001:db8::1 --> 1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.B.D.0.1.0.0.2.ip6.arpa
+     *
+     * Or:
+     * 192.0.2.123 --> 123.2.0.192.in-addr.arpa
+     */
+    @VisibleForTesting
+    public static String[] getReverseDnsAddress(@NonNull InetAddress addr) {
+        // xxx.xxx.xxx.xxx.in-addr.arpa (up to 28 characters)
+        // or 32 hex characters separated by dots + .ip6.arpa
+        final byte[] addrBytes = addr.getAddress();
+        final List<String> out = new ArrayList<>();
+        if (addr instanceof Inet4Address) {
+            for (int i = addrBytes.length - 1; i >= 0; i--) {
+                out.add(String.valueOf(Byte.toUnsignedInt(addrBytes[i])));
+            }
+            out.add("in-addr");
+        } else {
+            final String hexAddr = HexDump.toHexString(addrBytes);
+
+            for (int i = hexAddr.length() - 1; i >= 0; i--) {
+                out.add(String.valueOf(hexAddr.charAt(i)));
+            }
+            out.add("ip6");
+        }
+        out.add("arpa");
+
+        return out.toArray(new String[0]);
+    }
+
     private static String[] splitFullyQualifiedName(
             @NonNull NsdServiceInfo info, @NonNull String[] serviceType) {
         final String[] split = new String[serviceType.length + 1];
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
index 961f0f0..650607d 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
@@ -22,11 +22,11 @@
 import android.os.SystemClock
 import com.android.internal.util.HexDump
 import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
+import com.android.server.connectivity.mdns.MdnsAnnouncer.BaseAnnouncementInfo
+import com.android.server.connectivity.mdns.MdnsRecordRepository.getReverseDnsAddress
 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
 import com.android.testutils.DevSdkIgnoreRunner
 import java.net.DatagramPacket
-import java.net.Inet6Address
-import java.net.InetAddress
 import kotlin.test.assertEquals
 import kotlin.test.assertTrue
 import org.junit.After
@@ -68,7 +68,7 @@
     private class TestAnnouncementInfo(
         announcedRecords: List<MdnsRecord>,
         additionalRecords: List<MdnsRecord>
-    ) : AnnouncementInfo(announcedRecords, additionalRecords) {
+    ) : AnnouncementInfo(1 /* serviceId */, announcedRecords, additionalRecords) {
         override fun getDelayMs(nextIndex: Int) =
                 if (nextIndex < FIRST_ANNOUNCES_COUNT) {
                     FIRST_ANNOUNCES_DELAY
@@ -82,7 +82,7 @@
         val replySender = MdnsReplySender(thread.looper, socket, buffer)
         @Suppress("UNCHECKED_CAST")
         val cb = mock(MdnsPacketRepeater.PacketRepeaterCallback::class.java)
-                as MdnsPacketRepeater.PacketRepeaterCallback<AnnouncementInfo>
+                as MdnsPacketRepeater.PacketRepeaterCallback<BaseAnnouncementInfo>
         val announcer = MdnsAnnouncer("testiface", thread.looper, replySender, cb)
         /*
         The expected packet replicates records announced when registering a service, as observed in
@@ -150,8 +150,8 @@
         val v6Addr1 = parseNumericAddress("2001:DB8::123")
         val v6Addr2 = parseNumericAddress("2001:DB8::456")
         val v4AddrRev = arrayOf("123", "0", "2", "192", "in-addr", "arpa")
-        val v6Addr1Rev = getReverseV6AddressName(v6Addr1)
-        val v6Addr2Rev = getReverseV6AddressName(v6Addr2)
+        val v6Addr1Rev = getReverseDnsAddress(v6Addr1)
+        val v6Addr2Rev = getReverseDnsAddress(v6Addr2)
 
         val announcedRecords = listOf(
                 // Reverse address records
@@ -267,13 +267,3 @@
         }
     }
 }
-
-/**
- * Compute 2001:db8::1 --> 1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.B.D.1.0.0.2.ip6.arpa
- */
-private fun getReverseV6AddressName(addr: InetAddress): Array<String> {
-    assertTrue(addr is Inet6Address)
-    return addr.address.flatMapTo(mutableListOf("arpa", "ip6")) {
-        HexDump.toHexString(it).toCharArray().map(Char::toString)
-    }.reversed().toTypedArray()
-}
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
index ad22305..2cb0850 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
@@ -22,6 +22,8 @@
 import android.os.Build
 import android.os.HandlerThread
 import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
+import com.android.server.connectivity.mdns.MdnsAnnouncer.BaseAnnouncementInfo
+import com.android.server.connectivity.mdns.MdnsAnnouncer.ExitAnnouncementInfo
 import com.android.server.connectivity.mdns.MdnsInterfaceAdvertiser.EXIT_ANNOUNCEMENT_DELAY_MS
 import com.android.server.connectivity.mdns.MdnsPacketRepeater.PacketRepeaterCallback
 import com.android.server.connectivity.mdns.MdnsProber.ProbingInfo
@@ -35,8 +37,10 @@
 import org.mockito.ArgumentCaptor
 import org.mockito.Mockito.any
 import org.mockito.Mockito.anyInt
+import org.mockito.Mockito.doAnswer
 import org.mockito.Mockito.doReturn
 import org.mockito.Mockito.mock
+import org.mockito.Mockito.times
 import org.mockito.Mockito.verify
 
 private const val LOG_TAG = "testlogtag"
@@ -66,7 +70,7 @@
     private val probeCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
             as ArgumentCaptor<PacketRepeaterCallback<ProbingInfo>>
     private val announceCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
-            as ArgumentCaptor<PacketRepeaterCallback<AnnouncementInfo>>
+            as ArgumentCaptor<PacketRepeaterCallback<BaseAnnouncementInfo>>
 
     private val probeCb get() = probeCbCaptor.value
     private val announceCb get() = announceCbCaptor.value
@@ -82,7 +86,21 @@
         doReturn(announcer).`when`(deps).makeMdnsAnnouncer(any(), any(), any(), any())
         doReturn(prober).`when`(deps).makeMdnsProber(any(), any(), any(), any())
 
-        doReturn(-1).`when`(repository).addService(anyInt(), any())
+        val knownServices = mutableSetOf<Int>()
+        doAnswer { inv ->
+            knownServices.add(inv.getArgument(0))
+            -1
+        }.`when`(repository).addService(anyInt(), any())
+        doAnswer { inv ->
+            knownServices.remove(inv.getArgument(0))
+            null
+        }.`when`(repository).removeService(anyInt())
+        doAnswer {
+            knownServices.toIntArray().also { knownServices.clear() }
+        }.`when`(repository).clearServices()
+        doAnswer { inv ->
+            knownServices.contains(inv.getArgument(0))
+        }.`when`(repository).hasActiveService(anyInt())
         thread.start()
         advertiser.start()
 
@@ -97,18 +115,7 @@
 
     @Test
     fun testAddRemoveService() {
-        val testProbingInfo = mock(ProbingInfo::class.java)
-        doReturn(TEST_SERVICE_ID_1).`when`(testProbingInfo).serviceId
-        doReturn(testProbingInfo).`when`(repository).setServiceProbing(TEST_SERVICE_ID_1)
-
-        advertiser.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
-        verify(repository).addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
-        verify(prober).startProbing(testProbingInfo)
-
-        // Simulate probing success: continues to announcing
-        val testAnnouncementInfo = mock(AnnouncementInfo::class.java)
-        doReturn(testAnnouncementInfo).`when`(repository).onProbingSucceeded(testProbingInfo)
-        probeCb.onFinished(testProbingInfo)
+        val testAnnouncementInfo = addServiceAndFinishProbing(TEST_SERVICE_ID_1, TEST_SERVICE_1)
 
         verify(announcer).startSending(TEST_SERVICE_ID_1, testAnnouncementInfo,
                 0L /* initialDelayMs */)
@@ -117,13 +124,53 @@
         verify(cb).onRegisterServiceSucceeded(advertiser, TEST_SERVICE_ID_1)
 
         // Remove the service: expect exit announcements
-        val testExitInfo = mock(AnnouncementInfo::class.java)
+        val testExitInfo = mock(ExitAnnouncementInfo::class.java)
         doReturn(testExitInfo).`when`(repository).exitService(TEST_SERVICE_ID_1)
         advertiser.removeService(TEST_SERVICE_ID_1)
 
+        verify(prober).stop(TEST_SERVICE_ID_1)
+        verify(announcer).stop(TEST_SERVICE_ID_1)
         verify(announcer).startSending(TEST_SERVICE_ID_1, testExitInfo, EXIT_ANNOUNCEMENT_DELAY_MS)
 
-        // TODO: after exit announcements are implemented, verify that announceCb.onFinished causes
-        // cb.onDestroyed to be called.
+        // Exit announcements finish: the advertiser has no left service and destroys itself
+        announceCb.onFinished(testExitInfo)
+        thread.waitForIdle(TIMEOUT_MS)
+        verify(cb).onDestroyed(socket)
+    }
+
+    @Test
+    fun testDoubleRemove() {
+        addServiceAndFinishProbing(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+
+        val testExitInfo = mock(ExitAnnouncementInfo::class.java)
+        doReturn(testExitInfo).`when`(repository).exitService(TEST_SERVICE_ID_1)
+        advertiser.removeService(TEST_SERVICE_ID_1)
+
+        verify(prober).stop(TEST_SERVICE_ID_1)
+        verify(announcer).stop(TEST_SERVICE_ID_1)
+        verify(announcer).startSending(TEST_SERVICE_ID_1, testExitInfo, EXIT_ANNOUNCEMENT_DELAY_MS)
+
+        doReturn(false).`when`(repository).hasActiveService(TEST_SERVICE_ID_1)
+        advertiser.removeService(TEST_SERVICE_ID_1)
+        // Prober, announcer were still stopped only one time
+        verify(prober, times(1)).stop(TEST_SERVICE_ID_1)
+        verify(announcer, times(1)).stop(TEST_SERVICE_ID_1)
+    }
+
+    private fun addServiceAndFinishProbing(serviceId: Int, serviceInfo: NsdServiceInfo):
+            AnnouncementInfo {
+        val testProbingInfo = mock(ProbingInfo::class.java)
+        doReturn(serviceId).`when`(testProbingInfo).serviceId
+        doReturn(testProbingInfo).`when`(repository).setServiceProbing(serviceId)
+
+        advertiser.addService(serviceId, serviceInfo)
+        verify(repository).addService(serviceId, serviceInfo)
+        verify(prober).startProbing(testProbingInfo)
+
+        // Simulate probing success: continues to announcing
+        val testAnnouncementInfo = mock(AnnouncementInfo::class.java)
+        doReturn(testAnnouncementInfo).`when`(repository).onProbingSucceeded(testProbingInfo)
+        probeCb.onFinished(testProbingInfo)
+        return testAnnouncementInfo
     }
 }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
index 502a36a..29d0854 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
@@ -17,10 +17,12 @@
 package com.android.server.connectivity.mdns
 
 import android.net.InetAddresses.parseNumericAddress
+import android.net.LinkAddress
 import android.net.nsd.NsdServiceInfo
 import android.os.Build
 import android.os.HandlerThread
 import com.android.server.connectivity.mdns.MdnsRecordRepository.Dependencies
+import com.android.server.connectivity.mdns.MdnsRecordRepository.getReverseDnsAddress
 import com.android.testutils.DevSdkIgnoreRule
 import com.android.testutils.DevSdkIgnoreRunner
 import java.net.NetworkInterface
@@ -28,6 +30,7 @@
 import kotlin.test.assertContentEquals
 import kotlin.test.assertEquals
 import kotlin.test.assertFailsWith
+import kotlin.test.assertFalse
 import kotlin.test.assertNotNull
 import kotlin.test.assertTrue
 import org.junit.After
@@ -39,10 +42,10 @@
 private const val TEST_SERVICE_ID_2 = 43
 private const val TEST_PORT = 12345
 private val TEST_HOSTNAME = arrayOf("Android_000102030405060708090A0B0C0D0E0F", "local")
-private val TEST_ADDRESSES = arrayOf(
-        parseNumericAddress("192.0.2.111"),
-        parseNumericAddress("2001:db8::111"),
-        parseNumericAddress("2001:db8::222"))
+private val TEST_ADDRESSES = listOf(
+        LinkAddress(parseNumericAddress("192.0.2.111"), 24),
+        LinkAddress(parseNumericAddress("2001:db8::111"), 64),
+        LinkAddress(parseNumericAddress("2001:db8::222"), 64))
 
 private val TEST_SERVICE_1 = NsdServiceInfo().apply {
     serviceType = "_testservice._tcp"
@@ -50,6 +53,12 @@
     port = TEST_PORT
 }
 
+private val TEST_SERVICE_2 = NsdServiceInfo().apply {
+    serviceType = "_testservice._tcp"
+    serviceName = "MyOtherTestService"
+    port = TEST_PORT
+}
+
 @RunWith(DevSdkIgnoreRunner::class)
 @DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
 class MdnsRecordRepositoryTest {
@@ -57,7 +66,7 @@
     private val deps = object : Dependencies() {
         override fun getHostname() = TEST_HOSTNAME
         override fun getInterfaceInetAddresses(iface: NetworkInterface) =
-                Collections.enumeration(TEST_ADDRESSES.toList())
+                Collections.enumeration(TEST_ADDRESSES.map { it.address })
     }
 
     @Before
@@ -113,9 +122,71 @@
     }
 
     @Test
+    fun testInvalidReuseOfServiceId() {
+        val repository = MdnsRecordRepository(thread.looper, deps)
+        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+        assertFailsWith(IllegalArgumentException::class) {
+            repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_2)
+        }
+    }
+
+    @Test
+    fun testHasActiveService() {
+        val repository = MdnsRecordRepository(thread.looper, deps)
+        assertFalse(repository.hasActiveService(TEST_SERVICE_ID_1))
+
+        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+        assertTrue(repository.hasActiveService(TEST_SERVICE_ID_1))
+
+        val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
+        repository.onProbingSucceeded(probingInfo)
+        repository.onAdvertisementSent(TEST_SERVICE_ID_1)
+        assertTrue(repository.hasActiveService(TEST_SERVICE_ID_1))
+
+        repository.exitService(TEST_SERVICE_ID_1)
+        assertFalse(repository.hasActiveService(TEST_SERVICE_ID_1))
+    }
+
+    @Test
+    fun testExitAnnouncements() {
+        val repository = MdnsRecordRepository(thread.looper, deps)
+        repository.updateAddresses(TEST_ADDRESSES)
+
+        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+        val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
+        repository.onProbingSucceeded(probingInfo)
+        repository.onAdvertisementSent(TEST_SERVICE_ID_1)
+
+        val exitAnnouncement = repository.exitService(TEST_SERVICE_ID_1)
+        assertNotNull(exitAnnouncement)
+        assertEquals(1, repository.servicesCount)
+        val packet = exitAnnouncement.getPacket(0)
+
+        assertEquals(0x8400 /* response, authoritative */, packet.flags)
+        assertEquals(0, packet.questions.size)
+        assertEquals(0, packet.authorityRecords.size)
+        assertEquals(0, packet.additionalRecords.size)
+
+        assertContentEquals(listOf(
+                MdnsPointerRecord(
+                        arrayOf("_testservice", "_tcp", "local"),
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        0L /* ttlMillis */,
+                        arrayOf("MyTestService", "_testservice", "_tcp", "local"))
+        ), packet.answers)
+
+        repository.removeService(TEST_SERVICE_ID_1)
+        assertEquals(0, repository.servicesCount)
+    }
+
+    @Test
     fun testExitingServiceReAdded() {
         val repository = MdnsRecordRepository(thread.looper, deps)
         repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+        val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
+        repository.onProbingSucceeded(probingInfo)
+        repository.onAdvertisementSent(TEST_SERVICE_ID_1)
         repository.exitService(TEST_SERVICE_ID_1)
 
         assertEquals(TEST_SERVICE_ID_1, repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1))
@@ -124,4 +195,131 @@
         repository.removeService(TEST_SERVICE_ID_2)
         assertEquals(0, repository.servicesCount)
     }
+
+    @Test
+    fun testOnProbingSucceeded() {
+        val repository = MdnsRecordRepository(thread.looper, deps)
+        repository.updateAddresses(TEST_ADDRESSES)
+
+        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+        val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
+        val announcementInfo = repository.onProbingSucceeded(probingInfo)
+        val packet = announcementInfo.getPacket(0)
+
+        assertEquals(0x8400 /* response, authoritative */, packet.flags)
+        assertEquals(0, packet.questions.size)
+        assertEquals(0, packet.authorityRecords.size)
+
+        val serviceType = arrayOf("_testservice", "_tcp", "local")
+        val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
+        val v4AddrRev = getReverseDnsAddress(TEST_ADDRESSES[0].address)
+        val v6Addr1Rev = getReverseDnsAddress(TEST_ADDRESSES[1].address)
+        val v6Addr2Rev = getReverseDnsAddress(TEST_ADDRESSES[2].address)
+
+        assertContentEquals(listOf(
+                // Reverse address and address records for the hostname
+                MdnsPointerRecord(v4AddrRev,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        120000L /* ttlMillis */,
+                        TEST_HOSTNAME),
+                MdnsInetAddressRecord(TEST_HOSTNAME,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        120000L /* ttlMillis */,
+                        TEST_ADDRESSES[0].address),
+                MdnsPointerRecord(v6Addr1Rev,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        120000L /* ttlMillis */,
+                        TEST_HOSTNAME),
+                MdnsInetAddressRecord(TEST_HOSTNAME,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        120000L /* ttlMillis */,
+                        TEST_ADDRESSES[1].address),
+                MdnsPointerRecord(v6Addr2Rev,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        120000L /* ttlMillis */,
+                        TEST_HOSTNAME),
+                MdnsInetAddressRecord(TEST_HOSTNAME,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        120000L /* ttlMillis */,
+                        TEST_ADDRESSES[2].address),
+                // Service registration records (RFC6763)
+                MdnsPointerRecord(
+                        serviceType,
+                        0L /* receiptTimeMillis */,
+                        // Not a unique name owned by the announcer, so cacheFlush=false
+                        false /* cacheFlush */,
+                        4500000L /* ttlMillis */,
+                        serviceName),
+                MdnsServiceRecord(
+                        serviceName,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        120000L /* ttlMillis */,
+                        0 /* servicePriority */,
+                        0 /* serviceWeight */,
+                        TEST_PORT /* servicePort */,
+                        TEST_HOSTNAME),
+                MdnsTextRecord(
+                        serviceName,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        4500000L /* ttlMillis */,
+                        emptyList() /* entries */),
+                // Service type enumeration record (RFC6763 9.)
+                MdnsPointerRecord(
+                        arrayOf("_services", "_dns-sd", "_udp", "local"),
+                        0L /* receiptTimeMillis */,
+                        false /* cacheFlush */,
+                        4500000L /* ttlMillis */,
+                        serviceType)
+        ), packet.answers)
+
+        assertContentEquals(listOf(
+                MdnsNsecRecord(v4AddrRev,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        120000L /* ttlMillis */,
+                        v4AddrRev,
+                        intArrayOf(MdnsRecord.TYPE_PTR)),
+                MdnsNsecRecord(TEST_HOSTNAME,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        120000L /* ttlMillis */,
+                        TEST_HOSTNAME,
+                        intArrayOf(MdnsRecord.TYPE_A, MdnsRecord.TYPE_AAAA)),
+                MdnsNsecRecord(v6Addr1Rev,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        120000L /* ttlMillis */,
+                        v6Addr1Rev,
+                        intArrayOf(MdnsRecord.TYPE_PTR)),
+                MdnsNsecRecord(v6Addr2Rev,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        120000L /* ttlMillis */,
+                        v6Addr2Rev,
+                        intArrayOf(MdnsRecord.TYPE_PTR)),
+                MdnsNsecRecord(serviceName,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        4500000L /* ttlMillis */,
+                        serviceName,
+                        intArrayOf(MdnsRecord.TYPE_TXT, MdnsRecord.TYPE_SRV))
+        ), packet.additionalRecords)
+    }
+
+    @Test
+    fun testGetReverseDnsAddress() {
+        val expectedV6 = "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.B.D.0.1.0.0.2.ip6.arpa"
+                .split(".").toTypedArray()
+        assertContentEquals(expectedV6, getReverseDnsAddress(parseNumericAddress("2001:db8::1")))
+        val expectedV4 = "123.2.0.192.in-addr.arpa".split(".").toTypedArray()
+        assertContentEquals(expectedV4, getReverseDnsAddress(parseNumericAddress("192.0.2.123")))
+    }
 }