Merge changes from topic "cdm-secure-channel-udc-dev" into udc-dev

* changes:
  Introduce hidden API to disable secure channel for back-compatibility
  Integrate secure channel into CDM
  Implement secure channel
  Add Ukey2 dependency
diff --git a/core/java/android/companion/CompanionDeviceManager.java b/core/java/android/companion/CompanionDeviceManager.java
index 5df2d5e..de4f619 100644
--- a/core/java/android/companion/CompanionDeviceManager.java
+++ b/core/java/android/companion/CompanionDeviceManager.java
@@ -1194,6 +1194,20 @@
         }
     }
 
+    /**
+     * Enable or disable secure transport for testing. Defaults to enabled.
+     *
+     * @param enabled true to enable. false to disable.
+     * @hide
+     */
+    public void enableSecureTransport(boolean enabled) {
+        try {
+            mService.enableSecureTransport(enabled);
+        } catch (RemoteException e) {
+            throw e.rethrowFromSystemServer();
+        }
+    }
+
     private boolean checkFeaturePresent() {
         boolean featurePresent = mService != null;
         if (!featurePresent && DEBUG) {
diff --git a/core/java/android/companion/ICompanionDeviceManager.aidl b/core/java/android/companion/ICompanionDeviceManager.aidl
index 010aa8f..cb4baca 100644
--- a/core/java/android/companion/ICompanionDeviceManager.aidl
+++ b/core/java/android/companion/ICompanionDeviceManager.aidl
@@ -88,4 +88,6 @@
     void enableSystemDataSync(int associationId, int flags);
 
     void disableSystemDataSync(int associationId, int flags);
+
+    void enableSecureTransport(boolean enabled);
 }
diff --git a/core/tests/companiontests/src/android/companion/SystemDataTransportTest.java b/core/tests/companiontests/src/android/companion/SystemDataTransportTest.java
index d633843..2b4123a 100644
--- a/core/tests/companiontests/src/android/companion/SystemDataTransportTest.java
+++ b/core/tests/companiontests/src/android/companion/SystemDataTransportTest.java
@@ -60,6 +60,7 @@
         mContext = getInstrumentation().getTargetContext();
         mCdm = mContext.getSystemService(CompanionDeviceManager.class);
         mAssociationId = createAssociation();
+        mCdm.enableSecureTransport(false);
     }
 
     @Override
@@ -67,6 +68,7 @@
         super.tearDown();
 
         mCdm.disassociate(mAssociationId);
+        mCdm.enableSecureTransport(true);
     }
 
     public void testPingHandRolled() {
diff --git a/services/Android.bp b/services/Android.bp
index f8097ec..6e6c553 100644
--- a/services/Android.bp
+++ b/services/Android.bp
@@ -195,6 +195,10 @@
         "manifest_services.xml",
     ],
 
+    required: [
+        "libukey2_jni_shared",
+    ],
+
     // Uncomment to enable output of certain warnings (deprecated, unchecked)
     //javacflags: ["-Xlint"],
 }
diff --git a/services/companion/Android.bp b/services/companion/Android.bp
index cdeb2dc..a248d9e5 100644
--- a/services/companion/Android.bp
+++ b/services/companion/Android.bp
@@ -24,4 +24,7 @@
         "app-compat-annotations",
         "services.core",
     ],
+    static_libs: [
+        "ukey2_jni",
+    ],
 }
diff --git a/services/companion/java/com/android/server/companion/CompanionDeviceManagerService.java b/services/companion/java/com/android/server/companion/CompanionDeviceManagerService.java
index 0f2ba35..a35cae9 100644
--- a/services/companion/java/com/android/server/companion/CompanionDeviceManagerService.java
+++ b/services/companion/java/com/android/server/companion/CompanionDeviceManagerService.java
@@ -726,6 +726,11 @@
         }
 
         @Override
