Merge "Add flag definition for metered network firewall chains" into main
diff --git a/bpf_progs/netd.c b/bpf_progs/netd.c
index 2aff89c..c3acaad 100644
--- a/bpf_progs/netd.c
+++ b/bpf_progs/netd.c
@@ -407,6 +407,9 @@
 
     BpfConfig enabledRules = getConfig(UID_RULES_CONFIGURATION_KEY);
 
+    // BACKGROUND match does not apply to loopback traffic
+    if (skb->ifindex == 1) enabledRules &= ~BACKGROUND_MATCH;
+
     UidOwnerValue* uidEntry = bpf_uid_owner_map_lookup_elem(&uid);
     uint32_t uidRules = uidEntry ? uidEntry->rule : 0;
     uint32_t allowed_iif = uidEntry ? uidEntry->iif : 0;
diff --git a/bpf_progs/offload.c b/bpf_progs/offload.c
index 8e72747..4f152bf 100644
--- a/bpf_progs/offload.c
+++ b/bpf_progs/offload.c
@@ -876,5 +876,5 @@
 }
 
 LICENSE("Apache 2.0");
-//CRITICAL("Connectivity (Tethering)");
+CRITICAL("Connectivity (Tethering)");
 DISABLE_BTF_ON_USER_BUILDS();
diff --git a/framework-t/src/android/net/nsd/NsdManager.java b/framework-t/src/android/net/nsd/NsdManager.java
index 1001423..48d40e6 100644
--- a/framework-t/src/android/net/nsd/NsdManager.java
+++ b/framework-t/src/android/net/nsd/NsdManager.java
@@ -389,6 +389,7 @@
     }
 
     private static final int FIRST_LISTENER_KEY = 1;
+    private static final int DNSSEC_PROTOCOL = 3;
 
     private final INsdServiceConnector mService;
     private final Context mContext;
@@ -1754,45 +1755,132 @@
         }
     }
 
+    private enum ServiceValidationType {
+        NO_SERVICE,
+        HAS_SERVICE, // A service with a positive port
+        HAS_SERVICE_ZERO_PORT, // A service with a zero port
+    }
+
+    private enum HostValidationType {
+        DEFAULT_HOST, // No host is specified so the default host will be used
+        CUSTOM_HOST, // A custom host with addresses is specified
+        CUSTOM_HOST_NO_ADDRESS, // A custom host without address is specified
+    }
+
+    private enum PublicKeyValidationType {
+        NO_KEY,
+        HAS_KEY,
+    }
+
+    /**
+     * Check if the service is valid for registration and classify it as one of {@link
+     * ServiceValidationType}.
+     */
+    private static ServiceValidationType validateService(NsdServiceInfo serviceInfo) {
+        final boolean hasServiceName = !TextUtils.isEmpty(serviceInfo.getServiceName());
+        final boolean hasServiceType = !TextUtils.isEmpty(serviceInfo.getServiceType());
+        if (!hasServiceName && !hasServiceType && serviceInfo.getPort() == 0) {
+            return ServiceValidationType.NO_SERVICE;
+        }
+        if (hasServiceName && hasServiceType) {
+            if (serviceInfo.getPort() < 0) {
+                throw new IllegalArgumentException("Invalid port");
+            }
+            if (serviceInfo.getPort() == 0) {
+                return ServiceValidationType.HAS_SERVICE_ZERO_PORT;
+            }
+            return ServiceValidationType.HAS_SERVICE;
+        }
+        throw new IllegalArgumentException("The service name or the service type is missing");
+    }
+
+    /**
+     * Check if the host is valid for registration and classify it as one of {@link
+     * HostValidationType}.
+     */
+    private static HostValidationType validateHost(NsdServiceInfo serviceInfo) {
+        final boolean hasHostname = !TextUtils.isEmpty(serviceInfo.getHostname());
+        final boolean hasHostAddresses = !CollectionUtils.isEmpty(serviceInfo.getHostAddresses());
+        if (!hasHostname) {
+            // Keep compatible with the legacy behavior: It's allowed to set host
+            // addresses for a service registration although the host addresses
+            // won't be registered. To register the addresses for a host, the
+            // hostname must be specified.
+            return HostValidationType.DEFAULT_HOST;
+        }
+        if (!hasHostAddresses) {
+            return HostValidationType.CUSTOM_HOST_NO_ADDRESS;
+        }
+        return HostValidationType.CUSTOM_HOST;
+    }
+
+    /**
+     * Check if the public key is valid for registration and classify it as one of {@link
+     * PublicKeyValidationType}.
+     *
+     * <p>For simplicity, it only checks if the protocol is DNSSEC and the RDATA is not fewer than 4
+     * bytes. See RFC 3445 Section 3.
+     */
+    private static PublicKeyValidationType validatePublicKey(NsdServiceInfo serviceInfo) {
+        byte[] publicKey = serviceInfo.getPublicKey();
+        if (publicKey == null) {
+            return PublicKeyValidationType.NO_KEY;
+        }
+        if (publicKey.length < 4) {
+            throw new IllegalArgumentException("The public key should be at least 4 bytes long");
+        }
+        int protocol = publicKey[2];
+        if (protocol == DNSSEC_PROTOCOL) {
+            return PublicKeyValidationType.HAS_KEY;
+        }
+        throw new IllegalArgumentException(
+                "The public key's protocol ("
+                        + protocol
+                        + ") is invalid. It should be DNSSEC_PROTOCOL (3)");
+    }
+
     /**
      * Check if the {@link NsdServiceInfo} is valid for registration.
      *
-     * The following can be registered:
-     * - A service with an optional host.
-     * - A hostname with addresses.
+     * <p>Firstly, check if service, host and public key are all valid respectively. Then check if
+     * the combination of service, host and public key is valid.
      *
-     * Note that:
-     * - When registering a service, the service name, service type and port must be specified. If
-     *   hostname is specified, the host addresses can optionally be specified.
-     * - When registering a host without a service, the addresses must be specified.
+     * <p>If the {@code serviceInfo} is invalid, throw an {@link IllegalArgumentException}
+     * describing the reason.
+     *
+     * <p>There are the invalid combinations of service, host and public key:
+     *
+     * <ul>
+     *   <li>Neither service nor host is specified.
+     *   <li>No public key is specified and the service has a zero port.
+     *   <li>The registration only contains the hostname but addresses are missing.
+     * </ul>
+     *
+     * <p>Keys are used to reserve hostnames or service names while the service/host is temporarily
+     * inactive, so registrations with a key and just a hostname or a service name are acceptable.
      *
      * @hide
      */
     public static void checkServiceInfoForRegistration(NsdServiceInfo serviceInfo) {
         Objects.requireNonNull(serviceInfo, "NsdServiceInfo cannot be null");
-        boolean hasServiceName = !TextUtils.isEmpty(serviceInfo.getServiceName());
-        boolean hasServiceType = !TextUtils.isEmpty(serviceInfo.getServiceType());
-        boolean hasHostname = !TextUtils.isEmpty(serviceInfo.getHostname());
-        boolean hasHostAddresses = !CollectionUtils.isEmpty(serviceInfo.getHostAddresses());
 
-        if (serviceInfo.getPort() < 0) {
-            throw new IllegalArgumentException("Invalid port");
+        final ServiceValidationType serviceValidation = validateService(serviceInfo);
+        final HostValidationType hostValidation = validateHost(serviceInfo);
+        final PublicKeyValidationType publicKeyValidation = validatePublicKey(serviceInfo);
+
+        if (serviceValidation == ServiceValidationType.NO_SERVICE
+                && hostValidation == HostValidationType.DEFAULT_HOST) {
+            throw new IllegalArgumentException("Nothing to register");
         }
-
-        if (hasServiceType || hasServiceName || (serviceInfo.getPort() > 0)) {
-            if (!(hasServiceType && hasServiceName && (serviceInfo.getPort() > 0))) {
-                throw new IllegalArgumentException(
-                        "The service type, service name or port is missing");
+        if (publicKeyValidation == PublicKeyValidationType.NO_KEY) {
+            if (serviceValidation == ServiceValidationType.HAS_SERVICE_ZERO_PORT) {
+                throw new IllegalArgumentException("The port is missing");
             }
-        }
-
-        if (!hasServiceType && !hasHostname) {
-            throw new IllegalArgumentException("No service or host specified in NsdServiceInfo");
-        }
-
-        if (!hasServiceType && hasHostname && !hasHostAddresses) {
-            // TODO: b/317946010 - This may be allowed when it supports registering KEY RR.
-            throw new IllegalArgumentException("No host addresses specified in NsdServiceInfo");
+            if (serviceValidation == ServiceValidationType.NO_SERVICE
+                    && hostValidation == HostValidationType.CUSTOM_HOST_NO_ADDRESS) {
+                throw new IllegalArgumentException(
+                        "The host addresses must be specified unless there is a service");
+            }
         }
     }
 }
diff --git a/framework-t/src/android/net/nsd/NsdServiceInfo.java b/framework-t/src/android/net/nsd/NsdServiceInfo.java
index 9491a9c..2f675a9 100644
--- a/framework-t/src/android/net/nsd/NsdServiceInfo.java
+++ b/framework-t/src/android/net/nsd/NsdServiceInfo.java
@@ -37,6 +37,7 @@
 import java.nio.charset.StandardCharsets;
 import java.time.Instant;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
@@ -69,6 +70,9 @@
     private int mPort;
 
     @Nullable
+    private byte[] mPublicKey;
+
+    @Nullable
     private Network mNetwork;
 
     private int mInterfaceIndex;
@@ -220,6 +224,40 @@
     }
 
     /**
+     * Set the public key RDATA to be advertised in a KEY RR (RFC 2535).
+     *
+     * <p>This is the public key of the key pair used for signing a DNS message (e.g. SRP). Clients
+     * typically don't need this information, but the KEY RR is usually published to claim the use
+     * of the DNS name so that another mDNS advertiser can't take over the ownership during a
+     * temporary power down of the original host device.
+     *
+     * <p>When the public key is set to non-null, exactly one KEY RR will be advertised for each of
+     * the service and host name if they are not null.
+     *
+     * @hide // For Thread only
+     */
+    public void setPublicKey(@Nullable byte[] publicKey) {
+        if (publicKey == null) {
+            mPublicKey = null;
+            return;
+        }
+        mPublicKey = Arrays.copyOf(publicKey, publicKey.length);
+    }
+
+    /**
+     * Get the public key RDATA in the KEY RR (RFC 2535) or {@code null} if no KEY RR exists.
+     *
+     * @hide // For Thread only
+     */
+    @Nullable
+    public byte[] getPublicKey() {
+        if (mPublicKey == null) {
+            return null;
+        }
+        return Arrays.copyOf(mPublicKey, mPublicKey.length);
+    }
+
+    /**
      * Unpack txt information from a base-64 encoded byte array.
      *
      * @param txtRecordsRawBytes The raw base64 encoded byte array.
@@ -622,6 +660,7 @@
         }
         dest.writeString(mHostname);
         dest.writeLong(mExpirationTime != null ? mExpirationTime.getEpochSecond() : -1);
+        dest.writeByteArray(mPublicKey);
     }
 
     /** Implement the Parcelable interface */
@@ -654,6 +693,7 @@
                 info.mHostname = in.readString();
                 final long seconds = in.readLong();
                 info.setExpirationTime(seconds < 0 ? null : Instant.ofEpochSecond(seconds));
+                info.mPublicKey = in.createByteArray();
                 return info;
             }
 
diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java
index 7c1ca30..0a8adf0 100644
--- a/service-t/src/com/android/server/NsdService.java
+++ b/service-t/src/com/android/server/NsdService.java
@@ -201,6 +201,7 @@
     private static final int NO_SENT_QUERY_COUNT = 0;
     private static final int DISCOVERY_QUERY_SENT_CALLBACK = 1000;
     private static final int MAX_SUBTYPE_COUNT = 100;
