[mdns] add hidden API for public key

This commit adds support of registering a public key for a host/service.
This is required to enable Advertising Proxy feature for Thread devices.

For example:
```
NsdServiceInfo info = new NsdServiceInfo();

info.setServiceName("My Service");
info.setServiceType("_test._tcp");
info.setHostname("MyHost");
info.setHostAddresses(List.of(address1, address2));
info.setPublicKey(/* KEY RDATA */);

nsdManager.registerService(info, PROTOCOL_DNS_SD, listener);
```

Bug: 317946010

Change-Id: I367ebff8119d5c1dff0410c85e6fb86dca6c66b8
diff --git a/framework-t/src/android/net/nsd/NsdManager.java b/framework-t/src/android/net/nsd/NsdManager.java
index 1001423..48d40e6 100644
--- a/framework-t/src/android/net/nsd/NsdManager.java
+++ b/framework-t/src/android/net/nsd/NsdManager.java
@@ -389,6 +389,7 @@
     }
 
     private static final int FIRST_LISTENER_KEY = 1;
+    private static final int DNSSEC_PROTOCOL = 3;
 
     private final INsdServiceConnector mService;
     private final Context mContext;
@@ -1754,45 +1755,132 @@
         }
     }
 
+    private enum ServiceValidationType {
+        NO_SERVICE,
+        HAS_SERVICE, // A service with a positive port
+        HAS_SERVICE_ZERO_PORT, // A service with a zero port
+    }
+
+    private enum HostValidationType {
+        DEFAULT_HOST, // No host is specified so the default host will be used
+        CUSTOM_HOST, // A custom host with addresses is specified
+        CUSTOM_HOST_NO_ADDRESS, // A custom host without address is specified
+    }
+
+    private enum PublicKeyValidationType {
+        NO_KEY,
+        HAS_KEY,
+    }
+
+    /**
+     * Check if the service is valid for registration and classify it as one of {@link
+     * ServiceValidationType}.
+     */
+    private static ServiceValidationType validateService(NsdServiceInfo serviceInfo) {
+        final boolean hasServiceName = !TextUtils.isEmpty(serviceInfo.getServiceName());
+        final boolean hasServiceType = !TextUtils.isEmpty(serviceInfo.getServiceType());
+        if (!hasServiceName && !hasServiceType && serviceInfo.getPort() == 0) {
+            return ServiceValidationType.NO_SERVICE;
+        }
+        if (hasServiceName && hasServiceType) {
+            if (serviceInfo.getPort() < 0) {
+                throw new IllegalArgumentException("Invalid port");
+            }
+            if (serviceInfo.getPort() == 0) {
+                return ServiceValidationType.HAS_SERVICE_ZERO_PORT;
+            }
+            return ServiceValidationType.HAS_SERVICE;
+        }
+        throw new IllegalArgumentException("The service name or the service type is missing");
+    }
+
+    /**
+     * Check if the host is valid for registration and classify it as one of {@link
+     * HostValidationType}.
+     */
+    private static HostValidationType validateHost(NsdServiceInfo serviceInfo) {
+        final boolean hasHostname = !TextUtils.isEmpty(serviceInfo.getHostname());
+        final boolean hasHostAddresses = !CollectionUtils.isEmpty(serviceInfo.getHostAddresses());
+        if (!hasHostname) {
+            // Keep compatible with the legacy behavior: It's allowed to set host
+            // addresses for a service registration although the host addresses
+            // won't be registered. To register the addresses for a host, the
+            // hostname must be specified.
+            return HostValidationType.DEFAULT_HOST;
+        }
+        if (!hasHostAddresses) {
+            return HostValidationType.CUSTOM_HOST_NO_ADDRESS;
+        }
+        return HostValidationType.CUSTOM_HOST;
+    }
+
+    /**
+     * Check if the public key is valid for registration and classify it as one of {@link
+     * PublicKeyValidationType}.
+     *
+     * <p>For simplicity, it only checks if the protocol is DNSSEC and the RDATA is not fewer than 4
+     * bytes. See RFC 3445 Section 3.
+     */
+    private static PublicKeyValidationType validatePublicKey(NsdServiceInfo serviceInfo) {
+        byte[] publicKey = serviceInfo.getPublicKey();
+        if (publicKey == null) {
+            return PublicKeyValidationType.NO_KEY;
+        }
+        if (publicKey.length < 4) {
+            throw new IllegalArgumentException("The public key should be at least 4 bytes long");
+        }
+        int protocol = publicKey[2];
+        if (protocol == DNSSEC_PROTOCOL) {
+            return PublicKeyValidationType.HAS_KEY;
+        }
+        throw new IllegalArgumentException(
+                "The public key's protocol ("
+                        + protocol
+                        + ") is invalid. It should be DNSSEC_PROTOCOL (3)");
+    }
+
     /**
      * Check if the {@link NsdServiceInfo} is valid for registration.
      *
-     * The following can be registered:
-     * - A service with an optional host.
-     * - A hostname with addresses.
+     * <p>Firstly, check if service, host and public key are all valid respectively. Then check if
+     * the combination of service, host and public key is valid.
      *
-     * Note that:
-     * - When registering a service, the service name, service type and port must be specified. If
-     *   hostname is specified, the host addresses can optionally be specified.
-     * - When registering a host without a service, the addresses must be specified.
+     * <p>If the {@code serviceInfo} is invalid, throw an {@link IllegalArgumentException}
+     * describing the reason.
+     *
+     * <p>There are the invalid combinations of service, host and public key:
+     *
+     * <ul>
+     *   <li>Neither service nor host is specified.
+     *   <li>No public key is specified and the service has a zero port.
+     *   <li>The registration only contains the hostname but addresses are missing.
+     * </ul>
+     *
+     * <p>Keys are used to reserve hostnames or service names while the service/host is temporarily
+     * inactive, so registrations with a key and just a hostname or a service name are acceptable.
      *
      * @hide
      */
     public static void checkServiceInfoForRegistration(NsdServiceInfo serviceInfo) {
         Objects.requireNonNull(serviceInfo, "NsdServiceInfo cannot be null");
-        boolean hasServiceName = !TextUtils.isEmpty(serviceInfo.getServiceName());
-        boolean hasServiceType = !TextUtils.isEmpty(serviceInfo.getServiceType());
-        boolean hasHostname = !TextUtils.isEmpty(serviceInfo.getHostname());
-        boolean hasHostAddresses = !CollectionUtils.isEmpty(serviceInfo.getHostAddresses());
 
-        if (serviceInfo.getPort() < 0) {
-            throw new IllegalArgumentException("Invalid port");
+        final ServiceValidationType serviceValidation = validateService(serviceInfo);
+        final HostValidationType hostValidation = validateHost(serviceInfo);
+        final PublicKeyValidationType publicKeyValidation = validatePublicKey(serviceInfo);
+
+        if (serviceValidation == ServiceValidationType.NO_SERVICE
+                && hostValidation == HostValidationType.DEFAULT_HOST) {
+            throw new IllegalArgumentException("Nothing to register");
         }
-
-        if (hasServiceType || hasServiceName || (serviceInfo.getPort() > 0)) {
-            if (!(hasServiceType && hasServiceName && (serviceInfo.getPort() > 0))) {
-                throw new IllegalArgumentException(
-                        "The service type, service name or port is missing");
+        if (publicKeyValidation == PublicKeyValidationType.NO_KEY) {
+            if (serviceValidation == ServiceValidationType.HAS_SERVICE_ZERO_PORT) {
+                throw new IllegalArgumentException("The port is missing");
             }
-        }
-
-        if (!hasServiceType && !hasHostname) {
-            throw new IllegalArgumentException("No service or host specified in NsdServiceInfo");
-        }
-
-        if (!hasServiceType && hasHostname && !hasHostAddresses) {
-            // TODO: b/317946010 - This may be allowed when it supports registering KEY RR.
-            throw new IllegalArgumentException("No host addresses specified in NsdServiceInfo");
+            if (serviceValidation == ServiceValidationType.NO_SERVICE
+                    && hostValidation == HostValidationType.CUSTOM_HOST_NO_ADDRESS) {
+                throw new IllegalArgumentException(
+                        "The host addresses must be specified unless there is a service");
+            }
         }
     }
 }
diff --git a/framework-t/src/android/net/nsd/NsdServiceInfo.java b/framework-t/src/android/net/nsd/NsdServiceInfo.java
index 9491a9c..2f675a9 100644
--- a/framework-t/src/android/net/nsd/NsdServiceInfo.java
+++ b/framework-t/src/android/net/nsd/NsdServiceInfo.java
@@ -37,6 +37,7 @@
 import java.nio.charset.StandardCharsets;
 import java.time.Instant;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
@@ -69,6 +70,9 @@
     private int mPort;
 
     @Nullable
+    private byte[] mPublicKey;
+
+    @Nullable
     private Network mNetwork;
 
     private int mInterfaceIndex;
@@ -220,6 +224,40 @@
     }
 
     /**
+     * Set the public key RDATA to be advertised in a KEY RR (RFC 2535).
+     *
+     * <p>This is the public key of the key pair used for signing a DNS message (e.g. SRP). Clients
+     * typically don't need this information, but the KEY RR is usually published to claim the use
+     * of the DNS name so that another mDNS advertiser can't take over the ownership during a
+     * temporary power down of the original host device.
+     *
+     * <p>When the public key is set to non-null, exactly one KEY RR will be advertised for each of
+     * the service and host name if they are not null.
+     *
+     * @hide // For Thread only
+     */
+    public void setPublicKey(@Nullable byte[] publicKey) {
+        if (publicKey == null) {
+            mPublicKey = null;
+            return;
+        }
+        mPublicKey = Arrays.copyOf(publicKey, publicKey.length);
+    }
+
+    /**
+     * Get the public key RDATA in the KEY RR (RFC 2535) or {@code null} if no KEY RR exists.
+     *
+     * @hide // For Thread only
+     */
+    @Nullable
+    public byte[] getPublicKey() {
+        if (mPublicKey == null) {
+            return null;
+        }
+        return Arrays.copyOf(mPublicKey, mPublicKey.length);
+    }
+
+    /**
      * Unpack txt information from a base-64 encoded byte array.
      *
      * @param txtRecordsRawBytes The raw base64 encoded byte array.
@@ -622,6 +660,7 @@
         }
         dest.writeString(mHostname);
         dest.writeLong(mExpirationTime != null ? mExpirationTime.getEpochSecond() : -1);
+        dest.writeByteArray(mPublicKey);
     }
 
     /** Implement the Parcelable interface */
@@ -654,6 +693,7 @@
                 info.mHostname = in.readString();
                 final long seconds = in.readLong();
                 info.setExpirationTime(seconds < 0 ? null : Instant.ofEpochSecond(seconds));
+                info.mPublicKey = in.createByteArray();
                 return info;
             }
 