+        public void enableSecureTransport(boolean enabled) {
+            mTransportManager.enableSecureTransport(enabled);
+        }
+
+        @Override
         public void notifyDeviceAppeared(int associationId) {
             if (DEBUG) Log.i(TAG, "notifyDevice_Appeared() id=" + associationId);
 
diff --git a/services/companion/java/com/android/server/companion/securechannel/AttestationVerifier.java b/services/companion/java/com/android/server/companion/securechannel/AttestationVerifier.java
new file mode 100644
index 0000000..adaee75
--- /dev/null
+++ b/services/companion/java/com/android/server/companion/securechannel/AttestationVerifier.java
@@ -0,0 +1,100 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.companion.securechannel;
+
+import static android.security.attestationverification.AttestationVerificationManager.PARAM_CHALLENGE;
+import static android.security.attestationverification.AttestationVerificationManager.PROFILE_PEER_DEVICE;
+import static android.security.attestationverification.AttestationVerificationManager.TYPE_CHALLENGE;
+
+import android.annotation.NonNull;
+import android.content.Context;
+import android.os.Bundle;
+import android.security.attestationverification.AttestationProfile;
+import android.security.attestationverification.AttestationVerificationManager;
+import android.security.attestationverification.VerificationToken;
+
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.BiConsumer;
+
+/**
+ * Helper class to perform attestation verification synchronously.
+ */
+class AttestationVerifier {
+    private static final long ATTESTATION_VERIFICATION_TIMEOUT_SECONDS = 10; // 10 seconds
+    private static final String PARAM_OWNED_BY_SYSTEM = "android.key_owned_by_system";
+
+    private final Context mContext;
+
+    AttestationVerifier(Context context) {
+        this.mContext = context;
+    }
+
+    /**
+     * Synchronously verify remote attestation as a suitable peer device on current thread.
+     *
+     * The peer device must be owned by the Android system and be protected with appropriate
+     * public key that this device can verify as attestation challenge.
+     *
+     * @param remoteAttestation the full certificate chain containing attestation extension.
+     * @param attestationChallenge attestation challenge for authentication.
+     * @return true if attestation is successfully verified; false otherwise.
+     */
+    @NonNull
+    public int verifyAttestation(
+            @NonNull byte[] remoteAttestation,
+            @NonNull byte[] attestationChallenge
+    ) throws SecureChannelException {
+        Bundle requirements = new Bundle();
+        requirements.putByteArray(PARAM_CHALLENGE, attestationChallenge);
+        requirements.putBoolean(PARAM_OWNED_BY_SYSTEM, true); // Custom parameter for CDM
+
+        // Synchronously execute attestation verification.
+        AtomicInteger verificationResult = new AtomicInteger(0);
+        CountDownLatch verificationFinished = new CountDownLatch(1);
+        BiConsumer<Integer, VerificationToken> onVerificationResult = (result, token) -> {
+            verificationResult.set(result);
+            verificationFinished.countDown();
+        };
+
+        mContext.getSystemService(AttestationVerificationManager.class).verifyAttestation(
+                new AttestationProfile(PROFILE_PEER_DEVICE),
+                /* localBindingType */ TYPE_CHALLENGE,
+                requirements,
+                remoteAttestation,
+                Runnable::run,
+                onVerificationResult
+        );
+
+        boolean finished;
+        try {
+            finished = verificationFinished.await(
+                    ATTESTATION_VERIFICATION_TIMEOUT_SECONDS,
+                    TimeUnit.SECONDS
+            );
+        } catch (InterruptedException e) {
+            throw new SecureChannelException("Attestation verification was interrupted", e);
+        }
+
+        if (!finished) {
+            throw new SecureChannelException("Attestation verification timed out.");
+        }
+
+        return verificationResult.get();
+    }
+}
diff --git a/services/companion/java/com/android/server/companion/securechannel/KeyStoreUtils.java b/services/companion/java/com/android/server/companion/securechannel/KeyStoreUtils.java
new file mode 100644
index 0000000..18ebec4
--- /dev/null
+++ b/services/companion/java/com/android/server/companion/securechannel/KeyStoreUtils.java
@@ -0,0 +1,130 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.companion.securechannel;
+
+import static android.security.keystore.KeyProperties.DIGEST_SHA256;
+import static android.security.keystore.KeyProperties.KEY_ALGORITHM_EC;
+import static android.security.keystore.KeyProperties.PURPOSE_SIGN;
+import static android.security.keystore.KeyProperties.PURPOSE_VERIFY;
+
+import android.security.keystore.KeyGenParameterSpec;
+import android.security.keystore2.AndroidKeyStoreSpi;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.security.GeneralSecurityException;
+import java.security.KeyPairGenerator;
+import java.security.KeyStore;
+import java.security.KeyStoreException;
+import java.security.cert.Certificate;
+
+/**
+ * Utility class to help generate, store, and access key-pair for the secure channel. Uses
+ * Android Keystore.
+ */
+final class KeyStoreUtils {
+    private static final String TAG = "CDM_SecureChannelKeyStore";
+    private static final String ANDROID_KEYSTORE = AndroidKeyStoreSpi.NAME;
+
+    private KeyStoreUtils() {}
+
+    /**
+     * Load Android keystore to be used by the secure channel.
+     *
+     * @return loaded keystore instance
+     */
+    static KeyStore loadKeyStore() throws GeneralSecurityException {
+        KeyStore androidKeyStore = KeyStore.getInstance(ANDROID_KEYSTORE);
+
+        try {
+            androidKeyStore.load(null);
+        } catch (IOException e) {
+            // Should not happen
+            throw new KeyStoreException("Failed to load Android Keystore.", e);
+        }
+
+        return androidKeyStore;
+    }
+
+    /**
+     * Fetch the certificate chain encoded as byte array in the form of concatenated
+     * X509 certificates.
+     *
+     * @param alias unique alias for the key-pair entry
+     * @return a single byte-array containing the entire certificate chain
+     */
+    static byte[] getEncodedCertificateChain(String alias) throws GeneralSecurityException {
+        KeyStore ks = loadKeyStore();
+
+        Certificate[] certificateChain = ks.getCertificateChain(alias);
+
+        ByteArrayOutputStream buffer = new ByteArrayOutputStream();
+        for (Certificate certificate : certificateChain) {
+            buffer.writeBytes(certificate.getEncoded());
+        }
+        return buffer.toByteArray();
+    }
+
+    /**
+     * Generate a new attestation key-pair.
+     *
+     * @param alias unique alias for the key-pair entry
+     * @param attestationChallenge challenge value to check against for authentication
+     */
+    static void generateAttestationKeyPair(String alias, byte[] attestationChallenge)
+            throws GeneralSecurityException {
+        KeyGenParameterSpec parameterSpec =
+                new KeyGenParameterSpec.Builder(alias, PURPOSE_SIGN | PURPOSE_VERIFY)
+                        .setAttestationChallenge(attestationChallenge)
+                        .setDigests(DIGEST_SHA256)
+                        .build();
+
+        KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance(
+                /* algorithm */ KEY_ALGORITHM_EC,
+                /* provider */ ANDROID_KEYSTORE);
+        keyPairGenerator.initialize(parameterSpec);
+        keyPairGenerator.generateKeyPair();
+    }
+
+    /**
+     * Check if alias exists.
+     *
+     * @param alias unique alias for the key-pair entry
+     * @return true if given alias already exists in the keystore
+     */
+    static boolean aliasExists(String alias) {
+        try {
+            KeyStore ks = loadKeyStore();
+            return ks.containsAlias(alias);
+        } catch (GeneralSecurityException e) {
+            return false;
+        }
+
+    }
+
+    static void cleanUp(String alias) {
+        try {
+            KeyStore ks = loadKeyStore();
+
+            if (ks.containsAlias(alias)) {
+                ks.deleteEntry(alias);
+            }
+        } catch (Exception ignored) {
+            // Do nothing;
+        }
+    }
+}
diff --git a/services/companion/java/com/android/server/companion/securechannel/SecureChannel.java b/services/companion/java/com/android/server/companion/securechannel/SecureChannel.java
new file mode 100644
index 0000000..13dba84
--- /dev/null
+++ b/services/companion/java/com/android/server/companion/securechannel/SecureChannel.java
@@ -0,0 +1,543 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.companion.securechannel;
+
+import static android.security.attestationverification.AttestationVerificationManager.RESULT_SUCCESS;
+
+import android.annotation.NonNull;
+import android.content.Context;
+import android.os.Build;
+import android.util.Slog;
+
+import com.google.security.cryptauth.lib.securegcm.BadHandleException;
+import com.google.security.cryptauth.lib.securegcm.CryptoException;
+import com.google.security.cryptauth.lib.securegcm.D2DConnectionContextV1;
+import com.google.security.cryptauth.lib.securegcm.D2DHandshakeContext;
+import com.google.security.cryptauth.lib.securegcm.D2DHandshakeContext.Role;
+import com.google.security.cryptauth.lib.securegcm.DefaultUkey2Logger;
+import com.google.security.cryptauth.lib.securegcm.HandshakeException;
+
+import libcore.io.IoUtils;
+import libcore.io.Streams;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.nio.ByteBuffer;
+import java.nio.charset.StandardCharsets;
+import java.security.GeneralSecurityException;
+import java.security.MessageDigest;
+import java.util.Arrays;
+import java.util.UUID;
+
+/**
+ * Data stream channel that establishes secure connection between two peer devices.
+ */
+public class SecureChannel {
+    private static final String TAG = "CDM_SecureChannel";
+    private static final boolean DEBUG = Build.IS_DEBUGGABLE;
+
+    private static final int VERSION = 1;
+    private static final int HEADER_LENGTH = 6;
+
+    private static final String HANDSHAKE_PROTOCOL = "AES_256_CBC-HMAC_SHA256";
+
+    private final InputStream mInput;
+    private final OutputStream mOutput;
+    private final Callback mCallback;
+    private final byte[] mPreSharedKey;
+    private final AttestationVerifier mVerifier;
+
+    private volatile boolean mStopped;
+    private boolean mInProgress;
+
+    private Role mRole;
+    private D2DHandshakeContext mHandshakeContext;
+    private D2DConnectionContextV1 mConnectionContext;
+
+    private String mAlias;
+    private int mVerificationResult;
+
+
+    /**
+     * Create a new secure channel object. This secure channel allows secure messages to be
+     * exchanged with unattested devices. The pre-shared key must have been distributed to both
+     * participants of the channel in a secure way previously.
+     *
+     * @param in input stream from which data is received
+     * @param out output stream from which data is sent out
+     * @param callback subscription to received messages from the channel
+     * @param preSharedKey pre-shared key to authenticate unattested participant
+     */
+    public SecureChannel(
+            @NonNull final InputStream in,
+            @NonNull final OutputStream out,
+            @NonNull Callback callback,
+            @NonNull byte[] preSharedKey
+    ) {
+        this(in, out, callback, preSharedKey, null);
+    }
+
+    /**
+     * Create a new secure channel object. This secure channel allows secure messages to be
+     * exchanged with Android devices that were authenticated and verified with an attestation key.
+     *
+     * @param in input stream from which data is received
+     * @param out output stream from which data is sent out
+     * @param callback subscription to received messages from the channel
+     * @param context context for fetching the Attestation Verifier Framework system service
+     */
+    public SecureChannel(
+            @NonNull final InputStream in,
+            @NonNull final OutputStream out,
+            @NonNull Callback callback,
+            @NonNull Context context
+    ) {
+        this(in, out, callback, null, new AttestationVerifier(context));
+    }
+
+    private SecureChannel(
+            final InputStream in,
+            final OutputStream out,
+            Callback callback,
+            byte[] preSharedKey,
+            AttestationVerifier verifier
+    ) {
+        this.mInput = in;
+        this.mOutput = out;
+        this.mCallback = callback;
+        this.mPreSharedKey = preSharedKey;
+        this.mVerifier = verifier;
+    }
+
+    /**
+     * Start listening for incoming messages.
+     */
+    public void start() {
+        new Thread(() -> {
+            try {
+                // 1. Wait for the next handshake message and process it.
+                exchangeHandshake();
+
+                // 2. Authenticate remote actor via attestation or pre-shared key.
+                exchangeAuthentication();
+
+                // 3. Notify secure channel is ready.
+                mInProgress = false;
+                mCallback.onSecureConnection();
+
+                // Listen for secure messages.
+                while (!mStopped) {
+                    receiveSecureMessage();
+                }
+            } catch (Exception e) {
+                if (mStopped) {
+                    return;
+                }
+                // TODO: Handle different types errors.
+
+                Slog.e(TAG, "Secure channel encountered an error.", e);
+                stop();
+                mCallback.onError(e);
+            }
+        }).start();
+    }
+
+    /**
+     * Stop listening to incoming messages and close the channel.
+     */
+    public void stop() {
+        if (DEBUG) {
+            Slog.d(TAG, "Stopping secure channel.");
+        }
+        mStopped = true;
+        mInProgress = false;
+
+        IoUtils.closeQuietly(mInput);
+        IoUtils.closeQuietly(mOutput);
+        KeyStoreUtils.cleanUp(mAlias);
+    }
+
+    /**
+     * Start exchanging handshakes to create a secure layer asynchronously. When the handshake is
+     * completed successfully, then the {@link Callback#onSecureConnection()} will trigger. Any
+     * error that occurs during the handshake will be passed by {@link Callback#onError(Throwable)}.
+     *
+     * This method must only be called from one of the two participants.
+     */
+    public void establishSecureConnection() throws IOException, SecureChannelException {
+        if (isSecured()) {
+            Slog.d(TAG, "Channel is already secure.");
+            return;
+        }
+        if (mInProgress) {
+            Slog.w(TAG, "Channel has already started establishing secure connection.");
+            return;
+        }
+
+        try {
+            initiateHandshake();
+            mInProgress = true;
+        } catch (BadHandleException e) {
+            throw new SecureChannelException("Failed to initiate handshake protocol.", e);
+        }
+    }
+
+    /**
+     * Send an encrypted, authenticated message via this channel.
+     *
+     * @param data data to be sent to the other side.
+     * @throws IOException if the output stream fails to write given data.
+     */
+    public void sendSecureMessage(byte[] data) throws IOException {
+        if (!isSecured()) {
+            Slog.d(TAG, "Cannot send a message without a secure connection.");
+            throw new IllegalStateException("Channel is not secured yet.");
+        }
+
+        // Encrypt constructed message
+        try {
+            sendMessage(MessageType.SECURE_MESSAGE, data);
+        } catch (BadHandleException e) {
+            throw new SecureChannelException("Failed to encrypt data.", e);
+        }
+    }
+
+    private void receiveSecureMessage() throws IOException, CryptoException {
+        // Check if channel is secured. Trigger error callback. Let user handle it.
+        if (!isSecured()) {
+            Slog.d(TAG, "Received a message without a secure connection. "
+                    + "Message will be ignored.");
+            mCallback.onError(new IllegalStateException("Connection is not secure."));
+            return;
+        }
+
+        try {
+            byte[] receivedMessage = readMessage(MessageType.SECURE_MESSAGE);
+            mCallback.onSecureMessageReceived(receivedMessage);
+        } catch (SecureChannelException e) {
+            Slog.w(TAG, "Ignoring received message.", e);
+        }
+    }
+
+    private byte[] readMessage(MessageType expected)
+            throws IOException, SecureChannelException, CryptoException {
+        if (DEBUG) {
+            if (isSecured()) {
+                Slog.d(TAG, "Waiting to receive next secure message.");
+            } else {
+                Slog.d(TAG, "Waiting to receive next message.");
+            }
+        }
+
+        // TODO: Handle message timeout
+
+        // Header is _not_ encrypted, but will be covered by MAC
+        final byte[] headerBytes = new byte[HEADER_LENGTH];
+        Streams.readFully(mInput, headerBytes);
+        final ByteBuffer header = ByteBuffer.wrap(headerBytes);
+        final int version = header.getInt();
+        final short type = header.getShort();
+
+        if (version != VERSION) {
+            Streams.skipByReading(mInput, Long.MAX_VALUE);
+            throw new SecureChannelException("Secure channel version mismatch. "
+                    + "Currently on version " + VERSION + ". Skipping rest of data.");
+        }
+
+        if (type != expected.mValue) {
+            Streams.skipByReading(mInput, Long.MAX_VALUE);
+            throw new SecureChannelException("Unexpected message type. Expected " + expected.name()
+                    + "; Found " + MessageType.from(type).name() + ". Skipping rest of data.");
+        }
+
+        // Length of attached data is prepended as plaintext
+        final byte[] lengthBytes = new byte[4];
+        Streams.readFully(mInput, lengthBytes);
+        final int length = ByteBuffer.wrap(lengthBytes).getInt();
+
+        // Read data based on the length
+        final byte[] data;
+        try {
+            data = new byte[length];
+        } catch (OutOfMemoryError error) {
+            throw new SecureChannelException("Payload is too large.", error);
+        }
+
+        Streams.readFully(mInput, data);
+        if (!MessageType.shouldEncrypt(expected)) {
+            return data;
+        }
+
+        return mConnectionContext.decodeMessageFromPeer(data, headerBytes);
+    }
+
+    private void sendMessage(MessageType messageType, byte[] payload)
+            throws IOException, BadHandleException {
+        synchronized (mOutput) {
+            byte[] header = ByteBuffer.allocate(HEADER_LENGTH)
+                    .putInt(VERSION)
+                    .putShort(messageType.mValue)
+                    .array();
+            byte[] data = MessageType.shouldEncrypt(messageType)
+                    ? mConnectionContext.encodeMessageToPeer(payload, header)
+                    : payload;
+            mOutput.write(header);
+            mOutput.write(ByteBuffer.allocate(4)
+                    .putInt(data.length)
+                    .array());
+            mOutput.write(data);
+            mOutput.flush();
+        }
+    }
+
+    private void initiateHandshake() throws IOException, BadHandleException {
+        if (mConnectionContext != null) {
+            Slog.d(TAG, "Ukey2 handshake is already completed.");
+            return;
+        }
+
+        mRole = Role.Initiator;
+        mHandshakeContext = D2DHandshakeContext.forInitiator(DefaultUkey2Logger.INSTANCE);
+
+        // Send Client Init
+        if (DEBUG) {
+            Slog.d(TAG, "Sending Ukey2 Client Init message");
+        }
+        sendMessage(MessageType.HANDSHAKE_INIT, mHandshakeContext.getNextHandshakeMessage());
+    }
+
+    private void exchangeHandshake()
+            throws IOException, HandshakeException, BadHandleException, CryptoException {
+        if (mConnectionContext != null) {
+            Slog.d(TAG, "Ukey2 handshake is already completed.");
+            return;
+        }
+
+        // Waiting for message
+        byte[] handshakeMessage = readMessage(MessageType.HANDSHAKE_INIT);
+
+        if (mHandshakeContext == null) { // Server-side logic
+            mRole = Role.Responder;
+            mHandshakeContext = D2DHandshakeContext.forResponder(DefaultUkey2Logger.INSTANCE);
+
+            // Receive Client Init
+            if (DEBUG) {
+                Slog.d(TAG, "Receiving Ukey2 Client Init message");
+            }
+            mHandshakeContext.parseHandshakeMessage(handshakeMessage);
+
+            // Send Server Init
+            if (DEBUG) {
+                Slog.d(TAG, "Sending Ukey2 Server Init message");
+            }
+            sendMessage(MessageType.HANDSHAKE_INIT, mHandshakeContext.getNextHandshakeMessage());
+
+            // Receive Client Finish
+            if (DEBUG) {
+                Slog.d(TAG, "Receiving Ukey2 Client Finish message");
+            }
+            mHandshakeContext.parseHandshakeMessage(readMessage(MessageType.HANDSHAKE_FINISH));
+        } else { // Client-side logic
+
+            // Receive Server Init
+            if (DEBUG) {
+                Slog.d(TAG, "Receiving Ukey2 Server Init message");
+            }
+            mHandshakeContext.parseHandshakeMessage(handshakeMessage);
+
+            // Send Client Finish
+            if (DEBUG) {
+                Slog.d(TAG, "Sending Ukey2 Client Finish message");
+            }
+            sendMessage(MessageType.HANDSHAKE_FINISH, mHandshakeContext.getNextHandshakeMessage());
+        }
+
+        // Convert secrets to connection context
+        if (mHandshakeContext.isHandshakeComplete()) {
+            if (DEBUG) {
+                Slog.d(TAG, "Ukey2 Handshake completed successfully");
+            }
+            mConnectionContext = mHandshakeContext.toConnectionContext();
+        } else {
+            Slog.e(TAG, "Failed to complete Ukey2 Handshake");
+            throw new IllegalStateException("Ukey2 Handshake did not complete as expected.");
+        }
+    }
+
+    private void exchangeAuthentication()
+            throws IOException, GeneralSecurityException, BadHandleException, CryptoException {
+        if (mVerifier == null) {
+            exchangePreSharedKey();
+        } else {
+            exchangeAttestation();
+        }
+    }
+
+    private void exchangePreSharedKey()
+            throws IOException, GeneralSecurityException, BadHandleException, CryptoException {
+
+        // Exchange hashed pre-shared keys
+        if (DEBUG) {
+            Slog.d(TAG, "Exchanging pre-shared keys.");
+        }
+        sendMessage(MessageType.PRE_SHARED_KEY, constructToken(mRole, mPreSharedKey));
+        byte[] receivedAuthToken = readMessage(MessageType.PRE_SHARED_KEY);
+        byte[] expectedAuthToken = constructToken(mRole == Role.Initiator
+                ? Role.Responder
+                : Role.Initiator,
+                mPreSharedKey);
+        boolean authenticated = Arrays.equals(receivedAuthToken, expectedAuthToken);
+
+        if (!authenticated) {
+            throw new SecureChannelException("Failed to verify the hash of pre-shared key.");
+        }
+
+        if (DEBUG) {
+            Slog.d(TAG, "The pre-shared key was successfully authenticated.");
+        }
+    }
+
+    private void exchangeAttestation()
+            throws IOException, GeneralSecurityException, BadHandleException, CryptoException {
+        if (mVerificationResult == RESULT_SUCCESS) {
+            Slog.d(TAG, "Remote attestation was already verified.");
+            return;
+        }
+
+        // Send local attestation
+        if (DEBUG) {
+            Slog.d(TAG, "Exchanging device attestation.");
+        }
+        if (mAlias == null) {
+            mAlias = generateAlias();
+        }
+        byte[] localChallenge = constructToken(mRole, mConnectionContext.getSessionUnique());
+        KeyStoreUtils.generateAttestationKeyPair(mAlias, localChallenge);
+        byte[] localAttestation = KeyStoreUtils.getEncodedCertificateChain(mAlias);
+        sendMessage(MessageType.ATTESTATION, localAttestation);
+        byte[] remoteAttestation = readMessage(MessageType.ATTESTATION);
+
+        // Verifying remote attestation with public key local binding param
+        byte[] expectedChallenge = constructToken(mRole == Role.Initiator
+                ? Role.Responder
+                : Role.Initiator,
+                mConnectionContext.getSessionUnique());
+        mVerificationResult = mVerifier.verifyAttestation(remoteAttestation, expectedChallenge);
+
+        // Exchange attestation verification result and finish
+        byte[] verificationResult = ByteBuffer.allocate(4)
+                .putInt(mVerificationResult)
+                .array();
+        sendMessage(MessageType.AVF_RESULT, verificationResult);
+        byte[] remoteVerificationResult = readMessage(MessageType.AVF_RESULT);
+
+        if (ByteBuffer.wrap(remoteVerificationResult).getInt() != RESULT_SUCCESS) {
+            throw new SecureChannelException("Remote device failed to verify local attestation.");
+        }
+
+        if (mVerificationResult != RESULT_SUCCESS) {
+            throw new SecureChannelException("Failed to verify remote attestation.");
+        }
+
+        if (DEBUG) {
+            Slog.d(TAG, "Remote attestation was successfully verified.");
+        }
+    }
+
+    private boolean isSecured() {
+        if (mConnectionContext == null) {
+            return false;
+        }
+        return mVerifier == null || mVerificationResult == RESULT_SUCCESS;
+    }
+
+    private byte[] constructToken(D2DHandshakeContext.Role role, byte[] authValue)
+            throws GeneralSecurityException {
+        MessageDigest hash = MessageDigest.getInstance("SHA-256");
+        byte[] roleUtf8 = role.name().getBytes(StandardCharsets.UTF_8);
+        int tokenLength = roleUtf8.length + authValue.length;
+        return hash.digest(ByteBuffer.allocate(tokenLength)
+                .put(roleUtf8)
+                .put(authValue)
+                .array());
+    }
+
+    private String generateAlias() {
+        String alias;
+        do {
+            alias = "secure-channel-" + UUID.randomUUID();
+        } while (KeyStoreUtils.aliasExists(alias));
+        return alias;
+    }
+
+    private enum MessageType {
+        HANDSHAKE_INIT(0x4849),   // HI
+        HANDSHAKE_FINISH(0x4846), // HF
+        PRE_SHARED_KEY(0x504b),   // PK
+        ATTESTATION(0x4154),      // AT
+        AVF_RESULT(0x5652),       // VR
+        SECURE_MESSAGE(0x534d),   // SM
+        UNKNOWN(0);               // X
+
+        private final short mValue;
+
+        MessageType(int value) {
+            this.mValue = (short) value;
+        }
+
+        static MessageType from(short value) {
+            for (MessageType messageType : values()) {
+                if (value == messageType.mValue) {
+                    return messageType;
+                }
+            }
+            return UNKNOWN;
+        }
+
+        // Encrypt every message besides Ukey2 handshake messages
+        private static boolean shouldEncrypt(MessageType type) {
+            return type != HANDSHAKE_INIT && type != HANDSHAKE_FINISH;
+        }
+    }
+
+    /**
+     * Callback that passes securely received message to the subscribed user.
+     */
+    public interface Callback {
+        /**
+         * Triggered after {@link SecureChannel#establishSecureConnection()} finishes exchanging
+         * every required handshakes to fully establish a secure connection.
+         */
+        void onSecureConnection();
+
+        /**
+         * Callback that passes securely received and decrypted data to the subscribed user.
+         *
+         * @param data securely received plaintext data.
+         */
+        void onSecureMessageReceived(byte[] data);
+
+        /**
+         * Callback that passes error that occurred during handshakes or while listening to
+         * messages in the secure channel.
+         *
+         * @param error
+         */
+        void onError(Throwable error);
+    }
+}
diff --git a/services/companion/java/com/android/server/companion/securechannel/SecureChannelException.java b/services/companion/java/com/android/server/companion/securechannel/SecureChannelException.java
new file mode 100644
index 0000000..68db97e
--- /dev/null
+++ b/services/companion/java/com/android/server/companion/securechannel/SecureChannelException.java
@@ -0,0 +1,34 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.companion.securechannel;
+
+/**
+ * Catch-all exception for any error in the secure channel.
+ */
+public class SecureChannelException extends RuntimeException {
+    /**
+     *
+     * @param message
+     */
+    public SecureChannelException(String message) {
+        super(message);
+    }
+
+    public SecureChannelException(String message, Throwable t) {
+        super(message, t);
+    }
+}
diff --git a/services/companion/java/com/android/server/companion/transport/CompanionTransportManager.java b/services/companion/java/com/android/server/companion/transport/CompanionTransportManager.java
index 6db99a0..494c5a6 100644
--- a/services/companion/java/com/android/server/companion/transport/CompanionTransportManager.java
+++ b/services/companion/java/com/android/server/companion/transport/CompanionTransportManager.java
@@ -34,6 +34,7 @@
 
 import com.android.internal.annotations.GuardedBy;
 import com.android.server.LocalServices;
+import com.android.server.companion.securechannel.SecureChannel;
 
 import libcore.io.IoUtils;
 import libcore.io.Streams;
@@ -43,6 +44,8 @@
 import java.io.InputStream;
 import java.io.OutputStream;
 import java.nio.ByteBuffer;
+import java.util.concurrent.ArrayBlockingQueue;
+import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.Future;
 import java.util.concurrent.atomic.AtomicInteger;
@@ -54,8 +57,6 @@
     private static final boolean DEBUG = true;
 
     private static final int HEADER_LENGTH = 12;
-    // TODO: refactor message processing to use streams to remove this limit
-    private static final int MAX_PAYLOAD_LENGTH = 1_000_000;
 
     private static final int MESSAGE_REQUEST_PING = 0x63807378; // ?PIN
     private static final int MESSAGE_REQUEST_PERMISSION_RESTORE = 0x63826983; // ?RES
@@ -63,6 +64,8 @@
     private static final int MESSAGE_RESPONSE_SUCCESS = 0x33838567; // !SUC
     private static final int MESSAGE_RESPONSE_FAILURE = 0x33706573; // !FAI
 
+    private boolean mSecureTransportEnabled = true;
+
     private static boolean isRequest(int message) {
         return (message & 0xFF000000) == 0x63000000;
     }
@@ -122,7 +125,13 @@
                 detachSystemDataTransport(packageName, userId, associationId);
             }
 