+    private static final int DNSSEC_PROTOCOL = 3;
     private static final SharedLog LOGGER = new SharedLog("serviceDiscovery");
 
     private final Context mContext;
@@ -1009,6 +1010,17 @@
                                 break;
                             }
 
+                            if (!checkPublicKey(serviceInfo.getPublicKey())) {
+                                Log.e(TAG,
+                                        "Invalid public key: "
+                                                + Arrays.toString(serviceInfo.getPublicKey()));
+                                clientInfo.onRegisterServiceFailedImmediately(
+                                        clientRequestId,
+                                        NsdManager.FAILURE_BAD_PARAMETERS,
+                                        false /* isLegacy */);
+                                break;
+                            }
+
                             Set<String> subtypes = new ArraySet<>(serviceInfo.getSubtypes());
                             if (typeSubtype != null && typeSubtype.second != null) {
                                 for (String subType : typeSubtype.second) {
@@ -1842,6 +1854,25 @@
         return Pattern.compile(HOSTNAME_REGEX).matcher(hostname).matches();
     }
 
+    /**
+     * Checks if the public key is valid.
+     *
+     * <p>For simplicity, it only checks if the protocol is DNSSEC and the RDATA is not fewer than 4
+     * bytes. See RFC 3445 Section 3.
+     *
+     * <p>Message format: flags (2 bytes), protocol (1 byte), algorithm (1 byte), public key.
+     */
+    private static boolean checkPublicKey(@Nullable byte[] publicKey) {
+        if (publicKey == null) {
+            return true;
+        }
+        if (publicKey.length < 4) {
+            return false;
+        }
+        int protocol = publicKey[2];
+        return protocol == DNSSEC_PROTOCOL;
+    }
+
     /** Returns {@code true} if {@code subtype} is a valid DNS-SD subtype label. */
     private static boolean checkSubtypeLabel(String subtype) {
         return Pattern.compile("^" + SUBTYPE_LABEL_REGEX + "$").matcher(subtype).matches();
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java
index 98c2d86..42efcac 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java
@@ -448,10 +448,11 @@
         /**
          * Get the ID of a conflicting registration due to host, or -1 if none.
          *
-         * <p>It's valid that multiple registrations from the same user are using the same hostname.
-         *
          * <p>If there's already another registration with the same hostname requested by another
-         * user, this is considered a conflict.
+         * user, this is a conflict.
+         *
+         * <p>If there're two registrations both containing address records using the same hostname,
+         * this is a conflict.
          */
         int getConflictingRegistrationDueToHost(@NonNull NsdServiceInfo info, int clientUid) {
             if (TextUtils.isEmpty(info.getHostname())) {
@@ -460,10 +461,17 @@
             for (int i = 0; i < mPendingRegistrations.size(); i++) {
                 final Registration otherRegistration = mPendingRegistrations.valueAt(i);
                 final NsdServiceInfo otherInfo = otherRegistration.getServiceInfo();
+                final int otherServiceId = mPendingRegistrations.keyAt(i);
                 if (clientUid != otherRegistration.mClientUid
                         && MdnsUtils.equalsIgnoreDnsCase(
                                 info.getHostname(), otherInfo.getHostname())) {
-                    return mPendingRegistrations.keyAt(i);
+                    return otherServiceId;
+                }
+                if (!info.getHostAddresses().isEmpty()
+                        && !otherInfo.getHostAddresses().isEmpty()
+                        && MdnsUtils.equalsIgnoreDnsCase(
+                                info.getHostname(), otherInfo.getHostname())) {
+                    return otherServiceId;
                 }
             }
             return -1;
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsKeyRecord.java b/service-t/src/com/android/server/connectivity/mdns/MdnsKeyRecord.java
new file mode 100644
index 0000000..ba8a56e
--- /dev/null
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsKeyRecord.java
@@ -0,0 +1,100 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivity.mdns;
+
+import static com.android.net.module.util.HexDump.toHexString;
+
+import android.annotation.NonNull;
+import android.annotation.Nullable;
+
+import androidx.annotation.VisibleForTesting;
+
+import java.io.IOException;
+import java.util.Arrays;
+
+/** An mDNS "KEY" record, which contains a public key for a name. See RFC 2535. */
+@VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
+public class MdnsKeyRecord extends MdnsRecord {
+    @Nullable private byte[] rData;
+
+    public MdnsKeyRecord(@NonNull String[] name, @NonNull MdnsPacketReader reader)
+            throws IOException {
+        this(name, reader, false);
+    }
+
+    public MdnsKeyRecord(@NonNull String[] name, @NonNull MdnsPacketReader reader,
+            boolean isQuestion) throws IOException {
+        super(name, TYPE_KEY, reader, isQuestion);
+    }
+
+    public MdnsKeyRecord(@NonNull String[] name, boolean isUnicast) {
+        super(name, TYPE_KEY,
+                MdnsConstants.QCLASS_INTERNET | (isUnicast ? MdnsConstants.QCLASS_UNICAST : 0),
+                0L /* receiptTimeMillis */, false /* cacheFlush */, 0L /* ttlMillis */);
+    }
+
+    public MdnsKeyRecord(@NonNull String[] name, long receiptTimeMillis, boolean cacheFlush,
+            long ttlMillis, @Nullable byte[] rData) {
+        super(name, TYPE_KEY, MdnsConstants.QCLASS_INTERNET, receiptTimeMillis, cacheFlush,
+                ttlMillis);
+        if (rData != null) {
+            this.rData = Arrays.copyOf(rData, rData.length);
+        }
+    }
+    /** Returns the KEY RDATA in bytes **/
+    public byte[] getRData() {
+        if (rData == null) {
+            return null;
+        }
+        return Arrays.copyOf(rData, rData.length);
+    }
+
+    @Override
+    protected void readData(MdnsPacketReader reader) throws IOException {
+        rData = new byte[reader.getRemaining()];
+        reader.readBytes(rData);
+    }
+
+    @Override
+    protected void writeData(MdnsPacketWriter writer) throws IOException {
+        if (rData != null) {
+            writer.writeBytes(rData);
+        }
+    }
+
+    @Override
+    public String toString() {
+        return "KEY: " + toHexString(rData);
+    }
+
+    @Override
+    public int hashCode() {
+        return (super.hashCode() * 31) + Arrays.hashCode(rData);
+    }
+
+    @Override
+    public boolean equals(@Nullable Object other) {
+        if (this == other) {
+            return true;
+        }
+        if (!(other instanceof MdnsKeyRecord)) {
+            return false;
+        }
+
+        return super.equals(other) && Arrays.equals(rData, ((MdnsKeyRecord) other).rData);
+    }
+}
\ No newline at end of file
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsPacket.java b/service-t/src/com/android/server/connectivity/mdns/MdnsPacket.java
index 83ecabc..aef8211 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsPacket.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsPacket.java
@@ -196,6 +196,15 @@
                 }
             }
 
+            case MdnsRecord.TYPE_KEY: {
+                try {
+                    return new MdnsKeyRecord(name, reader, isQuestion);
+                } catch (IOException e) {
+                    throw new ParseException(MdnsResponseErrorCode.ERROR_READING_KEY_RDATA,
+                            "Failed to read KEY record from mDNS response.", e);
+                }
+            }
+
             case MdnsRecord.TYPE_NSEC: {
                 try {
                     return new MdnsNsecRecord(name, reader, isQuestion);
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsRecord.java b/service-t/src/com/android/server/connectivity/mdns/MdnsRecord.java
index 1f9f42b..b865319 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsRecord.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsRecord.java
@@ -41,6 +41,7 @@
     public static final int TYPE_PTR = 0x000C;
     public static final int TYPE_SRV = 0x0021;
     public static final int TYPE_TXT = 0x0010;
+    public static final int TYPE_KEY = 0x0019;
     public static final int TYPE_NSEC = 0x002f;
     public static final int TYPE_ANY = 0x00ff;
 
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
index eb85110..39e8bcc 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
@@ -184,6 +184,10 @@
         public final RecordInfo<MdnsServiceRecord> srvRecord;
         @Nullable
         public final RecordInfo<MdnsTextRecord> txtRecord;
+        @Nullable
+        public final RecordInfo<MdnsKeyRecord> serviceKeyRecord;
+        @Nullable
+        public final RecordInfo<MdnsKeyRecord> hostKeyRecord;
         @NonNull
         public final List<RecordInfo<MdnsInetAddressRecord>> addressRecords;
         @NonNull
@@ -245,7 +249,6 @@
                 nameRecordsTtlMillis = DEFAULT_NAME_RECORDS_TTL_MILLIS;
             }
 
-            final boolean hasService = !TextUtils.isEmpty(serviceInfo.getServiceType());
             final boolean hasCustomHost = !TextUtils.isEmpty(serviceInfo.getHostname());
             final String[] hostname =
                     hasCustomHost
@@ -253,9 +256,11 @@
                             : deviceHostname;
             final ArrayList<RecordInfo<?>> allRecords = new ArrayList<>(5);
 
-            if (hasService) {
-                final String[] serviceType = splitServiceType(serviceInfo);
-                final String[] serviceName = splitFullyQualifiedName(serviceInfo, serviceType);
+            final boolean hasService = !TextUtils.isEmpty(serviceInfo.getServiceType());
+            final String[] serviceType = hasService ? splitServiceType(serviceInfo) : null;
+            final String[] serviceName =
+                    hasService ? splitFullyQualifiedName(serviceInfo, serviceType) : null;
+            if (hasService && hasSrvRecord(serviceInfo)) {
                 // Service PTR records
                 ptrRecords = new ArrayList<>(serviceInfo.getSubtypes().size() + 1);
                 ptrRecords.add(new RecordInfo<>(
@@ -336,6 +341,36 @@
                 addressRecords = Collections.emptyList();
             }
 
+            final boolean hasKey = hasKeyRecord(serviceInfo);
+            if (hasKey && hasService) {
+                this.serviceKeyRecord = new RecordInfo<>(
+                        serviceInfo,
+                        new MdnsKeyRecord(
+                                serviceName,
+                                0L /*receiptTimeMillis */,
+                                true /* cacheFlush */,
+                                nameRecordsTtlMillis,
+                                serviceInfo.getPublicKey()),
+                        false /* sharedName */);
+                allRecords.add(this.serviceKeyRecord);
+            } else {
+                this.serviceKeyRecord = null;
+            }
+            if (hasKey && hasCustomHost) {
+                this.hostKeyRecord = new RecordInfo<>(
+                        serviceInfo,
+                        new MdnsKeyRecord(
+                                hostname,
+                                0L /*receiptTimeMillis */,
+                                true /* cacheFlush */,
+                                nameRecordsTtlMillis,
+                                serviceInfo.getPublicKey()),
+                        false /* sharedName */);
+                allRecords.add(this.hostKeyRecord);
+            } else {
+                this.hostKeyRecord = null;
+            }
+
             this.allRecords = Collections.unmodifiableList(allRecords);
             this.repliedServiceCount = repliedServiceCount;
             this.sentPacketCount = sentPacketCount;
@@ -486,6 +521,22 @@
                             ? inetAddressRecord.getInet6Address()
                             : inetAddressRecord.getInet4Address()));
         }
+
+        List<MdnsKeyRecord> keyRecords = new ArrayList<>();
+        if (registration.serviceKeyRecord != null) {
+            keyRecords.add(registration.serviceKeyRecord.record);
+        }
+        if (registration.hostKeyRecord != null) {
+            keyRecords.add(registration.hostKeyRecord.record);
+        }
+        for (MdnsKeyRecord keyRecord : keyRecords) {
+            probingRecords.add(new MdnsKeyRecord(
+                            keyRecord.getName(),
+                            0L /* receiptTimeMillis */,
+                            false /* cacheFlush */,
+                            keyRecord.getTtl(),
+                            keyRecord.getRData()));
+        }
         return new MdnsProber.ProbingInfo(serviceId, probingRecords);
     }
 
@@ -1101,18 +1152,15 @@
                 Collections.emptyList() /* additionalRecords */);
     }
 