diff --git a/tests/common/java/android/net/nsd/NsdServiceInfoTest.java b/tests/common/java/android/net/nsd/NsdServiceInfoTest.java
index 8e89037..21e34ab 100644
--- a/tests/common/java/android/net/nsd/NsdServiceInfoTest.java
+++ b/tests/common/java/android/net/nsd/NsdServiceInfoTest.java
@@ -16,6 +16,7 @@
 
 package android.net.nsd;
 
+import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertThrows;
@@ -51,6 +52,23 @@
 
     private static final InetAddress IPV4_ADDRESS = InetAddresses.parseNumericAddress("192.0.2.1");
     private static final InetAddress IPV6_ADDRESS = InetAddresses.parseNumericAddress("2001:db8::");
+    private static final byte[] PUBLIC_KEY_RDATA = new byte[] {
+            (byte) 0x02, (byte)0x01,  // flag
+            (byte) 0x03, // protocol
+            (byte) 0x0d, // algorithm
+            // 64-byte public key below
+            (byte) 0xC1, (byte) 0x41, (byte) 0xD0, (byte) 0x63, (byte) 0x79, (byte) 0x60,
+            (byte) 0xB9, (byte) 0x8C, (byte) 0xBC, (byte) 0x12, (byte) 0xCF, (byte) 0xCA,
+            (byte) 0x22, (byte) 0x1D, (byte) 0x28, (byte) 0x79, (byte) 0xDA, (byte) 0xC2,
+            (byte) 0x6E, (byte) 0xE5, (byte) 0xB4, (byte) 0x60, (byte) 0xE9, (byte) 0x00,
+            (byte) 0x7C, (byte) 0x99, (byte) 0x2E, (byte) 0x19, (byte) 0x02, (byte) 0xD8,
+            (byte) 0x97, (byte) 0xC3, (byte) 0x91, (byte) 0xB0, (byte) 0x37, (byte) 0x64,
+            (byte) 0xD4, (byte) 0x48, (byte) 0xF7, (byte) 0xD0, (byte) 0xC7, (byte) 0x72,
+            (byte) 0xFD, (byte) 0xB0, (byte) 0x3B, (byte) 0x1D, (byte) 0x9D, (byte) 0x6D,
+            (byte) 0x52, (byte) 0xFF, (byte) 0x88, (byte) 0x86, (byte) 0x76, (byte) 0x9E,
+            (byte) 0x8E, (byte) 0x23, (byte) 0x62, (byte) 0x51, (byte) 0x35, (byte) 0x65,
+            (byte) 0x27, (byte) 0x09, (byte) 0x62, (byte) 0xD3
+    };
 
     @Test
     public void testLimits() throws Exception {
@@ -120,6 +138,7 @@
         fullInfo.setPort(4242);
         fullInfo.setHostAddresses(List.of(IPV4_ADDRESS));
         fullInfo.setHostname("home");
+        fullInfo.setPublicKey(PUBLIC_KEY_RDATA);
         fullInfo.setNetwork(new Network(123));
         fullInfo.setInterfaceIndex(456);
         checkParcelable(fullInfo);
@@ -136,6 +155,7 @@
         attributedInfo.setPort(4242);
         attributedInfo.setHostAddresses(List.of(IPV6_ADDRESS, IPV4_ADDRESS));
         attributedInfo.setHostname("home");
+        attributedInfo.setPublicKey(PUBLIC_KEY_RDATA);
         attributedInfo.setAttribute("color", "pink");
         attributedInfo.setAttribute("sound", (new String("にゃあ")).getBytes("UTF-8"));
         attributedInfo.setAttribute("adorable", (String) null);
@@ -172,6 +192,7 @@
         assertEquals(original.getServiceType(), result.getServiceType());
         assertEquals(original.getHost(), result.getHost());
         assertEquals(original.getHostname(), result.getHostname());
+        assertArrayEquals(original.getPublicKey(), result.getPublicKey());
         assertTrue(original.getPort() == result.getPort());
         assertEquals(original.getNetwork(), result.getNetwork());
         assertEquals(original.getInterfaceIndex(), result.getInterfaceIndex());
diff --git a/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt b/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt
index 5ba6c4c..93cec9c 100644
--- a/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt
+++ b/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt
@@ -287,6 +287,12 @@
 ): TestDnsPacket? = pollForMdnsPacket(timeoutMs) { it.isQueryFor(recordName, *requiredTypes) }
 
 fun TapPacketReader.pollForReply(
+    recordName: String,
+    type: Int,
+    timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
+): TestDnsPacket? = pollForMdnsPacket(timeoutMs) { it.isReplyFor(recordName, type) }
+
+fun TapPacketReader.pollForReply(
     serviceName: String,
     serviceType: String,
     timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
diff --git a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
index 6dd4857..6c6f6a3 100644
--- a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
+++ b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
@@ -81,7 +81,9 @@
 import com.android.compatibility.common.util.SystemUtil
 import com.android.modules.utils.build.SdkLevel.isAtLeastU
 import com.android.net.module.util.DnsPacket
+import com.android.net.module.util.DnsPacket.ANSECTION
 import com.android.net.module.util.HexDump
+import com.android.net.module.util.HexDump.hexStringToByteArray
 import com.android.net.module.util.NetworkStackConstants.IPV6_ADDR_LEN
 import com.android.net.module.util.PacketBuilder
 import com.android.testutils.ConnectivityModuleTest
@@ -96,6 +98,7 @@
 import com.android.testutils.TestableNetworkAgent
 import com.android.testutils.TestableNetworkAgent.CallbackEntry.OnNetworkCreated
 import com.android.testutils.TestableNetworkCallback
+import com.android.testutils.assertContainsExactly
 import com.android.testutils.assertEmpty
 import com.android.testutils.filters.CtsNetTestCasesMaxTargetSdk30
 import com.android.testutils.filters.CtsNetTestCasesMaxTargetSdk33
@@ -138,6 +141,9 @@
 private const val DBG = false
 private const val TEST_PORT = 12345
 private const val MDNS_PORT = 5353.toShort()
+private const val TYPE_KEY = 25
+private const val QCLASS_INTERNET = 0x0001
+private const val NAME_RECORDS_TTL_MILLIS: Long = 120
 private val multicastIpv6Addr = parseNumericAddress("ff02::fb") as Inet6Address
 private val testSrcAddr = parseNumericAddress("2001:db8::123") as Inet6Address
 
@@ -167,6 +173,12 @@
     private val serviceType2 = "_nmt%09d._tcp".format(Random().nextInt(1_000_000_000))
     private val customHostname = "NsdTestHost%09d".format(Random().nextInt(1_000_000_000))
     private val customHostname2 = "NsdTestHost%09d".format(Random().nextInt(1_000_000_000))
+    private val publicKey = hexStringToByteArray(
+            "0201030dc141d0637960b98cbc12cfca"
+                    + "221d2879dac26ee5b460e9007c992e19"
+                    + "02d897c391b03764d448f7d0c772fdb0"
+                    + "3b1d9d6d52ff8886769e8e2362513565"
+                    + "270962d3")
     private val handlerThread = HandlerThread(NsdManagerTest::class.java.simpleName)
     private val ctsNetUtils by lazy{ CtsNetUtils(context) }
 
@@ -2266,6 +2278,165 @@
     }
 
     @Test