-            final Transport transport = new Transport(associationId, fd);
+            final Transport transport;
+            if (isSecureTransportEnabled(associationId)) {
+                transport = new SecureTransport(associationId, fd);
+            } else {
+                transport = new RawTransport(associationId, fd);
+            }
+
             transport.start();
             mTransports.put(associationId, transport);
         }
@@ -142,61 +151,65 @@
     public Future<?> requestPermissionRestore(int associationId, byte[] data) {
         synchronized (mTransports) {
             final Transport transport = mTransports.get(associationId);
-            if (transport != null) {
-                return transport.requestForResponse(MESSAGE_REQUEST_PERMISSION_RESTORE, data);
-            } else {
+            if (transport == null) {
                 return CompletableFuture.failedFuture(new IOException("Missing transport"));
             }
+
+            return transport.requestForResponse(MESSAGE_REQUEST_PERMISSION_RESTORE, data);
         }
     }
 
-    private class Transport {
-        private final int mAssociationId;
+    /**
+     * @hide
+     */
+    public void enableSecureTransport(boolean enabled) {
+        this.mSecureTransportEnabled = enabled;
+    }
 
-        private final InputStream mRemoteIn;
-        private final OutputStream mRemoteOut;
+    private boolean isSecureTransportEnabled(int associationId) {
+        boolean enabled = !Build.IS_DEBUGGABLE || mSecureTransportEnabled;
 
-        private final AtomicInteger mNextSequence = new AtomicInteger();
+        // TODO: version comparison logic
+        return enabled;
+    }
+
+    // TODO: Make Transport inner classes into standalone classes.
+    private abstract class Transport {
+        protected final int mAssociationId;
+        protected final InputStream mRemoteIn;
+        protected final OutputStream mRemoteOut;
 
         @GuardedBy("mPendingRequests")
-        private final SparseArray<CompletableFuture<byte[]>> mPendingRequests = new SparseArray<>();
+        protected final SparseArray<CompletableFuture<byte[]>> mPendingRequests =
+                new SparseArray<>();
+        protected final AtomicInteger mNextSequence = new AtomicInteger();
 
-        private volatile boolean mStopped;
-
-        public Transport(int associationId, ParcelFileDescriptor fd) {
-            mAssociationId = associationId;
-            mRemoteIn = new ParcelFileDescriptor.AutoCloseInputStream(fd);
-            mRemoteOut = new ParcelFileDescriptor.AutoCloseOutputStream(fd);
+        Transport(int associationId, ParcelFileDescriptor fd) {
+            this(associationId,
+                    new ParcelFileDescriptor.AutoCloseInputStream(fd),
+                    new ParcelFileDescriptor.AutoCloseOutputStream(fd));
         }
 
-        public void start() {
-            new Thread(() -> {
-                try {
-                    while (!mStopped) {
-                        receiveMessage();
-                    }
-                } catch (IOException e) {
-                    if (!mStopped) {
-                        Slog.w(TAG, "Trouble during transport", e);
-                        stop();
-                    }
-                }
-            }).start();
+        Transport(int associationId, InputStream in, OutputStream out) {
+            this.mAssociationId = associationId;
+            this.mRemoteIn = in;
+            this.mRemoteOut = out;
         }
 
-        public void stop() {
-            mStopped = true;
+        public abstract void start();
+        public abstract void stop();
 
-            IoUtils.closeQuietly(mRemoteIn);
-            IoUtils.closeQuietly(mRemoteOut);
-        }
+        protected abstract void sendMessage(int message, int sequence, @NonNull byte[] data)
+                throws IOException;
 
         public Future<byte[]> requestForResponse(int message, byte[] data) {
+            if (DEBUG) Slog.d(TAG, "Requesting for response");
             final int sequence = mNextSequence.incrementAndGet();
             final CompletableFuture<byte[]> pending = new CompletableFuture<>();
             synchronized (mPendingRequests) {
                 mPendingRequests.put(sequence, pending);
             }
+
             try {
                 sendMessage(message, sequence, data);
             } catch (IOException e) {
@@ -205,58 +218,24 @@
                 }
                 pending.completeExceptionally(e);
             }
+
             return pending;
         }
 
-        private void sendMessage(int message, int sequence, @NonNull byte[] data)
+        protected final void handleMessage(int message, int sequence, @NonNull byte[] data)
                 throws IOException {
             if (DEBUG) {
-                Slog.d(TAG, "Sending message 0x" + Integer.toHexString(message)
-                        + " sequence " + sequence + " length " + data.length
-                        + " to association " + mAssociationId);
-            }
-
-            synchronized (mRemoteOut) {
-                final ByteBuffer header = ByteBuffer.allocate(HEADER_LENGTH)
-                        .putInt(message)
-                        .putInt(sequence)
-                        .putInt(data.length);
-                mRemoteOut.write(header.array());
-                mRemoteOut.write(data);
-                mRemoteOut.flush();
-            }
-        }
-
-        private void receiveMessage() throws IOException {
-            if (DEBUG) {
-                Slog.d(TAG, "Waiting for next message...");
-            }
-
-            final byte[] headerBytes = new byte[HEADER_LENGTH];
-            Streams.readFully(mRemoteIn, headerBytes);
-            final ByteBuffer header = ByteBuffer.wrap(headerBytes);
-            final int message = header.getInt();
-            final int sequence = header.getInt();
-            final int length = header.getInt();
-
-            if (DEBUG) {
                 Slog.d(TAG, "Received message 0x" + Integer.toHexString(message)
-                        + " sequence " + sequence + " length " + length
+                        + " sequence " + sequence + " length " + data.length
                         + " from association " + mAssociationId);
             }
-            if (length > MAX_PAYLOAD_LENGTH) {
-                Slog.w(TAG, "Ignoring message 0x" + Integer.toHexString(message)
-                        + " sequence " + sequence + " length " + length
-                        + " from association " + mAssociationId + " beyond maximum length");
-                Streams.skipByReading(mRemoteIn, length);
-                return;
-            }
-
-            final byte[] data = new byte[length];
-            Streams.readFully(mRemoteIn, data);
 
             if (isRequest(message)) {
-                processRequest(message, sequence, data);
+                try {
+                    processRequest(message, sequence, data);
+                } catch (IOException e) {
+                    Slog.w(TAG, "Failed to respond to 0x" + Integer.toHexString(message), e);
+                }
             } else if (isResponse(message)) {
                 processResponse(message, sequence, data);
             } else {
@@ -320,4 +299,169 @@
             }
         }
     }
+
+    private class RawTransport extends Transport {
+        private volatile boolean mStopped;
+
+        RawTransport(int associationId, ParcelFileDescriptor fd) {
+            super(associationId, fd);
+        }
+
+        @Override
+        public void start() {
+            new Thread(() -> {
+                try {
+                    while (!mStopped) {
+                        receiveMessage();
+                    }
+                } catch (IOException e) {
+                    if (!mStopped) {
+                        Slog.w(TAG, "Trouble during transport", e);
+                        stop();
+                    }
+                }
+            }).start();
+        }
+
+        @Override
+        public void stop() {
+            mStopped = true;
+
+            IoUtils.closeQuietly(mRemoteIn);
+            IoUtils.closeQuietly(mRemoteOut);
+        }
+
+        @Override
+        protected void sendMessage(int message, int sequence, @NonNull byte[] data)
+                throws IOException {
+            if (DEBUG) {
+                Slog.d(TAG, "Sending message 0x" + Integer.toHexString(message)
+                        + " sequence " + sequence + " length " + data.length
+                        + " to association " + mAssociationId);
+            }
+
+            synchronized (mRemoteOut) {
+                final ByteBuffer header = ByteBuffer.allocate(HEADER_LENGTH)
+                        .putInt(message)
+                        .putInt(sequence)
+                        .putInt(data.length);
+                mRemoteOut.write(header.array());
+                mRemoteOut.write(data);
+                mRemoteOut.flush();
+            }
+        }
+
+        private void receiveMessage() throws IOException {
+            final byte[] headerBytes = new byte[HEADER_LENGTH];
+            Streams.readFully(mRemoteIn, headerBytes);
+            final ByteBuffer header = ByteBuffer.wrap(headerBytes);
+            final int message = header.getInt();
+            final int sequence = header.getInt();
+            final int length = header.getInt();
+            final byte[] data = new byte[length];
+            Streams.readFully(mRemoteIn, data);
+
+            handleMessage(message, sequence, data);
+        }
+    }
+
+    private class SecureTransport extends Transport implements SecureChannel.Callback {
+        private final SecureChannel mSecureChannel;
+
+        private volatile boolean mShouldProcessRequests = false;
+
+        private final BlockingQueue<byte[]> mRequestQueue = new ArrayBlockingQueue<>(100);
+
+        SecureTransport(int associationId, ParcelFileDescriptor fd) {
+            super(associationId, fd);
+            mSecureChannel = new SecureChannel(mRemoteIn, mRemoteOut, this, mContext);
+        }
+
+        @Override
+        public void start() {
+            mSecureChannel.start();
+        }
+
+        @Override
+        public void stop() {
+            mSecureChannel.stop();
+            mShouldProcessRequests = false;
+        }
+
+        @Override
+        public Future<byte[]> requestForResponse(int message, byte[] data) {
+            // Check if channel is secured and start securing
+            if (!mShouldProcessRequests) {
+                Slog.d(TAG, "Establishing secure connection.");
+                try {
+                    mSecureChannel.establishSecureConnection();
+                } catch (Exception e) {
+                    Slog.w(TAG, "Failed to initiate secure channel handshake.", e);
+                    onError(e);
+                }
+            }
+
+            return super.requestForResponse(message, data);
+        }
+
+        @Override
+        protected void sendMessage(int message, int sequence, @NonNull byte[] data)
+                throws IOException {
+            if (DEBUG) {
+                Slog.d(TAG, "Queueing message 0x" + Integer.toHexString(message)
+                        + " sequence " + sequence + " length " + data.length
+                        + " to association " + mAssociationId);
+            }
+
+            // Queue up a message to send
+            mRequestQueue.add(ByteBuffer.allocate(HEADER_LENGTH + data.length)
+                    .putInt(message)
+                    .putInt(sequence)
+                    .putInt(data.length)
+                    .put(data)
+                    .array());
+        }
+
+        @Override
+        public void onSecureConnection() {
+            mShouldProcessRequests = true;
+            Slog.d(TAG, "Secure connection established.");
+
+            // TODO: find a better way to handle incoming requests than a dedicated thread.
+            new Thread(() -> {
+                try {
+                    while (mShouldProcessRequests) {
+                        byte[] request = mRequestQueue.poll();
+                        if (request != null) {
+                            mSecureChannel.sendSecureMessage(request);
+                        }
+                    }
+                } catch (IOException e) {
+                    onError(e);
+                }
+            }).start();
+        }
+
+        @Override
+        public void onSecureMessageReceived(byte[] data) {
+            final ByteBuffer payload = ByteBuffer.wrap(data);
+            final int message = payload.getInt();
+            final int sequence = payload.getInt();
+            final int length = payload.getInt();
+            final byte[] content = new byte[length];
+            payload.get(content);
+
+            try {
+                handleMessage(message, sequence, content);
+            } catch (IOException error) {
+                onError(error);
+            }
+        }
+
+        @Override
+        public void onError(Throwable error) {
+            mShouldProcessRequests = false;
+            Slog.e(TAG, error.getMessage(), error);
+        }
+    }
 }