-    /** Check if the record is in any service registration */
-    private boolean hasInetAddressRecord(@NonNull MdnsInetAddressRecord record) {
-        for (int i = 0; i < mServices.size(); i++) {
-            final ServiceRegistration registration = mServices.valueAt(i);
-            if (registration.exiting) continue;
-
-            for (RecordInfo<MdnsInetAddressRecord> localRecord : registration.addressRecords) {
-                if (Objects.equals(localRecord.record, record)) {
-                    return true;
-                }
+    /** Check if the record is in a registration */
+    private static boolean hasInetAddressRecord(
+            @NonNull ServiceRegistration registration, @NonNull MdnsInetAddressRecord record) {
+        for (RecordInfo<MdnsInetAddressRecord> localRecord : registration.addressRecords) {
+            if (Objects.equals(localRecord.record, record)) {
+                return true;
             }
         }
+
         return false;
     }
 
@@ -1155,36 +1203,33 @@
         return conflicting;
     }
 
-
     private static boolean conflictForService(
             @NonNull MdnsRecord record, @NonNull ServiceRegistration registration) {
-        if (registration.srvRecord == null) {
+        String[] fullServiceName;
+        if (registration.srvRecord != null) {
+            fullServiceName = registration.srvRecord.record.getName();
+        } else if (registration.serviceKeyRecord != null) {
+            fullServiceName = registration.serviceKeyRecord.record.getName();
+        } else {
             return false;
         }
 
-        final RecordInfo<MdnsServiceRecord> srvRecord = registration.srvRecord;
-        if (!MdnsUtils.equalsDnsLabelIgnoreDnsCase(record.getName(), srvRecord.record.getName())) {
+        if (!MdnsUtils.equalsDnsLabelIgnoreDnsCase(record.getName(), fullServiceName)) {
             return false;
         }
 
         // As per RFC6762 9., it's fine if the "conflict" is an identical record with same
         // data.
-        if (record instanceof MdnsServiceRecord) {
-            final MdnsServiceRecord local = srvRecord.record;
-            final MdnsServiceRecord other = (MdnsServiceRecord) record;
-            // Note "equals" does not consider TTL or receipt time, as intended here
-            if (Objects.equals(local, other)) {
-                return false;
-            }
+        if (record instanceof MdnsServiceRecord && equals(record, registration.srvRecord)) {
+            return false;
+        }
+        if (record instanceof MdnsTextRecord && equals(record, registration.txtRecord)) {
+            return false;
+        }
+        if (record instanceof MdnsKeyRecord && equals(record, registration.serviceKeyRecord)) {
+            return false;
         }
 
-        if (record instanceof MdnsTextRecord) {
-            final MdnsTextRecord local = registration.txtRecord.record;
-            final MdnsTextRecord other = (MdnsTextRecord) record;
-            if (Objects.equals(local, other)) {
-                return false;
-            }
-        }
         return true;
     }
 
@@ -1196,6 +1241,11 @@
             return false;
         }
 
+        // It cannot be a hostname conflict because not record is registered with the hostname.
+        if (registration.addressRecords.isEmpty() && registration.hostKeyRecord == null) {
+            return false;
+        }
+
         // The record's name cannot be registered by NsdManager so it's not a conflict.
         if (record.getName().length != 2 || !record.getName()[1].equals(LOCAL_TLD)) {
             return false;
@@ -1207,13 +1257,26 @@
             return false;
         }
 
-        // If this registration has any address record and there's no identical record in the
-        // repository, it's a conflict. There will be no conflict if no registration has addresses
-        // for that hostname.
-        if (record instanceof MdnsInetAddressRecord) {
-            if (!registration.addressRecords.isEmpty()) {
-                return !hasInetAddressRecord((MdnsInetAddressRecord) record);
-            }
+        // As per RFC6762 9., it's fine if the "conflict" is an identical record with same
+        // data.
+        if (record instanceof MdnsInetAddressRecord
+                && hasInetAddressRecord(registration, (MdnsInetAddressRecord) record)) {
+            return false;
+        }
+        if (record instanceof MdnsKeyRecord && equals(record, registration.hostKeyRecord)) {
+            return false;
+        }
+
+        // Per RFC 6762 8.1, when a record is being probed, any answer containing a record with that
+        // name, of any type, MUST be considered a conflicting response.
+        if (registration.isProbing) {
+            return true;
+        }
+        if (record instanceof MdnsInetAddressRecord && !registration.addressRecords.isEmpty()) {
+            return true;
+        }
+        if (record instanceof MdnsKeyRecord && registration.hostKeyRecord != null) {
+            return true;
         }
 
         return false;
@@ -1402,4 +1465,21 @@
 
         return type;
     }
+
+    /** Returns whether there will be an SRV record when registering the {@code info}. */
+    private static boolean hasSrvRecord(@NonNull NsdServiceInfo info) {
+        return info.getPort() > 0;
+    }
+
+    /** Returns whether there will be KEY record(s) when registering the {@code info}. */
+    private static boolean hasKeyRecord(@NonNull NsdServiceInfo info) {
+        return info.getPublicKey() != null;
+    }
+
+    private static boolean equals(@NonNull MdnsRecord record, @Nullable RecordInfo<?> recordInfo) {
+        if (recordInfo == null) {
+            return false;
+        }
+        return Objects.equals(record, recordInfo.record);
+    }
 }
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsResponseErrorCode.java b/service-t/src/com/android/server/connectivity/mdns/MdnsResponseErrorCode.java
index 73a7e3a..f509da2 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsResponseErrorCode.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsResponseErrorCode.java
@@ -37,4 +37,5 @@
     public static final int ERROR_END_OF_FILE = 12;
     public static final int ERROR_READING_NSEC_RDATA = 13;
     public static final int ERROR_READING_ANY_RDATA = 14;
+    public static final int ERROR_READING_KEY_RDATA = 15;
 }
\ No newline at end of file
diff --git a/tests/common/java/android/net/nsd/NsdServiceInfoTest.java b/tests/common/java/android/net/nsd/NsdServiceInfoTest.java
index 8e89037..21e34ab 100644
--- a/tests/common/java/android/net/nsd/NsdServiceInfoTest.java
+++ b/tests/common/java/android/net/nsd/NsdServiceInfoTest.java
@@ -16,6 +16,7 @@
 
 package android.net.nsd;
 
+import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertThrows;
@@ -51,6 +52,23 @@
 
     private static final InetAddress IPV4_ADDRESS = InetAddresses.parseNumericAddress("192.0.2.1");
     private static final InetAddress IPV6_ADDRESS = InetAddresses.parseNumericAddress("2001:db8::");
+    private static final byte[] PUBLIC_KEY_RDATA = new byte[] {
+            (byte) 0x02, (byte)0x01,  // flag
+            (byte) 0x03, // protocol
+            (byte) 0x0d, // algorithm
+            // 64-byte public key below
+            (byte) 0xC1, (byte) 0x41, (byte) 0xD0, (byte) 0x63, (byte) 0x79, (byte) 0x60,
+            (byte) 0xB9, (byte) 0x8C, (byte) 0xBC, (byte) 0x12, (byte) 0xCF, (byte) 0xCA,
+            (byte) 0x22, (byte) 0x1D, (byte) 0x28, (byte) 0x79, (byte) 0xDA, (byte) 0xC2,
+            (byte) 0x6E, (byte) 0xE5, (byte) 0xB4, (byte) 0x60, (byte) 0xE9, (byte) 0x00,
+            (byte) 0x7C, (byte) 0x99, (byte) 0x2E, (byte) 0x19, (byte) 0x02, (byte) 0xD8,
+            (byte) 0x97, (byte) 0xC3, (byte) 0x91, (byte) 0xB0, (byte) 0x37, (byte) 0x64,
+            (byte) 0xD4, (byte) 0x48, (byte) 0xF7, (byte) 0xD0, (byte) 0xC7, (byte) 0x72,
+            (byte) 0xFD, (byte) 0xB0, (byte) 0x3B, (byte) 0x1D, (byte) 0x9D, (byte) 0x6D,
+            (byte) 0x52, (byte) 0xFF, (byte) 0x88, (byte) 0x86, (byte) 0x76, (byte) 0x9E,
+            (byte) 0x8E, (byte) 0x23, (byte) 0x62, (byte) 0x51, (byte) 0x35, (byte) 0x65,
+            (byte) 0x27, (byte) 0x09, (byte) 0x62, (byte) 0xD3
+    };
 
     @Test
     public void testLimits() throws Exception {
@@ -120,6 +138,7 @@
         fullInfo.setPort(4242);
         fullInfo.setHostAddresses(List.of(IPV4_ADDRESS));
         fullInfo.setHostname("home");
+        fullInfo.setPublicKey(PUBLIC_KEY_RDATA);
         fullInfo.setNetwork(new Network(123));
         fullInfo.setInterfaceIndex(456);
         checkParcelable(fullInfo);
@@ -136,6 +155,7 @@
         attributedInfo.setPort(4242);
         attributedInfo.setHostAddresses(List.of(IPV6_ADDRESS, IPV4_ADDRESS));
         attributedInfo.setHostname("home");
+        attributedInfo.setPublicKey(PUBLIC_KEY_RDATA);
         attributedInfo.setAttribute("color", "pink");
         attributedInfo.setAttribute("sound", (new String("にゃあ")).getBytes("UTF-8"));
         attributedInfo.setAttribute("adorable", (String) null);
@@ -172,6 +192,7 @@
         assertEquals(original.getServiceType(), result.getServiceType());
         assertEquals(original.getHost(), result.getHost());
         assertEquals(original.getHostname(), result.getHostname());
+        assertArrayEquals(original.getPublicKey(), result.getPublicKey());
         assertTrue(original.getPort() == result.getPort());
         assertEquals(original.getNetwork(), result.getNetwork());
         assertEquals(original.getInterfaceIndex(), result.getInterfaceIndex());
diff --git a/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java b/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
index 0f86d78..8e7b3d4 100755
--- a/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
+++ b/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
@@ -257,7 +257,6 @@
 
     @Before
     public void setUp() throws Exception {
-        assumeTrue(supportedHardware());
         mNetwork = null;
         mTestContext = getInstrumentation().getContext();
         mTargetContext = getInstrumentation().getTargetContext();
@@ -272,6 +271,7 @@
         mDevice.waitForIdle();
         mCtsNetUtils = new CtsNetUtils(mTestContext);
         mPackageManager = mTestContext.getPackageManager();
+        assumeTrue(supportedHardware());
     }
 
     @After
diff --git a/tests/cts/net/src/android/net/cts/ApfIntegrationTest.kt b/tests/cts/net/src/android/net/cts/ApfIntegrationTest.kt
index 6ce8b7c..1a535b4 100644
--- a/tests/cts/net/src/android/net/cts/ApfIntegrationTest.kt
+++ b/tests/cts/net/src/android/net/cts/ApfIntegrationTest.kt
@@ -31,7 +31,9 @@
 import android.net.apf.ApfConstant.IPV6_NEXT_HEADER_OFFSET
 import android.net.apf.ApfV4Generator
 import android.net.apf.BaseApfGenerator
+import android.net.apf.BaseApfGenerator.MemorySlot
 import android.net.apf.BaseApfGenerator.Register.R0
+import android.net.apf.BaseApfGenerator.Register.R1
 import android.os.Build
 import android.os.Handler
 import android.os.HandlerThread
@@ -64,12 +66,14 @@
 import com.android.testutils.TestableNetworkCallback
 import com.android.testutils.runAsShell
 import com.android.testutils.waitForIdle
+import com.google.common.truth.Expect
 import com.google.common.truth.Truth.assertThat
 import com.google.common.truth.Truth.assertWithMessage
 import com.google.common.truth.TruthJUnit.assume
 import java.io.FileDescriptor
 import java.lang.Thread
 import java.net.InetSocketAddress
+import java.nio.ByteBuffer
 import java.util.concurrent.CompletableFuture
 import java.util.concurrent.TimeUnit
 import java.util.concurrent.TimeoutException
@@ -230,8 +234,8 @@
         }
     }
 
-    @get:Rule
-    val ignoreRule = DevSdkIgnoreRule()
+    @get:Rule val ignoreRule = DevSdkIgnoreRule()
+    @get:Rule val expect = Expect.create()
 
     private val cm by lazy { context.getSystemService(ConnectivityManager::class.java)!! }
     private val pm by lazy { context.packageManager }
@@ -364,12 +368,28 @@
         }
     }
 