+    fun testAdvertising_registerServiceAndPublicKey_keyAnnounced() {
+        val si = NsdServiceInfo().also {
+            it.network = testNetwork1.network
+            it.serviceType = serviceType
+            it.serviceName = serviceName
+            it.port = TEST_PORT
+            it.publicKey = publicKey
+        }
+        val packetReader = TapPacketReader(Handler(handlerThread.looper),
+                testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+        packetReader.startAsyncForTest()
+        handlerThread.waitForIdle(TIMEOUT_MS)
+
+        val registrationRecord = NsdRegistrationRecord()
+        val discoveryRecord = NsdDiscoveryRecord()
+        tryTest {
+            registerService(registrationRecord, si)
+
+            val announcement = packetReader.pollForReply(
+                "$serviceName.$serviceType.local",
+                TYPE_KEY
+            )
+            assertNotNull(announcement)
+            val keyRecords = announcement.records[ANSECTION].filter { it.nsType == TYPE_KEY }
+            assertEquals(1, keyRecords.size)
+            val actualRecord = keyRecords.get(0)
+            assertEquals(TYPE_KEY, actualRecord.nsType)
+            assertEquals("$serviceName.$serviceType.local", actualRecord.dName)
+            assertEquals(NAME_RECORDS_TTL_MILLIS, actualRecord.ttl)
+            assertArrayEquals(publicKey, actualRecord.rr)
+
+            nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD,
+                    testNetwork1.network, Executor { it.run() }, discoveryRecord)
+
+            val discoveredInfo1 = discoveryRecord.waitForServiceDiscovered(
+                    serviceName, serviceType, testNetwork1.network)
+            val resolvedInfo1 = resolveService(discoveredInfo1)
+
+            assertEquals(serviceName, discoveredInfo1.serviceName)
+            assertEquals(TEST_PORT, resolvedInfo1.port)
+        } cleanupStep {
+            nsdManager.stopServiceDiscovery(discoveryRecord)
+
+            discoveryRecord.expectCallback<DiscoveryStopped>()
+        } cleanup {
+            nsdManager.unregisterService(registrationRecord)
+        }
+    }
+
+    @Test
+    fun testAdvertising_registerCustomHostAndPublicKey_keyAnnounced() {
+        val si = NsdServiceInfo().also {
+            it.network = testNetwork1.network
+            it.hostname = customHostname
+            it.hostAddresses = listOf(
+                    parseNumericAddress("192.0.2.23"),
+                    parseNumericAddress("2001:db8::1"),
+                    parseNumericAddress("2001:db8::2"))
+            it.publicKey = publicKey
+        }
+        val packetReader = TapPacketReader(Handler(handlerThread.looper),
+                testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+        packetReader.startAsyncForTest()
+        handlerThread.waitForIdle(TIMEOUT_MS)
+
+        val registrationRecord = NsdRegistrationRecord()
+        tryTest {
+            registerService(registrationRecord, si)
+
+            val announcement = packetReader.pollForReply("$customHostname.local", TYPE_KEY)
+            assertNotNull(announcement)
+            val keyRecords = announcement.records[ANSECTION].filter { it.nsType == TYPE_KEY }
+            assertEquals(1, keyRecords.size)
+            val actualRecord = keyRecords.get(0)
+            assertEquals(TYPE_KEY, actualRecord.nsType)
+            assertEquals("$customHostname.local", actualRecord.dName)
+            assertEquals(NAME_RECORDS_TTL_MILLIS, actualRecord.ttl)
+            assertArrayEquals(publicKey, actualRecord.rr)
+
+            // This test case focuses on key announcement so we don't check the details of the
+            // announcement of the custom host addresses.
+            val addressRecords = announcement.records[ANSECTION].filter {
+                it.nsType == DnsResolver.TYPE_AAAA ||
+                        it.nsType == DnsResolver.TYPE_A
+            }
+            assertEquals(3, addressRecords.size)
+        } cleanup {
+            nsdManager.unregisterService(registrationRecord)
+        }
+    }
+
+    @Test
+    fun testAdvertising_registerTwoServicesWithSameCustomHostAndPublicKey_keyAnnounced() {
+        val si1 = NsdServiceInfo().also {
+            it.network = testNetwork1.network
+            it.serviceType = serviceType
+            it.serviceName = serviceName
+            it.port = TEST_PORT
+            it.hostname = customHostname
+            it.hostAddresses = listOf(
+                parseNumericAddress("192.0.2.23"),
+                parseNumericAddress("2001:db8::1"),
+                parseNumericAddress("2001:db8::2"))
+            it.publicKey = publicKey
+        }
+        val si2 = NsdServiceInfo().also {
+            it.network = testNetwork1.network
+            it.serviceType = serviceType2
+            it.serviceName = serviceName2
+            it.port = TEST_PORT + 1
+            it.hostname = customHostname
+            it.hostAddresses = listOf()
+            it.publicKey = publicKey
+        }
+        val packetReader = TapPacketReader(Handler(handlerThread.looper),
+            testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+        packetReader.startAsyncForTest()
+        handlerThread.waitForIdle(TIMEOUT_MS)
+
+        val registrationRecord1 = NsdRegistrationRecord()
+        val registrationRecord2 = NsdRegistrationRecord()
+        tryTest {
+            registerService(registrationRecord1, si1)
+
+            var announcement =
+                packetReader.pollForReply("$serviceName.$serviceType.local", TYPE_KEY)
+            assertNotNull(announcement)
+            var keyRecords = announcement.records[ANSECTION].filter { it.nsType == TYPE_KEY }
+            assertEquals(2, keyRecords.size)
+            assertTrue(keyRecords.any { it.dName == "$serviceName.$serviceType.local" })
+            assertTrue(keyRecords.any { it.dName == "$customHostname.local" })
+            assertTrue(keyRecords.all { it.ttl == NAME_RECORDS_TTL_MILLIS })
+            assertTrue(keyRecords.all { it.rr.contentEquals(publicKey) })
+
+            // This test case focuses on key announcement so we don't check the details of the
+            // announcement of the custom host addresses.
+            val addressRecords = announcement.records[ANSECTION].filter {
+                it.nsType == DnsResolver.TYPE_AAAA ||
+                        it.nsType == DnsResolver.TYPE_A
+            }
+            assertEquals(3, addressRecords.size)
+
+            registerService(registrationRecord2, si2)
+
+            announcement = packetReader.pollForReply("$serviceName2.$serviceType2.local", TYPE_KEY)
+            assertNotNull(announcement)
+            keyRecords = announcement.records[ANSECTION].filter { it.nsType == TYPE_KEY }
+            assertEquals(2, keyRecords.size)
+            assertTrue(keyRecords.any { it.dName == "$serviceName2.$serviceType2.local" })
+            assertTrue(keyRecords.any { it.dName == "$customHostname.local" })
+            assertTrue(keyRecords.all { it.ttl == NAME_RECORDS_TTL_MILLIS })
+            assertTrue(keyRecords.all { it.rr.contentEquals(publicKey) })
+        } cleanup {
+            nsdManager.unregisterService(registrationRecord1)
+            nsdManager.unregisterService(registrationRecord2)
+        }
+    }
+
+    @Test
     fun testServiceTypeClientRemovedAfterSocketDestroyed() {
         val si = makeTestServiceInfo(testNetwork1.network)
         // Register service on testNetwork1
diff --git a/tests/unit/java/android/net/nsd/NsdManagerTest.java b/tests/unit/java/android/net/nsd/NsdManagerTest.java
index 27c4561..9c812a1 100644
--- a/tests/unit/java/android/net/nsd/NsdManagerTest.java
+++ b/tests/unit/java/android/net/nsd/NsdManagerTest.java
@@ -16,6 +16,11 @@
 
 package android.net.nsd;
 
+import static android.net.InetAddresses.parseNumericAddress;
+import static android.net.nsd.NsdManager.checkServiceInfoForRegistration;
+
+import static com.android.net.module.util.HexDump.hexStringToByteArray;
+
 import static libcore.junit.util.compat.CoreCompatChangeRule.DisableCompatChanges;
 import static libcore.junit.util.compat.CoreCompatChangeRule.EnableCompatChanges;
 
@@ -54,6 +59,7 @@
 import org.mockito.MockitoAnnotations;
 
 import java.net.InetAddress;
+import java.util.Collections;
 import java.util.List;
 import java.time.Duration;
 
@@ -395,6 +401,7 @@
         NsdManager.RegistrationListener listener4 = mock(NsdManager.RegistrationListener.class);
         NsdManager.RegistrationListener listener5 = mock(NsdManager.RegistrationListener.class);
         NsdManager.RegistrationListener listener6 = mock(NsdManager.RegistrationListener.class);
+        NsdManager.RegistrationListener listener7 = mock(NsdManager.RegistrationListener.class);
 
         NsdServiceInfo invalidService = new NsdServiceInfo(null, null);
         NsdServiceInfo validService = new NsdServiceInfo("a_name", "_a_type._tcp");
@@ -439,6 +446,19 @@
         validServiceWithCustomHostNoAddresses.setPort(2222);
         validServiceWithCustomHostNoAddresses.setHostname("a_host");
 
+        NsdServiceInfo validServiceWithPublicKey = new NsdServiceInfo("a_name", "_a_type._tcp");
+        validServiceWithPublicKey.setPublicKey(
+                hexStringToByteArray(
+                        "0201030dc141d0637960b98cbc12cfca"
+                                + "221d2879dac26ee5b460e9007c992e19"
+                                + "02d897c391b03764d448f7d0c772fdb0"
+                                + "3b1d9d6d52ff8886769e8e2362513565"
+                                + "270962d3"));
+
+        NsdServiceInfo invalidServiceWithTooShortPublicKey =
+                new NsdServiceInfo("a_name", "_a_type._tcp");
+        invalidServiceWithTooShortPublicKey.setPublicKey(hexStringToByteArray("0201"));
+
         // Service registration
         //  - invalid arguments
         mustFail(() -> { manager.unregisterService(null); });
@@ -449,6 +469,8 @@
         mustFail(() -> { manager.registerService(validService, PROTOCOL, null); });
         mustFail(() -> {
             manager.registerService(invalidMissingHostnameWithAddresses, PROTOCOL, listener1); });
+        mustFail(() -> {
+            manager.registerService(invalidServiceWithTooShortPublicKey, PROTOCOL, listener1); });
         manager.registerService(validService, PROTOCOL, listener1);
         //  - update without subtype is not allowed
         mustFail(() -> { manager.registerService(validServiceDuplicate, PROTOCOL, listener1); });
@@ -479,6 +501,9 @@
         //  - registering a service with a custom host with no addresses is valid
         manager.registerService(validServiceWithCustomHostNoAddresses, PROTOCOL, listener6);
         manager.unregisterService(listener6);
+        //  - registering a service with a public key is valid
+        manager.registerService(validServiceWithPublicKey, PROTOCOL, listener7);
+        manager.unregisterService(listener7);
 
         // Discover service
         //  - invalid arguments
@@ -506,6 +531,229 @@
         mustFail(() -> { manager.resolveService(validService, listener3); });
     }
 
+    private static final class NsdServiceInfoBuilder {
+        private static final String SERVICE_NAME = "TestService";
+        private static final String SERVICE_TYPE = "_testservice._tcp";
+        private static final int SERVICE_PORT = 12345;
+        private static final String HOSTNAME = "TestHost";
+        private static final List<InetAddress> HOST_ADDRESSES =
+                List.of(parseNumericAddress("192.168.2.23"), parseNumericAddress("2001:db8::3"));
+        private static final byte[] PUBLIC_KEY =
+                hexStringToByteArray(
+                        "0201030dc141d0637960b98cbc12cfca"
+                                + "221d2879dac26ee5b460e9007c992e19"
+                                + "02d897c391b03764d448f7d0c772fdb0"
+                                + "3b1d9d6d52ff8886769e8e2362513565"
+                                + "270962d3");
+
+        private final NsdServiceInfo mNsdServiceInfo = new NsdServiceInfo();
+
+        NsdServiceInfo build() {
+            return mNsdServiceInfo;
+        }
+
+        NsdServiceInfoBuilder setNoService() {
+            mNsdServiceInfo.setServiceName(null);
+            mNsdServiceInfo.setServiceType(null);
+            mNsdServiceInfo.setPort(0);
+            return this;
+        }
+
+        NsdServiceInfoBuilder setService() {
+            mNsdServiceInfo.setServiceName(SERVICE_NAME);
+            mNsdServiceInfo.setServiceType(SERVICE_TYPE);
+            mNsdServiceInfo.setPort(SERVICE_PORT);
+            return this;
+        }
+
+        NsdServiceInfoBuilder setZeroPortService() {
+            mNsdServiceInfo.setServiceName(SERVICE_NAME);
+            mNsdServiceInfo.setServiceType(SERVICE_TYPE);
+            mNsdServiceInfo.setPort(0);
+            return this;
+        }
+
+        NsdServiceInfoBuilder setInvalidService() {
+            mNsdServiceInfo.setServiceName(SERVICE_NAME);
+            mNsdServiceInfo.setServiceType(null);
+            mNsdServiceInfo.setPort(SERVICE_PORT);
+            return this;
+        }
+
+        NsdServiceInfoBuilder setDefaultHost() {
+            mNsdServiceInfo.setHostname(null);
+            mNsdServiceInfo.setHostAddresses(Collections.emptyList());
+            return this;
+        }
+
+        NsdServiceInfoBuilder setCustomHost() {
+            mNsdServiceInfo.setHostname(HOSTNAME);
+            mNsdServiceInfo.setHostAddresses(HOST_ADDRESSES);
+            return this;
+        }
+
+        NsdServiceInfoBuilder setCustomHostNoAddress() {
+            mNsdServiceInfo.setHostname(HOSTNAME);
+            mNsdServiceInfo.setHostAddresses(Collections.emptyList());
+            return this;
+        }
+
+        NsdServiceInfoBuilder setHostAddressesNoHostname() {
+            mNsdServiceInfo.setHostname(null);
+            mNsdServiceInfo.setHostAddresses(HOST_ADDRESSES);
+            return this;
+        }
+
+        NsdServiceInfoBuilder setNoPublicKey() {
+            mNsdServiceInfo.setPublicKey(null);
+            return this;
+        }
+
+        NsdServiceInfoBuilder setPublicKey() {
+            mNsdServiceInfo.setPublicKey(PUBLIC_KEY);
+            return this;
+        }
+
+        NsdServiceInfoBuilder setInvalidPublicKey() {
+            mNsdServiceInfo.setPublicKey(new byte[3]);
+            return this;
+        }
+    }
+
+    @Test
+    public void testCheckServiceInfoForRegistration() {
+        // The service is invalid
+        mustFail(() -> checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setInvalidService()
+                        .setCustomHost()
+                        .setPublicKey().build()));
+        // Keep compatible with the legacy behavior: It's allowed to set host
+        // addresses for a service registration although the host addresses
+        // won't be registered. To register the addresses for a host, the
+        // hostname must be specified.
+        checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setService()
+                        .setHostAddressesNoHostname()
+                        .setPublicKey().build());
+        // The public key is invalid
+        mustFail(() -> checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setService()
+                        .setCustomHost()
+                        .setInvalidPublicKey().build()));
+        // Invalid combinations
+        // 1. (service, custom host, key): valid
+        checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setService()
+                        .setCustomHost()
+                        .setPublicKey().build());
+        // 2. (service, custom host, no key): valid
+        checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setService()
+                        .setCustomHost()
+                        .setNoPublicKey().build());
+        // 3. (service, no-address custom host, key): valid
+        checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setService()
+                        .setCustomHostNoAddress()
+                        .setPublicKey().build());
+        // 4. (service, no-address custom host, no key): valid
+        checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setService()
+                        .setCustomHostNoAddress()
+                        .setNoPublicKey().build());
+        // 5. (service, default host, key): valid
+        checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setService()
+                        .setDefaultHost()
+                        .setPublicKey().build());
+        // 6. (service, default host, no key): valid
+        checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setService()
+                        .setDefaultHost()
+                        .setNoPublicKey().build());
+        // 7. (0-port service, custom host, valid key): valid
+        checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setZeroPortService()
+                        .setCustomHost()
+                        .setPublicKey().build());
+        // 8. (0-port service, custom host, no key): invalid
+        mustFail(() -> checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setZeroPortService()
+                        .setCustomHost()
+                        .setNoPublicKey().build()));
+        // 9. (0-port service, no-address custom host, key): valid
+        checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setZeroPortService()
+                        .setCustomHostNoAddress()
+                        .setPublicKey().build());
+        // 10. (0-port service, no-address custom host, no key): invalid
+        mustFail(() -> checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setZeroPortService()
+                        .setCustomHostNoAddress()
+                        .setNoPublicKey().build()));
+        // 11. (0-port service, default host, key): valid
+        checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setZeroPortService()
+                        .setDefaultHost()
+                        .setPublicKey().build());
+        // 12. (0-port service, default host, no key): invalid
+        mustFail(() -> checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setZeroPortService()
+                        .setDefaultHost()
+                        .setNoPublicKey().build()));
+        // 13. (no service, custom host, key): valid
+        checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setNoService()
+                        .setCustomHost()
+                        .setPublicKey().build());
+        // 14. (no service, custom host, no key): valid
+        checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setNoService()
+                        .setCustomHost()
+                        .setNoPublicKey().build());
+        // 15. (no service, no-address custom host, key): valid
+        checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setNoService()
+                        .setCustomHostNoAddress()
+                        .setPublicKey().build());
+        // 16. (no service, no-address custom host, no key): invalid
+        mustFail(() -> checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setNoService()
+                        .setCustomHostNoAddress()
+                        .setNoPublicKey().build()));
+        // 17. (no service, default host, key): invalid
+        mustFail(() -> checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setNoService()
+                        .setDefaultHost()
+                        .setPublicKey().build()));
+        // 18. (no service, default host, no key): invalid
+        mustFail(() -> checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setNoService()
+                        .setDefaultHost()
+                        .setNoPublicKey().build()));
+    }
+
     public void mustFail(Runnable fn) {
         try {
             fn.run();