+    fun ApfV4Generator.addPassIfNotIcmpv6EchoReply() {
+        // If not IPv6 -> PASS
+        addLoad16(R0, ETH_ETHERTYPE_OFFSET)
+        addJumpIfR0NotEquals(ETH_P_IPV6.toLong(), BaseApfGenerator.PASS_LABEL)
+
+        // If not ICMPv6 -> PASS
+        addLoad8(R0, IPV6_NEXT_HEADER_OFFSET)
+        addJumpIfR0NotEquals(IPPROTO_ICMPV6.toLong(), BaseApfGenerator.PASS_LABEL)
+
+        // If not echo reply -> PASS
+        addLoad8(R0, ICMP6_TYPE_OFFSET)
+        addJumpIfR0NotEquals(0x81, BaseApfGenerator.PASS_LABEL)
+    }
+
+    // APF integration is mostly broken before V
+    @IgnoreUpTo(Build.VERSION_CODES.UPSIDE_DOWN_CAKE)
     @Test
     fun testDropPingReply() {
         assumeApfVersionSupportAtLeast(4)
 
         // clear any active APF filter
-        var gen = ApfV4Generator(caps.apfVersionSupported).addPass()
+        var gen = ApfV4Generator(4).addPass()
         installProgram(gen.generate())
         readProgram() // wait for install completion
 
@@ -379,19 +399,10 @@
         assertThat(packetReader.expectPingReply()).isEqualTo(data)
 
         // Generate an APF program that drops the next ping
-        gen = ApfV4Generator(caps.apfVersionSupported)
+        gen = ApfV4Generator(4)
 
-        // If not IPv6 -> PASS
-        gen.addLoad16(R0, ETH_ETHERTYPE_OFFSET)
-        gen.addJumpIfR0NotEquals(ETH_P_IPV6.toLong(), BaseApfGenerator.PASS_LABEL)
-
-        // If not ICMPv6 -> PASS
-        gen.addLoad8(R0, IPV6_NEXT_HEADER_OFFSET)
-        gen.addJumpIfR0NotEquals(IPPROTO_ICMPV6.toLong(), BaseApfGenerator.PASS_LABEL)
-
-        // If not echo reply -> PASS
-        gen.addLoad8(R0, ICMP6_TYPE_OFFSET)
-        gen.addJumpIfR0NotEquals(0x81, BaseApfGenerator.PASS_LABEL)
+        // If not ICMPv6 Echo Reply -> PASS
+        gen.addPassIfNotIcmpv6EchoReply()
 
         // if not data matches -> PASS
         gen.addLoadImmediate(R0, ICMP6_TYPE_OFFSET + PING_HEADER_LENGTH)
@@ -407,4 +418,52 @@
         packetReader.sendPing(data)
         packetReader.expectPingDropped()
     }
+
+    // APF integration is mostly broken before V
+    @IgnoreUpTo(Build.VERSION_CODES.UPSIDE_DOWN_CAKE)
+    @Test
+    fun testPrefilledMemorySlotsV4() {
+        // Test v4 memory slots on both v4 and v6 interpreters.
+        assumeApfVersionSupportAtLeast(4)
+        // Clear the entire memory before starting this test
+        installProgram(ByteArray(caps.maximumApfProgramSize))
+        val gen = ApfV4Generator(4)
+
+        // If not ICMPv6 Echo Reply -> PASS
+        gen.addPassIfNotIcmpv6EchoReply()
+
+        // Store all prefilled memory slots in counter region [500, 520)
+        val counterRegion = 500
+        gen.addLoadImmediate(R1, counterRegion)
+        gen.addLoadFromMemory(R0, MemorySlot.PROGRAM_SIZE)
+        gen.addStoreData(R0, 0)
+        gen.addLoadFromMemory(R0, MemorySlot.RAM_LEN)
+        gen.addStoreData(R0, 4)
+        gen.addLoadFromMemory(R0, MemorySlot.IPV4_HEADER_SIZE)
+        gen.addStoreData(R0, 8)
+        gen.addLoadFromMemory(R0, MemorySlot.PACKET_SIZE)
+        gen.addStoreData(R0, 12)
+        gen.addLoadFromMemory(R0, MemorySlot.FILTER_AGE_SECONDS)
+        gen.addStoreData(R0, 16)
+
+        val program = gen.generate()
+        assertThat(program.size).isLessThan(counterRegion)
+        installProgram(program)
+        readProgram() // wait for install completion
+
+        // Trigger the program by sending a ping and waiting on the reply.
+        val data = ByteArray(56).also { Random.nextBytes(it) }
+        packetReader.sendPing(data)
+        packetReader.expectPingReply()
+
+        val readResult = readProgram()
+        val buffer = ByteBuffer.wrap(readResult)
+        buffer.position(counterRegion)
+        expect.withMessage("PROGRAM_SIZE").that(buffer.getInt()).isEqualTo(program.size)
+        expect.withMessage("RAM_LEN").that(buffer.getInt()).isEqualTo(caps.maximumApfProgramSize)
+        expect.withMessage("IPV4_HEADER_SIZE").that(buffer.getInt()).isEqualTo(0)
+        // Ping packet (64) + IPv6 header (40) + ethernet header (14)
+        expect.withMessage("PACKET_SIZE").that(buffer.getInt()).isEqualTo(64 + 40 + 14)
+        expect.withMessage("FILTER_AGE_SECONDS").that(buffer.getInt()).isLessThan(5)
+    }
 }
diff --git a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
index c0f1080..5ed4696 100644
--- a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
+++ b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
@@ -213,6 +213,7 @@
 
 import org.junit.After;
 import org.junit.Before;
+import org.junit.Ignore;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -3556,6 +3557,8 @@
         doTestFirewallBlocking(FIREWALL_CHAIN_DOZABLE, ALLOWLIST);
     }
 
+    // Disable test - needs to be fixed
+    @Ignore
     @Test @IgnoreUpTo(Build.VERSION_CODES.UPSIDE_DOWN_CAKE) @ConnectivityModuleTest
     @AppModeFull(reason = "Socket cannot bind in instant app mode")
     public void testFirewallBlockingBackground() {
diff --git a/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt b/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt
index 5ba6c4c..93cec9c 100644
--- a/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt
+++ b/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt
@@ -287,6 +287,12 @@
 ): TestDnsPacket? = pollForMdnsPacket(timeoutMs) { it.isQueryFor(recordName, *requiredTypes) }
 
 fun TapPacketReader.pollForReply(
+    recordName: String,
+    type: Int,
+    timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
+): TestDnsPacket? = pollForMdnsPacket(timeoutMs) { it.isReplyFor(recordName, type) }
+
+fun TapPacketReader.pollForReply(
     serviceName: String,
     serviceType: String,
     timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
diff --git a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
index 6dd4857..6394599 100644
--- a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
+++ b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
@@ -81,7 +81,9 @@
 import com.android.compatibility.common.util.SystemUtil
 import com.android.modules.utils.build.SdkLevel.isAtLeastU
 import com.android.net.module.util.DnsPacket
+import com.android.net.module.util.DnsPacket.ANSECTION
 import com.android.net.module.util.HexDump
+import com.android.net.module.util.HexDump.hexStringToByteArray
 import com.android.net.module.util.NetworkStackConstants.IPV6_ADDR_LEN
 import com.android.net.module.util.PacketBuilder
 import com.android.testutils.ConnectivityModuleTest
@@ -96,6 +98,7 @@
 import com.android.testutils.TestableNetworkAgent
 import com.android.testutils.TestableNetworkAgent.CallbackEntry.OnNetworkCreated
 import com.android.testutils.TestableNetworkCallback
+import com.android.testutils.assertContainsExactly
 import com.android.testutils.assertEmpty
 import com.android.testutils.filters.CtsNetTestCasesMaxTargetSdk30
 import com.android.testutils.filters.CtsNetTestCasesMaxTargetSdk33
@@ -127,6 +130,7 @@
 import org.junit.Rule
 import org.junit.Test
 import org.junit.runner.RunWith
+import kotlin.test.assertNotEquals
 
 private const val TAG = "NsdManagerTest"
 private const val TIMEOUT_MS = 2000L
@@ -138,6 +142,9 @@
 private const val DBG = false
 private const val TEST_PORT = 12345
 private const val MDNS_PORT = 5353.toShort()
+private const val TYPE_KEY = 25
+private const val QCLASS_INTERNET = 0x0001
+private const val NAME_RECORDS_TTL_MILLIS: Long = 120
 private val multicastIpv6Addr = parseNumericAddress("ff02::fb") as Inet6Address
 private val testSrcAddr = parseNumericAddress("2001:db8::123") as Inet6Address
 
@@ -167,6 +174,12 @@
     private val serviceType2 = "_nmt%09d._tcp".format(Random().nextInt(1_000_000_000))
     private val customHostname = "NsdTestHost%09d".format(Random().nextInt(1_000_000_000))
     private val customHostname2 = "NsdTestHost%09d".format(Random().nextInt(1_000_000_000))
+    private val publicKey = hexStringToByteArray(
+            "0201030dc141d0637960b98cbc12cfca"
+                    + "221d2879dac26ee5b460e9007c992e19"
+                    + "02d897c391b03764d448f7d0c772fdb0"
+                    + "3b1d9d6d52ff8886769e8e2362513565"
+                    + "270962d3")
     private val handlerThread = HandlerThread(NsdManagerTest::class.java.simpleName)
     private val ctsNetUtils by lazy{ CtsNetUtils(context) }
 
@@ -1451,10 +1464,8 @@
         handlerThread.waitForIdle(TIMEOUT_MS)
 
         tryTest {
-            repeat(3) {
-                assertNotNull(packetReader.pollForAdvertisement(serviceName, serviceType),
-                        "Expect 3 announcements sent after initial probing")
-            }
+            assertNotNull(packetReader.pollForAdvertisement(serviceName, serviceType),
+                "No announcements sent after initial probing")
 
             assertEquals(si.serviceName, registeredService.serviceName)
             assertEquals(si.hostname, registeredService.hostname)
@@ -2027,7 +2038,7 @@
     }
 
     @Test
-    fun testAdvertisingAndDiscovery_multipleRegistrationsForSameCustomHost_unionOfAddressesFound() {
+    fun testAdvertisingAndDiscovery_multipleRegistrationsForSameCustomHost_hostRenamed() {
         val hostAddresses1 = listOf(
                 parseNumericAddress("192.0.2.23"),
                 parseNumericAddress("2001:db8::1"),
@@ -2035,9 +2046,6 @@
         val hostAddresses2 = listOf(
                 parseNumericAddress("192.0.2.24"),
                 parseNumericAddress("2001:db8::3"))
-        val hostAddresses3 = listOf(
-                parseNumericAddress("2001:db8::3"),
-                parseNumericAddress("2001:db8::5"))
         val si1 = NsdServiceInfo().also {
             it.network = testNetwork1.network
             it.hostname = customHostname
@@ -2051,18 +2059,9 @@
             it.hostname = customHostname
             it.hostAddresses = hostAddresses2
         }
-        val si3 = NsdServiceInfo().also {
-            it.network = testNetwork1.network
-            it.serviceName = serviceName3
-            it.serviceType = serviceType
-            it.port = TEST_PORT + 1
-            it.hostname = customHostname
-            it.hostAddresses = hostAddresses3
-        }
 
         val registrationRecord1 = NsdRegistrationRecord()
         val registrationRecord2 = NsdRegistrationRecord()
-        val registrationRecord3 = NsdRegistrationRecord()
 
         val discoveryRecord = NsdDiscoveryRecord()
         tryTest {
@@ -2072,27 +2071,13 @@
             nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD,
                     testNetwork1.network, Executor { it.run() }, discoveryRecord)
 
-            val discoveredInfo1 = discoveryRecord.waitForServiceDiscovered(
+            val discoveredInfo = discoveryRecord.waitForServiceDiscovered(
                     serviceName, serviceType, testNetwork1.network)
-            val resolvedInfo1 = resolveService(discoveredInfo1)
+            val resolvedInfo = resolveService(discoveredInfo)
 
-            assertEquals(TEST_PORT, resolvedInfo1.port)
-            assertEquals(si1.hostname, resolvedInfo1.hostname)
-            assertAddressEquals(
-                    hostAddresses1 + hostAddresses2,
-                    resolvedInfo1.hostAddresses)
-
-            registerService(registrationRecord3, si3)
-
-            val discoveredInfo2 = discoveryRecord.waitForServiceDiscovered(
-                    serviceName3, serviceType, testNetwork1.network)
-            val resolvedInfo2 = resolveService(discoveredInfo2)
-
-            assertEquals(TEST_PORT + 1, resolvedInfo2.port)
-            assertEquals(si2.hostname, resolvedInfo2.hostname)
-            assertAddressEquals(
-                    hostAddresses1 + hostAddresses2 + hostAddresses3,
-                    resolvedInfo2.hostAddresses)
+            assertEquals(TEST_PORT, resolvedInfo.port)
+            assertNotEquals(si1.hostname, resolvedInfo.hostname)
+            assertAddressEquals(hostAddresses2, resolvedInfo.hostAddresses)
         } cleanupStep {
             nsdManager.stopServiceDiscovery(discoveryRecord)
 
@@ -2100,7 +2085,6 @@
         } cleanup {
             nsdManager.unregisterService(registrationRecord1)
             nsdManager.unregisterService(registrationRecord2)
-            nsdManager.unregisterService(registrationRecord3)
         }
     }
 
@@ -2266,6 +2250,165 @@
     }
 
     @Test
+    fun testAdvertising_registerServiceAndPublicKey_keyAnnounced() {
+        val si = NsdServiceInfo().also {
+            it.network = testNetwork1.network
+            it.serviceType = serviceType
+            it.serviceName = serviceName
+            it.port = TEST_PORT
+            it.publicKey = publicKey
+        }
+        val packetReader = TapPacketReader(Handler(handlerThread.looper),
+                testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+        packetReader.startAsyncForTest()
+        handlerThread.waitForIdle(TIMEOUT_MS)
+
+        val registrationRecord = NsdRegistrationRecord()
+        val discoveryRecord = NsdDiscoveryRecord()
+        tryTest {
+            registerService(registrationRecord, si)
+
+            val announcement = packetReader.pollForReply(
+                "$serviceName.$serviceType.local",
+                TYPE_KEY
+            )
+            assertNotNull(announcement)
+            val keyRecords = announcement.records[ANSECTION].filter { it.nsType == TYPE_KEY }
+            assertEquals(1, keyRecords.size)
+            val actualRecord = keyRecords.get(0)
+            assertEquals(TYPE_KEY, actualRecord.nsType)
+            assertEquals("$serviceName.$serviceType.local", actualRecord.dName)
+            assertEquals(NAME_RECORDS_TTL_MILLIS, actualRecord.ttl)
+            assertArrayEquals(publicKey, actualRecord.rr)
+
+            nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD,
+                    testNetwork1.network, Executor { it.run() }, discoveryRecord)
+
+            val discoveredInfo1 = discoveryRecord.waitForServiceDiscovered(
+                    serviceName, serviceType, testNetwork1.network)
+            val resolvedInfo1 = resolveService(discoveredInfo1)
+
+            assertEquals(serviceName, discoveredInfo1.serviceName)
+            assertEquals(TEST_PORT, resolvedInfo1.port)
+        } cleanupStep {
+            nsdManager.stopServiceDiscovery(discoveryRecord)
+
+            discoveryRecord.expectCallback<DiscoveryStopped>()
+        } cleanup {
+            nsdManager.unregisterService(registrationRecord)
+        }
+    }
+
+    @Test
+    fun testAdvertising_registerCustomHostAndPublicKey_keyAnnounced() {
+        val si = NsdServiceInfo().also {
+            it.network = testNetwork1.network
+            it.hostname = customHostname
+            it.hostAddresses = listOf(
+                    parseNumericAddress("192.0.2.23"),
+                    parseNumericAddress("2001:db8::1"),
+                    parseNumericAddress("2001:db8::2"))
+            it.publicKey = publicKey
+        }
+        val packetReader = TapPacketReader(Handler(handlerThread.looper),
+                testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+        packetReader.startAsyncForTest()
+        handlerThread.waitForIdle(TIMEOUT_MS)
+
+        val registrationRecord = NsdRegistrationRecord()
+        tryTest {
+            registerService(registrationRecord, si)
+
+            val announcement = packetReader.pollForReply("$customHostname.local", TYPE_KEY)
+            assertNotNull(announcement)
+            val keyRecords = announcement.records[ANSECTION].filter { it.nsType == TYPE_KEY }
+            assertEquals(1, keyRecords.size)
+            val actualRecord = keyRecords.get(0)
+            assertEquals(TYPE_KEY, actualRecord.nsType)
+            assertEquals("$customHostname.local", actualRecord.dName)
+            assertEquals(NAME_RECORDS_TTL_MILLIS, actualRecord.ttl)
+            assertArrayEquals(publicKey, actualRecord.rr)
+
+            // This test case focuses on key announcement so we don't check the details of the
+            // announcement of the custom host addresses.
+            val addressRecords = announcement.records[ANSECTION].filter {
+                it.nsType == DnsResolver.TYPE_AAAA ||
+                        it.nsType == DnsResolver.TYPE_A
+            }
+            assertEquals(3, addressRecords.size)
+        } cleanup {
+            nsdManager.unregisterService(registrationRecord)
+        }
+    }
+
+    @Test
+    fun testAdvertising_registerTwoServicesWithSameCustomHostAndPublicKey_keyAnnounced() {
+        val si1 = NsdServiceInfo().also {
+            it.network = testNetwork1.network
+            it.serviceType = serviceType
+            it.serviceName = serviceName
+            it.port = TEST_PORT
+            it.hostname = customHostname
+            it.hostAddresses = listOf(
+                parseNumericAddress("192.0.2.23"),
+                parseNumericAddress("2001:db8::1"),
+                parseNumericAddress("2001:db8::2"))
+            it.publicKey = publicKey
+        }
+        val si2 = NsdServiceInfo().also {
+            it.network = testNetwork1.network
+            it.serviceType = serviceType2
+            it.serviceName = serviceName2
+            it.port = TEST_PORT + 1
+            it.hostname = customHostname
+            it.hostAddresses = listOf()
+            it.publicKey = publicKey
+        }
+        val packetReader = TapPacketReader(Handler(handlerThread.looper),
+            testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+        packetReader.startAsyncForTest()
+        handlerThread.waitForIdle(TIMEOUT_MS)
+
+        val registrationRecord1 = NsdRegistrationRecord()
+        val registrationRecord2 = NsdRegistrationRecord()
+        tryTest {
+            registerService(registrationRecord1, si1)
+
+            var announcement =
+                packetReader.pollForReply("$serviceName.$serviceType.local", TYPE_KEY)
+            assertNotNull(announcement)
+            var keyRecords = announcement.records[ANSECTION].filter { it.nsType == TYPE_KEY }
+            assertEquals(2, keyRecords.size)
+            assertTrue(keyRecords.any { it.dName == "$serviceName.$serviceType.local" })
+            assertTrue(keyRecords.any { it.dName == "$customHostname.local" })
+            assertTrue(keyRecords.all { it.ttl == NAME_RECORDS_TTL_MILLIS })
+            assertTrue(keyRecords.all { it.rr.contentEquals(publicKey) })
+
+            // This test case focuses on key announcement so we don't check the details of the
+            // announcement of the custom host addresses.
+            val addressRecords = announcement.records[ANSECTION].filter {
+                it.nsType == DnsResolver.TYPE_AAAA ||
+                        it.nsType == DnsResolver.TYPE_A
+            }
+            assertEquals(3, addressRecords.size)
+
+            registerService(registrationRecord2, si2)
+
+            announcement = packetReader.pollForReply("$serviceName2.$serviceType2.local", TYPE_KEY)
+            assertNotNull(announcement)
+            keyRecords = announcement.records[ANSECTION].filter { it.nsType == TYPE_KEY }
+            assertEquals(2, keyRecords.size)
+            assertTrue(keyRecords.any { it.dName == "$serviceName2.$serviceType2.local" })
+            assertTrue(keyRecords.any { it.dName == "$customHostname.local" })
+            assertTrue(keyRecords.all { it.ttl == NAME_RECORDS_TTL_MILLIS })
+            assertTrue(keyRecords.all { it.rr.contentEquals(publicKey) })
+        } cleanup {
+            nsdManager.unregisterService(registrationRecord1)
+            nsdManager.unregisterService(registrationRecord2)
+        }
+    }
+
+    @Test
     fun testServiceTypeClientRemovedAfterSocketDestroyed() {
         val si = makeTestServiceInfo(testNetwork1.network)
         // Register service on testNetwork1
diff --git a/tests/unit/java/android/net/nsd/NsdManagerTest.java b/tests/unit/java/android/net/nsd/NsdManagerTest.java
index 27c4561..9c812a1 100644
--- a/tests/unit/java/android/net/nsd/NsdManagerTest.java
+++ b/tests/unit/java/android/net/nsd/NsdManagerTest.java
@@ -16,6 +16,11 @@
 
 package android.net.nsd;
 
+import static android.net.InetAddresses.parseNumericAddress;
+import static android.net.nsd.NsdManager.checkServiceInfoForRegistration;
+
+import static com.android.net.module.util.HexDump.hexStringToByteArray;
+
 import static libcore.junit.util.compat.CoreCompatChangeRule.DisableCompatChanges;
 import static libcore.junit.util.compat.CoreCompatChangeRule.EnableCompatChanges;
 
@@ -54,6 +59,7 @@
 import org.mockito.MockitoAnnotations;
 
 import java.net.InetAddress;
+import java.util.Collections;
 import java.util.List;
 import java.time.Duration;
 
@@ -395,6 +401,7 @@
         NsdManager.RegistrationListener listener4 = mock(NsdManager.RegistrationListener.class);
         NsdManager.RegistrationListener listener5 = mock(NsdManager.RegistrationListener.class);
         NsdManager.RegistrationListener listener6 = mock(NsdManager.RegistrationListener.class);
+        NsdManager.RegistrationListener listener7 = mock(NsdManager.RegistrationListener.class);
 
         NsdServiceInfo invalidService = new NsdServiceInfo(null, null);
         NsdServiceInfo validService = new NsdServiceInfo("a_name", "_a_type._tcp");
@@ -439,6 +446,19 @@
         validServiceWithCustomHostNoAddresses.setPort(2222);
         validServiceWithCustomHostNoAddresses.setHostname("a_host");
 
+        NsdServiceInfo validServiceWithPublicKey = new NsdServiceInfo("a_name", "_a_type._tcp");
+        validServiceWithPublicKey.setPublicKey(
+                hexStringToByteArray(
+                        "0201030dc141d0637960b98cbc12cfca"
+                                + "221d2879dac26ee5b460e9007c992e19"
+                                + "02d897c391b03764d448f7d0c772fdb0"
+                                + "3b1d9d6d52ff8886769e8e2362513565"
+                                + "270962d3"));
+
+        NsdServiceInfo invalidServiceWithTooShortPublicKey =
+                new NsdServiceInfo("a_name", "_a_type._tcp");
+        invalidServiceWithTooShortPublicKey.setPublicKey(hexStringToByteArray("0201"));
+
         // Service registration
         //  - invalid arguments
         mustFail(() -> { manager.unregisterService(null); });
@@ -449,6 +469,8 @@
         mustFail(() -> { manager.registerService(validService, PROTOCOL, null); });
         mustFail(() -> {
             manager.registerService(invalidMissingHostnameWithAddresses, PROTOCOL, listener1); });
+        mustFail(() -> {
+            manager.registerService(invalidServiceWithTooShortPublicKey, PROTOCOL, listener1); });
         manager.registerService(validService, PROTOCOL, listener1);
         //  - update without subtype is not allowed
         mustFail(() -> { manager.registerService(validServiceDuplicate, PROTOCOL, listener1); });
@@ -479,6 +501,9 @@
         //  - registering a service with a custom host with no addresses is valid
         manager.registerService(validServiceWithCustomHostNoAddresses, PROTOCOL, listener6);
         manager.unregisterService(listener6);
+        //  - registering a service with a public key is valid
+        manager.registerService(validServiceWithPublicKey, PROTOCOL, listener7);
+        manager.unregisterService(listener7);
 
         // Discover service
         //  - invalid arguments
@@ -506,6 +531,229 @@
         mustFail(() -> { manager.resolveService(validService, listener3); });
     }
 
+    private static final class NsdServiceInfoBuilder {
+        private static final String SERVICE_NAME = "TestService";
+        private static final String SERVICE_TYPE = "_testservice._tcp";
+        private static final int SERVICE_PORT = 12345;
+        private static final String HOSTNAME = "TestHost";
+        private static final List<InetAddress> HOST_ADDRESSES =
+                List.of(parseNumericAddress("192.168.2.23"), parseNumericAddress("2001:db8::3"));
+        private static final byte[] PUBLIC_KEY =
+                hexStringToByteArray(
+                        "0201030dc141d0637960b98cbc12cfca"
+                                + "221d2879dac26ee5b460e9007c992e19"
+                                + "02d897c391b03764d448f7d0c772fdb0"
+                                + "3b1d9d6d52ff8886769e8e2362513565"
+                                + "270962d3");
+
+        private final NsdServiceInfo mNsdServiceInfo = new NsdServiceInfo();
+
+        NsdServiceInfo build() {
+            return mNsdServiceInfo;
+        }
+
+        NsdServiceInfoBuilder setNoService() {
+            mNsdServiceInfo.setServiceName(null);
+            mNsdServiceInfo.setServiceType(null);
+            mNsdServiceInfo.setPort(0);
+            return this;
+        }
+
+        NsdServiceInfoBuilder setService() {
+            mNsdServiceInfo.setServiceName(SERVICE_NAME);
+            mNsdServiceInfo.setServiceType(SERVICE_TYPE);
+            mNsdServiceInfo.setPort(SERVICE_PORT);
+            return this;
+        }
+
+        NsdServiceInfoBuilder setZeroPortService() {
+            mNsdServiceInfo.setServiceName(SERVICE_NAME);
+            mNsdServiceInfo.setServiceType(SERVICE_TYPE);
+            mNsdServiceInfo.setPort(0);
+            return this;
+        }
+
+        NsdServiceInfoBuilder setInvalidService() {
+            mNsdServiceInfo.setServiceName(SERVICE_NAME);
+            mNsdServiceInfo.setServiceType(null);
+            mNsdServiceInfo.setPort(SERVICE_PORT);
+            return this;
+        }
+
+        NsdServiceInfoBuilder setDefaultHost() {
+            mNsdServiceInfo.setHostname(null);
+            mNsdServiceInfo.setHostAddresses(Collections.emptyList());
+            return this;
+        }
+
+        NsdServiceInfoBuilder setCustomHost() {
+            mNsdServiceInfo.setHostname(HOSTNAME);
+            mNsdServiceInfo.setHostAddresses(HOST_ADDRESSES);
+            return this;
+        }
+
+        NsdServiceInfoBuilder setCustomHostNoAddress() {
+            mNsdServiceInfo.setHostname(HOSTNAME);
+            mNsdServiceInfo.setHostAddresses(Collections.emptyList());
+            return this;
+        }
+
+        NsdServiceInfoBuilder setHostAddressesNoHostname() {
+            mNsdServiceInfo.setHostname(null);
+            mNsdServiceInfo.setHostAddresses(HOST_ADDRESSES);
+            return this;
+        }
+
+        NsdServiceInfoBuilder setNoPublicKey() {
+            mNsdServiceInfo.setPublicKey(null);
+            return this;
+        }
+
+        NsdServiceInfoBuilder setPublicKey() {
+            mNsdServiceInfo.setPublicKey(PUBLIC_KEY);
+            return this;
+        }
+
+        NsdServiceInfoBuilder setInvalidPublicKey() {
+            mNsdServiceInfo.setPublicKey(new byte[3]);
+            return this;
+        }
+    }
+
+    @Test
+    public void testCheckServiceInfoForRegistration() {
+        // The service is invalid
+        mustFail(() -> checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setInvalidService()
+                        .setCustomHost()
+                        .setPublicKey().build()));
+        // Keep compatible with the legacy behavior: It's allowed to set host
+        // addresses for a service registration although the host addresses
+        // won't be registered. To register the addresses for a host, the
+        // hostname must be specified.
+        checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setService()
+                        .setHostAddressesNoHostname()
+                        .setPublicKey().build());
+        // The public key is invalid
+        mustFail(() -> checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setService()
+                        .setCustomHost()
+                        .setInvalidPublicKey().build()));
+        // Invalid combinations
+        // 1. (service, custom host, key): valid
+        checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setService()
+                        .setCustomHost()
+                        .setPublicKey().build());
+        // 2. (service, custom host, no key): valid
+        checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setService()
+                        .setCustomHost()
+                        .setNoPublicKey().build());
+        // 3. (service, no-address custom host, key): valid
+        checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setService()
+                        .setCustomHostNoAddress()
+                        .setPublicKey().build());
+        // 4. (service, no-address custom host, no key): valid
+        checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setService()
+                        .setCustomHostNoAddress()
+                        .setNoPublicKey().build());
+        // 5. (service, default host, key): valid
+        checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setService()
+                        .setDefaultHost()
+                        .setPublicKey().build());
+        // 6. (service, default host, no key): valid
+        checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setService()
+                        .setDefaultHost()
+                        .setNoPublicKey().build());
+        // 7. (0-port service, custom host, valid key): valid
+        checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setZeroPortService()
+                        .setCustomHost()
+                        .setPublicKey().build());
+        // 8. (0-port service, custom host, no key): invalid
+        mustFail(() -> checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setZeroPortService()
+                        .setCustomHost()
+                        .setNoPublicKey().build()));
+        // 9. (0-port service, no-address custom host, key): valid
+        checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setZeroPortService()
+                        .setCustomHostNoAddress()
+                        .setPublicKey().build());
+        // 10. (0-port service, no-address custom host, no key): invalid
+        mustFail(() -> checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setZeroPortService()
+                        .setCustomHostNoAddress()
+                        .setNoPublicKey().build()));
+        // 11. (0-port service, default host, key): valid
+        checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setZeroPortService()
+                        .setDefaultHost()
+                        .setPublicKey().build());
+        // 12. (0-port service, default host, no key): invalid
+        mustFail(() -> checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setZeroPortService()
+                        .setDefaultHost()
+                        .setNoPublicKey().build()));
+        // 13. (no service, custom host, key): valid
+        checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setNoService()
+                        .setCustomHost()
+                        .setPublicKey().build());
+        // 14. (no service, custom host, no key): valid
+        checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setNoService()
+                        .setCustomHost()
+                        .setNoPublicKey().build());
+        // 15. (no service, no-address custom host, key): valid
+        checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setNoService()
+                        .setCustomHostNoAddress()
+                        .setPublicKey().build());
+        // 16. (no service, no-address custom host, no key): invalid
+        mustFail(() -> checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setNoService()
+                        .setCustomHostNoAddress()
+                        .setNoPublicKey().build()));
+        // 17. (no service, default host, key): invalid
+        mustFail(() -> checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setNoService()
+                        .setDefaultHost()
+                        .setPublicKey().build()));
+        // 18. (no service, default host, no key): invalid
+        mustFail(() -> checkServiceInfoForRegistration(
+                new NsdServiceInfoBuilder()
+                        .setNoService()
+                        .setDefaultHost()
+                        .setNoPublicKey().build()));
+    }
+
     public void mustFail(Runnable fn) {
         try {
             fn.run();
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
index f7e0b0e..d735dc6 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
@@ -21,11 +21,13 @@
 import android.net.nsd.NsdServiceInfo
 import android.os.Build
 import android.os.HandlerThread
+import com.android.net.module.util.HexDump.hexStringToByteArray
 import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
 import com.android.server.connectivity.mdns.MdnsInterfaceAdvertiser.CONFLICT_HOST
 import com.android.server.connectivity.mdns.MdnsInterfaceAdvertiser.CONFLICT_SERVICE
 import com.android.server.connectivity.mdns.MdnsRecord.TYPE_A
 import com.android.server.connectivity.mdns.MdnsRecord.TYPE_AAAA
+import com.android.server.connectivity.mdns.MdnsRecord.TYPE_KEY
 import com.android.server.connectivity.mdns.MdnsRecord.TYPE_PTR
 import com.android.server.connectivity.mdns.MdnsRecord.TYPE_SRV
 import com.android.server.connectivity.mdns.MdnsRecord.TYPE_TXT
@@ -120,6 +122,20 @@
     port = TEST_PORT
 }
 
+private val TEST_PUBLIC_KEY = hexStringToByteArray(
+        "0201030dc141d0637960b98cbc12cfca"
+                + "221d2879dac26ee5b460e9007c992e19"
+                + "02d897c391b03764d448f7d0c772fdb0"
+                + "3b1d9d6d52ff8886769e8e2362513565"
+                + "270962d3")
+
+private val TEST_PUBLIC_KEY_2 = hexStringToByteArray(
+        "0201030dc141d0637960b98cbc12cfca"
+                + "221d2879dac26ee5b460e9007c992e19"
+                + "02d897c391b03764d448f7d0c772fdb0"
+                + "3b1d9d6d52ff8886769e8e2362513565"
+                + "270962d4")
+
 @RunWith(DevSdkIgnoreRunner::class)
 @DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
 class MdnsRecordRepositoryTest {
@@ -581,6 +597,7 @@
             TYPE_PTR -> return MdnsPointerRecord(name, false /* isUnicast */)
             TYPE_SRV -> return MdnsServiceRecord(name, false /* isUnicast */)
             TYPE_TXT -> return MdnsTextRecord(name, false /* isUnicast */)
+            TYPE_KEY -> return MdnsKeyRecord(name, false /* isUnicast */)
             TYPE_A, TYPE_AAAA -> return MdnsInetAddressRecord(name, type, false /* isUnicast */)
             else -> fail("Unexpected question type: $type")
         }
@@ -908,6 +925,159 @@
     }
 
     @Test
+    fun testGetReply_keyQuestionForServiceName_returnsKeyRecord() {
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
+
+        repository.addServiceAndFinishProbing(TEST_SERVICE_ID_1, NsdServiceInfo().apply {
+            serviceType = "_testservice._tcp"
+            serviceName = "MyTestService1"
+            port = TEST_PORT
+            publicKey = TEST_PUBLIC_KEY
+        })
+        repository.addServiceAndFinishProbing(TEST_SERVICE_ID_2, NsdServiceInfo().apply {
+            serviceType = "_testservice._tcp"
+            serviceName = "MyTestService2"
+            port = 0 // No SRV RR
+            publicKey = TEST_PUBLIC_KEY
+        })
+        val src = InetSocketAddress(parseNumericAddress("fe80::123"), 5353)
+        val serviceName1 = arrayOf("MyTestService1", "_testservice", "_tcp", "local")
+        val serviceName2 = arrayOf("MyTestService2", "_testservice", "_tcp", "local")
+
+        val query1 = makeQuery(TYPE_KEY to serviceName1)
+        val reply1 = repository.getReply(query1, src)
+
+        assertNotNull(reply1)
+        assertEquals(listOf(MdnsKeyRecord(serviceName1,
+                0, false, LONG_TTL, TEST_PUBLIC_KEY)),
+                reply1.answers)
+        assertEquals(listOf(),
+                reply1.additionalAnswers)
+
+        val query2 = makeQuery(TYPE_KEY to serviceName2)
+        val reply2 = repository.getReply(query2, src)
+
+        assertNotNull(reply2)
+        assertEquals(listOf(MdnsKeyRecord(serviceName2,
+                0, false, LONG_TTL, TEST_PUBLIC_KEY)),
+                reply2.answers)
+        assertEquals(listOf(MdnsNsecRecord(serviceName2,
+                0L, true, SHORT_TTL,
+                serviceName2 /* nextDomain */,
+                intArrayOf(TYPE_KEY))),
+                reply2.additionalAnswers)
+    }
+
+    @Test
+    fun testGetReply_keyQuestionForHostname_returnsKeyRecord() {
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
+
+        repository.addServiceAndFinishProbing(TEST_SERVICE_ID_1, NsdServiceInfo().apply {
+            hostname = "MyHost1"
+            hostAddresses = listOf(
+                    parseNumericAddress("2001:db8::1"),
+                    parseNumericAddress("2001:db8::2"))
+            publicKey = TEST_PUBLIC_KEY
+        })
+        repository.addServiceAndFinishProbing(TEST_SERVICE_ID_2, NsdServiceInfo().apply {
+            hostname = "MyHost2"
+            hostAddresses = listOf() // No address records
+            publicKey = TEST_PUBLIC_KEY
+        })
+        val src = InetSocketAddress(parseNumericAddress("fe80::123"), 5353)
+        val hostname1 = arrayOf("MyHost1", "local")
+        val hostname2 = arrayOf("MyHost2", "local")
+
+        val query1 = makeQuery(TYPE_KEY to hostname1)
+        val reply1 = repository.getReply(query1, src)
+
+        assertNotNull(reply1)
+        assertEquals(listOf(MdnsKeyRecord(hostname1,
+                0, false, LONG_TTL, TEST_PUBLIC_KEY)),
+                reply1.answers)
+        assertEquals(listOf(),
+                reply1.additionalAnswers)
+
+        val query2 = makeQuery(TYPE_KEY to hostname2)
+        val reply2 = repository.getReply(query2, src)
+
+        assertNotNull(reply2)
+        assertEquals(listOf(MdnsKeyRecord(hostname2,
+                0, false, LONG_TTL, TEST_PUBLIC_KEY)),
+                reply2.answers)
+        assertEquals(listOf(MdnsNsecRecord(hostname2, 0L, true, SHORT_TTL,
+                hostname2 /* nextDomain */,
+                intArrayOf(TYPE_KEY))),
+                reply2.additionalAnswers)
+    }
+
+    @Test
+    fun testGetReply_keyRecordForHostRemoved_noAnswertoKeyQuestion() {
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
+
+        repository.addServiceAndFinishProbing(TEST_SERVICE_ID_1, NsdServiceInfo().apply {
+            hostname = "MyHost1"
+            hostAddresses = listOf(
+                    parseNumericAddress("2001:db8::1"),
+                    parseNumericAddress("2001:db8::2"))
+            publicKey = TEST_PUBLIC_KEY
+        })
+        repository.addServiceAndFinishProbing(TEST_SERVICE_ID_2, NsdServiceInfo().apply {
+            hostname = "MyHost2"
+            hostAddresses = listOf() // No address records
+            publicKey = TEST_PUBLIC_KEY
+        })
+        repository.removeService(TEST_SERVICE_ID_1)
+        repository.removeService(TEST_SERVICE_ID_2)
+        val src = InetSocketAddress(parseNumericAddress("fe80::123"), 5353)
+        val hostname1 = arrayOf("MyHost1", "local")
+        val hostname2 = arrayOf("MyHost2", "local")
+
+        val query1 = makeQuery(TYPE_KEY to hostname1)
+        val reply1 = repository.getReply(query1, src)
+
+        assertNull(reply1)
+
+        val query2 = makeQuery(TYPE_KEY to hostname2)
+        val reply2 = repository.getReply(query2, src)
+
+        assertNull(reply2)
+    }
+
+    @Test
+    fun testGetReply_keyRecordForServiceRemoved_noAnswertoKeyQuestion() {
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
+
+        repository.addServiceAndFinishProbing(TEST_SERVICE_ID_1, NsdServiceInfo().apply {
+            serviceType = "_testservice._tcp"
+            serviceName = "MyTestService1"
+            port = TEST_PORT
+            publicKey = TEST_PUBLIC_KEY
+        })
+        repository.addServiceAndFinishProbing(TEST_SERVICE_ID_2, NsdServiceInfo().apply {
+            serviceType = "_testservice._tcp"
+            serviceName = "MyTestService2"
+            port = 0 // No SRV RR
+            publicKey = TEST_PUBLIC_KEY
+        })
+        repository.removeService(TEST_SERVICE_ID_1)
+        repository.removeService(TEST_SERVICE_ID_2)
+        val src = InetSocketAddress(parseNumericAddress("fe80::123"), 5353)
+        val serviceName1 = arrayOf("MyTestService1", "_testservice", "_tcp", "local")
+        val serviceName2 = arrayOf("MyTestService2", "_testservice", "_tcp", "local")
+
+        val query1 = makeQuery(TYPE_KEY to serviceName1)
+        val reply1 = repository.getReply(query1, src)
+
+        assertNull(reply1)
+
+        val query2 = makeQuery(TYPE_KEY to serviceName2)
+        val reply2 = repository.getReply(query2, src)
+
+        assertNull(reply2)
+    }
+
+    @Test
     fun testGetReply_customHostRemoved_noAnswerToAAAAQuestion() {
         val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
         repository.initWithService(
@@ -1221,8 +1391,8 @@
     @Test
     fun testGetConflictingServices_customHostsReplyHasFewerAddressesThanUs_noConflict() {
         val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
-        repository.addService(TEST_CUSTOM_HOST_ID_1, TEST_CUSTOM_HOST_1, null /* ttl */)
-        repository.addService(TEST_CUSTOM_HOST_ID_2, TEST_CUSTOM_HOST_2, null /* ttl */)
+        repository.addServiceAndFinishProbing(TEST_CUSTOM_HOST_ID_1, TEST_CUSTOM_HOST_1)
+        repository.addServiceAndFinishProbing(TEST_CUSTOM_HOST_ID_2, TEST_CUSTOM_HOST_2)
 
         val packet = MdnsPacket(
                 0, /* flags */
@@ -1240,10 +1410,30 @@
     }
 
     @Test
-    fun testGetConflictingServices_customHostsReplyHasIdenticalHosts_noConflict() {
+    fun testGetConflictingServices_customHostsReplyHasSameNameRecord_conflictDuringProbing() {
         val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
         repository.addService(TEST_CUSTOM_HOST_ID_1, TEST_CUSTOM_HOST_1, null /* ttl */)
-        repository.addService(TEST_CUSTOM_HOST_ID_2, TEST_CUSTOM_HOST_2, null /* ttl */)
+        repository.addServiceAndFinishProbing(TEST_CUSTOM_HOST_ID_2, TEST_CUSTOM_HOST_2)
+
+        val packet = MdnsPacket(
+            0, /* flags */
+            emptyList(), /* questions */
+            listOf(MdnsKeyRecord(arrayOf("TestHost", "local"),
+                    0L /* receiptTimeMillis */, true /* cacheFlush */,
+                    0L /* ttlMillis */, TEST_PUBLIC_KEY),
+            ) /* answers */,
+            emptyList() /* authorityRecords */,
+            emptyList() /* additionalRecords */)
+
+        assertEquals(mapOf(TEST_CUSTOM_HOST_ID_1 to CONFLICT_HOST),
+            repository.getConflictingServices(packet))
+    }
+
+    @Test
+    fun testGetConflictingServices_customHostsReplyHasIdenticalHosts_noConflict() {
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
+        repository.addServiceAndFinishProbing(TEST_CUSTOM_HOST_ID_1, TEST_CUSTOM_HOST_1)
+        repository.addServiceAndFinishProbing(TEST_CUSTOM_HOST_ID_2, TEST_CUSTOM_HOST_2)
 
         val packet = MdnsPacket(
                 0, /* flags */
@@ -1267,8 +1457,8 @@
     @Test
     fun testGetConflictingServices_customHostsCaseInsensitiveReplyHasIdenticalHosts_noConflict() {
         val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
-        repository.addService(TEST_CUSTOM_HOST_ID_1, TEST_CUSTOM_HOST_1, null /* ttl */)
-        repository.addService(TEST_CUSTOM_HOST_ID_2, TEST_CUSTOM_HOST_2, null /* ttl */)
+        repository.addServiceAndFinishProbing(TEST_CUSTOM_HOST_ID_1, TEST_CUSTOM_HOST_1)
+        repository.addServiceAndFinishProbing(TEST_CUSTOM_HOST_ID_2, TEST_CUSTOM_HOST_2)
 
         val packet = MdnsPacket(
                 0, /* flags */
@@ -1289,6 +1479,152 @@
     }
 
     @Test
+    fun testGetConflictingServices_identicalKeyRecordsForService_noConflict() {
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
+
+        repository.addService(TEST_SERVICE_ID_1, NsdServiceInfo().apply {
+            serviceType = "_testservice._tcp"
+            serviceName = "MyTestService"
+            port = TEST_PORT
+            publicKey = TEST_PUBLIC_KEY
+        }, null /* ttl */)
+
+        val otherTtlMillis = 1234L
+        val packet = MdnsPacket(
+                0 /* flags */,
+                emptyList() /* questions */,
+                listOf(
+                        MdnsKeyRecord(
+                                arrayOf("MyTestService", "_testservice", "_tcp", "local"),
+                                0L /* receiptTimeMillis */, true /* cacheFlush */,
+                                otherTtlMillis,
+                                TEST_PUBLIC_KEY)
+                ) /* answers */,
+                emptyList() /* authorityRecords */,
+                emptyList() /* additionalRecords */)
+
+        assertEquals(emptyMap(),
+                repository.getConflictingServices(packet))
+    }
+
+    @Test
+    fun testGetConflictingServices_differentKeyRecordsForService_conflict() {
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
+
+        repository.addService(TEST_SERVICE_ID_1, NsdServiceInfo().apply {
+            serviceType = "_testservice._tcp"
+            serviceName = "MyTestService"
+            port = TEST_PORT
+            publicKey = TEST_PUBLIC_KEY
+        }, null /* null */)
+
+        val otherTtlMillis = 1234L
+        val packet = MdnsPacket(
+                0 /* flags */,
+                emptyList() /* questions */,
+                listOf(
+                        MdnsKeyRecord(
+                                arrayOf("MyTestService", "_testservice", "_tcp", "local"),
+                                0L /* receiptTimeMillis */, true /* cacheFlush */,
+                                otherTtlMillis,
+                                TEST_PUBLIC_KEY_2)
+                ) /* answers */,
+                emptyList() /* authorityRecords */,
+                emptyList() /* additionalRecords */)
+
+        assertEquals(mapOf(TEST_SERVICE_ID_1 to CONFLICT_SERVICE),
+                repository.getConflictingServices(packet))
+    }
+
+    @Test
+    fun testGetConflictingServices_identicalKeyRecordsForHost_noConflict() {
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
+
+        repository.addServiceAndFinishProbing(TEST_SERVICE_ID_1, NsdServiceInfo().apply {
+            hostname = "MyHost"
+            hostAddresses = listOf(
+                parseNumericAddress("2001:db8::1"),
+                parseNumericAddress("2001:db8::2")
+            )
+            publicKey = TEST_PUBLIC_KEY
+        })
+
+        val otherTtlMillis = 1234L
+        val packet = MdnsPacket(
+                0 /* flags */,
+                emptyList() /* questions */,
+                listOf(
+                        MdnsKeyRecord(
+                                arrayOf("MyHost", "local"),
+                                0L /* receiptTimeMillis */, true /* cacheFlush */,
+                                otherTtlMillis,
+                                TEST_PUBLIC_KEY)
+                ) /* answers */,
+                emptyList() /* authorityRecords */,
+                emptyList() /* additionalRecords */)
+
+        assertEquals(emptyMap(),
+                repository.getConflictingServices(packet))
+    }
+
+    @Test
+    fun testGetConflictingServices_keyForCustomHostReplySameRecordName_conflictDuringProbing() {
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
+
+        repository.addService(TEST_SERVICE_ID_1, NsdServiceInfo().apply {
+            hostname = "MyHost"
+            publicKey = TEST_PUBLIC_KEY
+        }, null /* ttl */)
+
+        val otherTtlMillis = 1234L
+        val packet = MdnsPacket(
+            0 /* flags */,
+            emptyList() /* questions */,
+            listOf(MdnsInetAddressRecord(arrayOf("MyHost", "local"),
+                    0L /* receiptTimeMillis */,
+                    true /* cacheFlush */,
+                    otherTtlMillis,
+                    parseNumericAddress("192.168.2.111"))
+            ) /* answers */,
+            emptyList() /* authorityRecords */,
+            emptyList() /* additionalRecords */
+        )
+
+        assertEquals(mapOf(TEST_SERVICE_ID_1 to CONFLICT_HOST),
+            repository.getConflictingServices(packet))
+    }
+
+    @Test
+    fun testGetConflictingServices_differentKeyRecordsForHost_conflict() {
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
+
+        repository.addService(TEST_SERVICE_ID_1, NsdServiceInfo().apply {
+            hostname = "MyHost"
+            hostAddresses = listOf(
+                    parseNumericAddress("2001:db8::1"),
+                    parseNumericAddress("2001:db8::2"))
+            publicKey = TEST_PUBLIC_KEY
+        }, null /* ttl */)
+
+        val otherTtlMillis = 1234L
+        val packet = MdnsPacket(
+                0 /* flags */,
+                emptyList() /* questions */,
+                listOf(
+                        MdnsKeyRecord(
+                                arrayOf("MyHost", "local"),
+                                0L /* receiptTimeMillis */, true /* cacheFlush */,
+                                otherTtlMillis,
+                                TEST_PUBLIC_KEY_2)
+                ) /* answers */,
+                emptyList() /* authorityRecords */,
+                emptyList() /* additionalRecords */)
+
+        assertEquals(mapOf(TEST_SERVICE_ID_1 to CONFLICT_HOST),
+                repository.getConflictingServices(packet))
+    }
+
+    @Test
     fun testGetConflictingServices_IdenticalService() {
         val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
         repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* ttl */)
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordTests.java
index 55c2846..63548c1 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordTests.java
@@ -16,11 +16,13 @@
 
 package com.android.server.connectivity.mdns;
 
+import static com.android.server.connectivity.mdns.MdnsConstants.QCLASS_INTERNET;
 import static com.android.testutils.DevSdkIgnoreRuleKt.SC_V2;
 
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertThrows;
@@ -424,4 +426,92 @@
         assertEquals(new TextEntry("xyz", HexDump.hexStringToByteArray("FFEFDFCF")),
                 entries.get(2));
     }
+
+    @Test
+    public void testKeyRecord() throws IOException {
+        final byte[] dataIn =
+                HexDump.hexStringToByteArray(
+                        "09746573742d686f7374056c6f63616c"
+                                + "00001980010000000a00440201030dc1"
+                                + "41d0637960b98cbc12cfca221d2879da"
+                                + "c26ee5b460e9007c992e1902d897c391"
+                                + "b03764d448f7d0c772fdb03b1d9d6d52"
+                                + "ff8886769e8e2362513565270962d3");
+        final byte[] rData =
+                HexDump.hexStringToByteArray(
+                        "0201030dc141d0637960b98cbc12cfca"
+                                + "221d2879dac26ee5b460e9007c992e19"
+                                + "02d897c391b03764d448f7d0c772fdb0"
+                                + "3b1d9d6d52ff8886769e8e2362513565"
+                                + "270962d3");
+        assertNotNull(dataIn);
+        String dataInText = HexDump.dumpHexString(dataIn, 0, dataIn.length);
+
+        // Decode
+        DatagramPacket packet = new DatagramPacket(dataIn, dataIn.length);
+        MdnsPacketReader reader = new MdnsPacketReader(packet);
+
+        String[] name = reader.readLabels();
+        assertNotNull(name);
+        assertEquals(2, name.length);
+        String fqdn = MdnsRecord.labelsToString(name);
+        assertEquals("test-host.local", fqdn);
+
+        int type = reader.readUInt16();
+        assertEquals(MdnsRecord.TYPE_KEY, type);
+
+        MdnsKeyRecord keyRecord;
+
+        // MdnsKeyRecord(String[] name, MdnsPacketReader reader)
+        reader = new MdnsPacketReader(packet);
+        reader.readLabels(); // Skip labels
+        reader.readUInt16(); // Skip type
+        keyRecord = new MdnsKeyRecord(name, reader);
+        assertEquals(MdnsRecord.TYPE_KEY, keyRecord.getType());
+        assertTrue(keyRecord.getTtl() > 0); // Not a question so the TTL is greater than 0
+        assertTrue(keyRecord.getCacheFlush());
+        assertArrayEquals(new String[] {"test-host", "local"}, keyRecord.getName());
+        assertArrayEquals(rData, keyRecord.getRData());
+        assertNotEquals(rData, keyRecord.getRData()); // Uses a copy of the original RDATA
+        assertEquals(dataInText, toHex(keyRecord));
+
+        // MdnsKeyRecord(String[] name, MdnsPacketReader reader, boolean isQuestion)
+        reader = new MdnsPacketReader(packet);
+        reader.readLabels(); // Skip labels
+        reader.readUInt16(); // Skip type
+        keyRecord = new MdnsKeyRecord(name, reader, false /* isQuestion */);
+        assertEquals(MdnsRecord.TYPE_KEY, keyRecord.getType());
+        assertTrue(keyRecord.getTtl() > 0); // Not a question, so the TTL is greater than 0
+        assertTrue(keyRecord.getCacheFlush());
+        assertArrayEquals(new String[] {"test-host", "local"}, keyRecord.getName());
+        assertArrayEquals(rData, keyRecord.getRData());
+        assertNotEquals(rData, keyRecord.getRData()); // Uses a copy of the original RDATA
+
+        // MdnsKeyRecord(String[] name, boolean isUnicast)
+        keyRecord = new MdnsKeyRecord(name, false /* isUnicast */);
+        assertEquals(MdnsRecord.TYPE_KEY, keyRecord.getType());
+        assertEquals(0, keyRecord.getTtl());
+        assertEquals(QCLASS_INTERNET, keyRecord.getRecordClass());
+        assertFalse(keyRecord.getCacheFlush());
+        assertArrayEquals(new String[] {"test-host", "local"}, keyRecord.getName());
+        assertArrayEquals(null, keyRecord.getRData());
+
+        // MdnsKeyRecord(String[] name, long receiptTimeMillis, boolean cacheFlush, long ttlMillis,
+        // byte[] rData)
+        keyRecord =
+                new MdnsKeyRecord(
+                        name,
+                        10 /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        20_000 /* ttlMillis */,
+                        rData);
+        assertEquals(MdnsRecord.TYPE_KEY, keyRecord.getType());
+        assertEquals(10, keyRecord.getReceiptTime());
+        assertTrue(keyRecord.getCacheFlush());
+        assertEquals(20_000, keyRecord.getTtl());
+        assertEquals(QCLASS_INTERNET, keyRecord.getRecordClass());
+        assertArrayEquals(new String[] {"test-host", "local"}, keyRecord.getName());
+        assertArrayEquals(rData, keyRecord.getRData());
+        assertNotEquals(rData, keyRecord.getRData()); // Uses a copy of the original RDATA
+    }
 }
diff --git a/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java b/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java
index 0e7f3be..dea4279 100644
--- a/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java
+++ b/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java
@@ -67,6 +67,7 @@
 import android.net.thread.utils.TapTestNetworkTracker;
 import android.net.thread.utils.ThreadFeatureCheckerRule;
 import android.net.thread.utils.ThreadFeatureCheckerRule.RequiresThreadFeature;
+import android.os.Build;
 import android.os.HandlerThread;
 import android.os.OutcomeReceiver;
 
@@ -782,10 +783,14 @@
     public void threadNetworkCallback_deviceAttached_threadNetworkIsAvailable() throws Exception {
         CompletableFuture<Network> networkFuture = new CompletableFuture<>();
         ConnectivityManager cm = mContext.getSystemService(ConnectivityManager.class);
-        NetworkRequest networkRequest =
-                new NetworkRequest.Builder()
-                        .addTransportType(NetworkCapabilities.TRANSPORT_THREAD)
-                        .build();
+        NetworkRequest.Builder networkRequestBuilder =
+                new NetworkRequest.Builder().addTransportType(NetworkCapabilities.TRANSPORT_THREAD);
+        // Before V, we need to explicitly set `NET_CAPABILITY_LOCAL_NETWORK` capability to request
+        // a Thread network.
+        if (Build.VERSION.SDK_INT <= Build.VERSION_CODES.UPSIDE_DOWN_CAKE) {
+            networkRequestBuilder.addCapability(NET_CAPABILITY_LOCAL_NETWORK);
+        }
+        NetworkRequest networkRequest = networkRequestBuilder.build();
         ConnectivityManager.NetworkCallback networkCallback =
                 new ConnectivityManager.NetworkCallback() {
                     @Override