Merge changes from topic "cdm-secure-channel-udc-dev" into udc-dev
* changes:
Introduce hidden API to disable secure channel for back-compatibility
Integrate secure channel into CDM
Implement secure channel
Add Ukey2 dependency
diff --git a/core/java/android/companion/CompanionDeviceManager.java b/core/java/android/companion/CompanionDeviceManager.java
index 5df2d5e..de4f619 100644
--- a/core/java/android/companion/CompanionDeviceManager.java
+++ b/core/java/android/companion/CompanionDeviceManager.java
@@ -1194,6 +1194,20 @@
}
}
+ /**
+ * Enable or disable secure transport for testing. Defaults to enabled.
+ *
+ * @param enabled true to enable. false to disable.
+ * @hide
+ */
+ public void enableSecureTransport(boolean enabled) {
+ try {
+ mService.enableSecureTransport(enabled);
+ } catch (RemoteException e) {
+ throw e.rethrowFromSystemServer();
+ }
+ }
+
private boolean checkFeaturePresent() {
boolean featurePresent = mService != null;
if (!featurePresent && DEBUG) {
diff --git a/core/java/android/companion/ICompanionDeviceManager.aidl b/core/java/android/companion/ICompanionDeviceManager.aidl
index 010aa8f..cb4baca 100644
--- a/core/java/android/companion/ICompanionDeviceManager.aidl
+++ b/core/java/android/companion/ICompanionDeviceManager.aidl
@@ -88,4 +88,6 @@
void enableSystemDataSync(int associationId, int flags);
void disableSystemDataSync(int associationId, int flags);
+
+ void enableSecureTransport(boolean enabled);
}
diff --git a/core/tests/companiontests/src/android/companion/SystemDataTransportTest.java b/core/tests/companiontests/src/android/companion/SystemDataTransportTest.java
index d633843..2b4123a 100644
--- a/core/tests/companiontests/src/android/companion/SystemDataTransportTest.java
+++ b/core/tests/companiontests/src/android/companion/SystemDataTransportTest.java
@@ -60,6 +60,7 @@
mContext = getInstrumentation().getTargetContext();
mCdm = mContext.getSystemService(CompanionDeviceManager.class);
mAssociationId = createAssociation();
+ mCdm.enableSecureTransport(false);
}
@Override
@@ -67,6 +68,7 @@
super.tearDown();
mCdm.disassociate(mAssociationId);
+ mCdm.enableSecureTransport(true);
}
public void testPingHandRolled() {
diff --git a/services/Android.bp b/services/Android.bp
index f8097ec..6e6c553 100644
--- a/services/Android.bp
+++ b/services/Android.bp
@@ -195,6 +195,10 @@
"manifest_services.xml",
],
+ required: [
+ "libukey2_jni_shared",
+ ],
+
// Uncomment to enable output of certain warnings (deprecated, unchecked)
//javacflags: ["-Xlint"],
}
diff --git a/services/companion/Android.bp b/services/companion/Android.bp
index cdeb2dc..a248d9e5 100644
--- a/services/companion/Android.bp
+++ b/services/companion/Android.bp
@@ -24,4 +24,7 @@
"app-compat-annotations",
"services.core",
],
+ static_libs: [
+ "ukey2_jni",
+ ],
}
diff --git a/services/companion/java/com/android/server/companion/CompanionDeviceManagerService.java b/services/companion/java/com/android/server/companion/CompanionDeviceManagerService.java
index 0f2ba35..a35cae9 100644
--- a/services/companion/java/com/android/server/companion/CompanionDeviceManagerService.java
+++ b/services/companion/java/com/android/server/companion/CompanionDeviceManagerService.java
@@ -726,6 +726,11 @@
}
@Override
+ public void enableSecureTransport(boolean enabled) {
+ mTransportManager.enableSecureTransport(enabled);
+ }
+
+ @Override
public void notifyDeviceAppeared(int associationId) {
if (DEBUG) Log.i(TAG, "notifyDevice_Appeared() id=" + associationId);
diff --git a/services/companion/java/com/android/server/companion/securechannel/AttestationVerifier.java b/services/companion/java/com/android/server/companion/securechannel/AttestationVerifier.java
new file mode 100644
index 0000000..adaee75
--- /dev/null
+++ b/services/companion/java/com/android/server/companion/securechannel/AttestationVerifier.java
@@ -0,0 +1,100 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.companion.securechannel;
+
+import static android.security.attestationverification.AttestationVerificationManager.PARAM_CHALLENGE;
+import static android.security.attestationverification.AttestationVerificationManager.PROFILE_PEER_DEVICE;
+import static android.security.attestationverification.AttestationVerificationManager.TYPE_CHALLENGE;
+
+import android.annotation.NonNull;
+import android.content.Context;
+import android.os.Bundle;
+import android.security.attestationverification.AttestationProfile;
+import android.security.attestationverification.AttestationVerificationManager;
+import android.security.attestationverification.VerificationToken;
+
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.BiConsumer;
+
+/**
+ * Helper class to perform attestation verification synchronously.
+ */
+class AttestationVerifier {
+ private static final long ATTESTATION_VERIFICATION_TIMEOUT_SECONDS = 10; // 10 seconds
+ private static final String PARAM_OWNED_BY_SYSTEM = "android.key_owned_by_system";
+
+ private final Context mContext;
+
+ AttestationVerifier(Context context) {
+ this.mContext = context;
+ }
+
+ /**
+ * Synchronously verify remote attestation as a suitable peer device on current thread.
+ *
+ * The peer device must be owned by the Android system and be protected with appropriate
+ * public key that this device can verify as attestation challenge.
+ *
+ * @param remoteAttestation the full certificate chain containing attestation extension.
+ * @param attestationChallenge attestation challenge for authentication.
+ * @return true if attestation is successfully verified; false otherwise.
+ */
+ @NonNull
+ public int verifyAttestation(
+ @NonNull byte[] remoteAttestation,
+ @NonNull byte[] attestationChallenge
+ ) throws SecureChannelException {
+ Bundle requirements = new Bundle();
+ requirements.putByteArray(PARAM_CHALLENGE, attestationChallenge);
+ requirements.putBoolean(PARAM_OWNED_BY_SYSTEM, true); // Custom parameter for CDM
+
+ // Synchronously execute attestation verification.
+ AtomicInteger verificationResult = new AtomicInteger(0);
+ CountDownLatch verificationFinished = new CountDownLatch(1);
+ BiConsumer<Integer, VerificationToken> onVerificationResult = (result, token) -> {
+ verificationResult.set(result);
+ verificationFinished.countDown();
+ };
+
+ mContext.getSystemService(AttestationVerificationManager.class).verifyAttestation(
+ new AttestationProfile(PROFILE_PEER_DEVICE),
+ /* localBindingType */ TYPE_CHALLENGE,
+ requirements,
+ remoteAttestation,
+ Runnable::run,
+ onVerificationResult
+ );
+
+ boolean finished;
+ try {
+ finished = verificationFinished.await(
+ ATTESTATION_VERIFICATION_TIMEOUT_SECONDS,
+ TimeUnit.SECONDS
+ );
+ } catch (InterruptedException e) {
+ throw new SecureChannelException("Attestation verification was interrupted", e);
+ }
+
+ if (!finished) {
+ throw new SecureChannelException("Attestation verification timed out.");
+ }
+
+ return verificationResult.get();
+ }
+}
diff --git a/services/companion/java/com/android/server/companion/securechannel/KeyStoreUtils.java b/services/companion/java/com/android/server/companion/securechannel/KeyStoreUtils.java
new file mode 100644
index 0000000..18ebec4
--- /dev/null
+++ b/services/companion/java/com/android/server/companion/securechannel/KeyStoreUtils.java
@@ -0,0 +1,130 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.companion.securechannel;
+
+import static android.security.keystore.KeyProperties.DIGEST_SHA256;
+import static android.security.keystore.KeyProperties.KEY_ALGORITHM_EC;
+import static android.security.keystore.KeyProperties.PURPOSE_SIGN;
+import static android.security.keystore.KeyProperties.PURPOSE_VERIFY;
+
+import android.security.keystore.KeyGenParameterSpec;
+import android.security.keystore2.AndroidKeyStoreSpi;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.security.GeneralSecurityException;
+import java.security.KeyPairGenerator;
+import java.security.KeyStore;
+import java.security.KeyStoreException;
+import java.security.cert.Certificate;
+
+/**
+ * Utility class to help generate, store, and access key-pair for the secure channel. Uses
+ * Android Keystore.
+ */
+final class KeyStoreUtils {
+ private static final String TAG = "CDM_SecureChannelKeyStore";
+ private static final String ANDROID_KEYSTORE = AndroidKeyStoreSpi.NAME;
+
+ private KeyStoreUtils() {}
+
+ /**
+ * Load Android keystore to be used by the secure channel.
+ *
+ * @return loaded keystore instance
+ */
+ static KeyStore loadKeyStore() throws GeneralSecurityException {
+ KeyStore androidKeyStore = KeyStore.getInstance(ANDROID_KEYSTORE);
+
+ try {
+ androidKeyStore.load(null);
+ } catch (IOException e) {
+ // Should not happen
+ throw new KeyStoreException("Failed to load Android Keystore.", e);
+ }
+
+ return androidKeyStore;
+ }
+
+ /**
+ * Fetch the certificate chain encoded as byte array in the form of concatenated
+ * X509 certificates.
+ *
+ * @param alias unique alias for the key-pair entry
+ * @return a single byte-array containing the entire certificate chain
+ */
+ static byte[] getEncodedCertificateChain(String alias) throws GeneralSecurityException {
+ KeyStore ks = loadKeyStore();
+
+ Certificate[] certificateChain = ks.getCertificateChain(alias);
+
+ ByteArrayOutputStream buffer = new ByteArrayOutputStream();
+ for (Certificate certificate : certificateChain) {
+ buffer.writeBytes(certificate.getEncoded());
+ }
+ return buffer.toByteArray();
+ }
+
+ /**
+ * Generate a new attestation key-pair.
+ *
+ * @param alias unique alias for the key-pair entry
+ * @param attestationChallenge challenge value to check against for authentication
+ */
+ static void generateAttestationKeyPair(String alias, byte[] attestationChallenge)
+ throws GeneralSecurityException {
+ KeyGenParameterSpec parameterSpec =
+ new KeyGenParameterSpec.Builder(alias, PURPOSE_SIGN | PURPOSE_VERIFY)
+ .setAttestationChallenge(attestationChallenge)
+ .setDigests(DIGEST_SHA256)
+ .build();
+
+ KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance(
+ /* algorithm */ KEY_ALGORITHM_EC,
+ /* provider */ ANDROID_KEYSTORE);
+ keyPairGenerator.initialize(parameterSpec);
+ keyPairGenerator.generateKeyPair();
+ }
+
+ /**
+ * Check if alias exists.
+ *
+ * @param alias unique alias for the key-pair entry
+ * @return true if given alias already exists in the keystore
+ */
+ static boolean aliasExists(String alias) {
+ try {
+ KeyStore ks = loadKeyStore();
+ return ks.containsAlias(alias);
+ } catch (GeneralSecurityException e) {
+ return false;
+ }
+
+ }
+
+ static void cleanUp(String alias) {
+ try {
+ KeyStore ks = loadKeyStore();
+
+ if (ks.containsAlias(alias)) {
+ ks.deleteEntry(alias);
+ }
+ } catch (Exception ignored) {
+ // Do nothing;
+ }
+ }
+}
diff --git a/services/companion/java/com/android/server/companion/securechannel/SecureChannel.java b/services/companion/java/com/android/server/companion/securechannel/SecureChannel.java
new file mode 100644
index 0000000..13dba84
--- /dev/null
+++ b/services/companion/java/com/android/server/companion/securechannel/SecureChannel.java
@@ -0,0 +1,543 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.companion.securechannel;
+
+import static android.security.attestationverification.AttestationVerificationManager.RESULT_SUCCESS;
+
+import android.annotation.NonNull;
+import android.content.Context;
+import android.os.Build;
+import android.util.Slog;
+
+import com.google.security.cryptauth.lib.securegcm.BadHandleException;
+import com.google.security.cryptauth.lib.securegcm.CryptoException;
+import com.google.security.cryptauth.lib.securegcm.D2DConnectionContextV1;
+import com.google.security.cryptauth.lib.securegcm.D2DHandshakeContext;
+import com.google.security.cryptauth.lib.securegcm.D2DHandshakeContext.Role;
+import com.google.security.cryptauth.lib.securegcm.DefaultUkey2Logger;
+import com.google.security.cryptauth.lib.securegcm.HandshakeException;
+
+import libcore.io.IoUtils;
+import libcore.io.Streams;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.nio.ByteBuffer;
+import java.nio.charset.StandardCharsets;
+import java.security.GeneralSecurityException;
+import java.security.MessageDigest;
+import java.util.Arrays;
+import java.util.UUID;
+
+/**
+ * Data stream channel that establishes secure connection between two peer devices.
+ */
+public class SecureChannel {
+ private static final String TAG = "CDM_SecureChannel";
+ private static final boolean DEBUG = Build.IS_DEBUGGABLE;
+
+ private static final int VERSION = 1;
+ private static final int HEADER_LENGTH = 6;
+
+ private static final String HANDSHAKE_PROTOCOL = "AES_256_CBC-HMAC_SHA256";
+
+ private final InputStream mInput;
+ private final OutputStream mOutput;
+ private final Callback mCallback;
+ private final byte[] mPreSharedKey;
+ private final AttestationVerifier mVerifier;
+
+ private volatile boolean mStopped;
+ private boolean mInProgress;
+
+ private Role mRole;
+ private D2DHandshakeContext mHandshakeContext;
+ private D2DConnectionContextV1 mConnectionContext;
+
+ private String mAlias;
+ private int mVerificationResult;
+
+
+ /**
+ * Create a new secure channel object. This secure channel allows secure messages to be
+ * exchanged with unattested devices. The pre-shared key must have been distributed to both
+ * participants of the channel in a secure way previously.
+ *
+ * @param in input stream from which data is received
+ * @param out output stream from which data is sent out
+ * @param callback subscription to received messages from the channel
+ * @param preSharedKey pre-shared key to authenticate unattested participant
+ */
+ public SecureChannel(
+ @NonNull final InputStream in,
+ @NonNull final OutputStream out,
+ @NonNull Callback callback,
+ @NonNull byte[] preSharedKey
+ ) {
+ this(in, out, callback, preSharedKey, null);
+ }
+
+ /**
+ * Create a new secure channel object. This secure channel allows secure messages to be
+ * exchanged with Android devices that were authenticated and verified with an attestation key.
+ *
+ * @param in input stream from which data is received
+ * @param out output stream from which data is sent out
+ * @param callback subscription to received messages from the channel
+ * @param context context for fetching the Attestation Verifier Framework system service
+ */
+ public SecureChannel(
+ @NonNull final InputStream in,
+ @NonNull final OutputStream out,
+ @NonNull Callback callback,
+ @NonNull Context context
+ ) {
+ this(in, out, callback, null, new AttestationVerifier(context));
+ }
+
+ private SecureChannel(
+ final InputStream in,
+ final OutputStream out,
+ Callback callback,
+ byte[] preSharedKey,
+ AttestationVerifier verifier
+ ) {
+ this.mInput = in;
+ this.mOutput = out;
+ this.mCallback = callback;
+ this.mPreSharedKey = preSharedKey;
+ this.mVerifier = verifier;
+ }
+
+ /**
+ * Start listening for incoming messages.
+ */
+ public void start() {
+ new Thread(() -> {
+ try {
+ // 1. Wait for the next handshake message and process it.
+ exchangeHandshake();
+
+ // 2. Authenticate remote actor via attestation or pre-shared key.
+ exchangeAuthentication();
+
+ // 3. Notify secure channel is ready.
+ mInProgress = false;
+ mCallback.onSecureConnection();
+
+ // Listen for secure messages.
+ while (!mStopped) {
+ receiveSecureMessage();
+ }
+ } catch (Exception e) {
+ if (mStopped) {
+ return;
+ }
+ // TODO: Handle different types errors.
+
+ Slog.e(TAG, "Secure channel encountered an error.", e);
+ stop();
+ mCallback.onError(e);
+ }
+ }).start();
+ }
+
+ /**
+ * Stop listening to incoming messages and close the channel.
+ */
+ public void stop() {
+ if (DEBUG) {
+ Slog.d(TAG, "Stopping secure channel.");
+ }
+ mStopped = true;
+ mInProgress = false;
+
+ IoUtils.closeQuietly(mInput);
+ IoUtils.closeQuietly(mOutput);
+ KeyStoreUtils.cleanUp(mAlias);
+ }
+
+ /**
+ * Start exchanging handshakes to create a secure layer asynchronously. When the handshake is
+ * completed successfully, then the {@link Callback#onSecureConnection()} will trigger. Any
+ * error that occurs during the handshake will be passed by {@link Callback#onError(Throwable)}.
+ *
+ * This method must only be called from one of the two participants.
+ */
+ public void establishSecureConnection() throws IOException, SecureChannelException {
+ if (isSecured()) {
+ Slog.d(TAG, "Channel is already secure.");
+ return;
+ }
+ if (mInProgress) {
+ Slog.w(TAG, "Channel has already started establishing secure connection.");
+ return;
+ }
+
+ try {
+ initiateHandshake();
+ mInProgress = true;
+ } catch (BadHandleException e) {
+ throw new SecureChannelException("Failed to initiate handshake protocol.", e);
+ }
+ }
+
+ /**
+ * Send an encrypted, authenticated message via this channel.
+ *
+ * @param data data to be sent to the other side.
+ * @throws IOException if the output stream fails to write given data.
+ */
+ public void sendSecureMessage(byte[] data) throws IOException {
+ if (!isSecured()) {
+ Slog.d(TAG, "Cannot send a message without a secure connection.");
+ throw new IllegalStateException("Channel is not secured yet.");
+ }
+
+ // Encrypt constructed message
+ try {
+ sendMessage(MessageType.SECURE_MESSAGE, data);
+ } catch (BadHandleException e) {
+ throw new SecureChannelException("Failed to encrypt data.", e);
+ }
+ }
+
+ private void receiveSecureMessage() throws IOException, CryptoException {
+ // Check if channel is secured. Trigger error callback. Let user handle it.
+ if (!isSecured()) {
+ Slog.d(TAG, "Received a message without a secure connection. "
+ + "Message will be ignored.");
+ mCallback.onError(new IllegalStateException("Connection is not secure."));
+ return;
+ }
+
+ try {
+ byte[] receivedMessage = readMessage(MessageType.SECURE_MESSAGE);
+ mCallback.onSecureMessageReceived(receivedMessage);
+ } catch (SecureChannelException e) {
+ Slog.w(TAG, "Ignoring received message.", e);
+ }
+ }
+
+ private byte[] readMessage(MessageType expected)
+ throws IOException, SecureChannelException, CryptoException {
+ if (DEBUG) {
+ if (isSecured()) {
+ Slog.d(TAG, "Waiting to receive next secure message.");
+ } else {
+ Slog.d(TAG, "Waiting to receive next message.");
+ }
+ }
+
+ // TODO: Handle message timeout
+
+ // Header is _not_ encrypted, but will be covered by MAC
+ final byte[] headerBytes = new byte[HEADER_LENGTH];
+ Streams.readFully(mInput, headerBytes);
+ final ByteBuffer header = ByteBuffer.wrap(headerBytes);
+ final int version = header.getInt();
+ final short type = header.getShort();
+
+ if (version != VERSION) {
+ Streams.skipByReading(mInput, Long.MAX_VALUE);
+ throw new SecureChannelException("Secure channel version mismatch. "
+ + "Currently on version " + VERSION + ". Skipping rest of data.");
+ }
+
+ if (type != expected.mValue) {
+ Streams.skipByReading(mInput, Long.MAX_VALUE);
+ throw new SecureChannelException("Unexpected message type. Expected " + expected.name()
+ + "; Found " + MessageType.from(type).name() + ". Skipping rest of data.");
+ }
+
+ // Length of attached data is prepended as plaintext
+ final byte[] lengthBytes = new byte[4];
+ Streams.readFully(mInput, lengthBytes);
+ final int length = ByteBuffer.wrap(lengthBytes).getInt();
+
+ // Read data based on the length
+ final byte[] data;
+ try {
+ data = new byte[length];
+ } catch (OutOfMemoryError error) {
+ throw new SecureChannelException("Payload is too large.", error);
+ }
+
+ Streams.readFully(mInput, data);
+ if (!MessageType.shouldEncrypt(expected)) {
+ return data;
+ }
+
+ return mConnectionContext.decodeMessageFromPeer(data, headerBytes);
+ }
+
+ private void sendMessage(MessageType messageType, byte[] payload)
+ throws IOException, BadHandleException {
+ synchronized (mOutput) {
+ byte[] header = ByteBuffer.allocate(HEADER_LENGTH)
+ .putInt(VERSION)
+ .putShort(messageType.mValue)
+ .array();
+ byte[] data = MessageType.shouldEncrypt(messageType)
+ ? mConnectionContext.encodeMessageToPeer(payload, header)
+ : payload;
+ mOutput.write(header);
+ mOutput.write(ByteBuffer.allocate(4)
+ .putInt(data.length)
+ .array());
+ mOutput.write(data);
+ mOutput.flush();
+ }
+ }
+
+ private void initiateHandshake() throws IOException, BadHandleException {
+ if (mConnectionContext != null) {
+ Slog.d(TAG, "Ukey2 handshake is already completed.");
+ return;
+ }
+
+ mRole = Role.Initiator;
+ mHandshakeContext = D2DHandshakeContext.forInitiator(DefaultUkey2Logger.INSTANCE);
+
+ // Send Client Init
+ if (DEBUG) {
+ Slog.d(TAG, "Sending Ukey2 Client Init message");
+ }
+ sendMessage(MessageType.HANDSHAKE_INIT, mHandshakeContext.getNextHandshakeMessage());
+ }
+
+ private void exchangeHandshake()
+ throws IOException, HandshakeException, BadHandleException, CryptoException {
+ if (mConnectionContext != null) {
+ Slog.d(TAG, "Ukey2 handshake is already completed.");
+ return;
+ }
+
+ // Waiting for message
+ byte[] handshakeMessage = readMessage(MessageType.HANDSHAKE_INIT);
+
+ if (mHandshakeContext == null) { // Server-side logic
+ mRole = Role.Responder;
+ mHandshakeContext = D2DHandshakeContext.forResponder(DefaultUkey2Logger.INSTANCE);
+
+ // Receive Client Init
+ if (DEBUG) {
+ Slog.d(TAG, "Receiving Ukey2 Client Init message");
+ }
+ mHandshakeContext.parseHandshakeMessage(handshakeMessage);
+
+ // Send Server Init
+ if (DEBUG) {
+ Slog.d(TAG, "Sending Ukey2 Server Init message");
+ }
+ sendMessage(MessageType.HANDSHAKE_INIT, mHandshakeContext.getNextHandshakeMessage());
+
+ // Receive Client Finish
+ if (DEBUG) {
+ Slog.d(TAG, "Receiving Ukey2 Client Finish message");
+ }
+ mHandshakeContext.parseHandshakeMessage(readMessage(MessageType.HANDSHAKE_FINISH));
+ } else { // Client-side logic
+
+ // Receive Server Init
+ if (DEBUG) {
+ Slog.d(TAG, "Receiving Ukey2 Server Init message");
+ }
+ mHandshakeContext.parseHandshakeMessage(handshakeMessage);
+
+ // Send Client Finish
+ if (DEBUG) {
+ Slog.d(TAG, "Sending Ukey2 Client Finish message");
+ }
+ sendMessage(MessageType.HANDSHAKE_FINISH, mHandshakeContext.getNextHandshakeMessage());
+ }
+
+ // Convert secrets to connection context
+ if (mHandshakeContext.isHandshakeComplete()) {
+ if (DEBUG) {
+ Slog.d(TAG, "Ukey2 Handshake completed successfully");
+ }
+ mConnectionContext = mHandshakeContext.toConnectionContext();
+ } else {
+ Slog.e(TAG, "Failed to complete Ukey2 Handshake");
+ throw new IllegalStateException("Ukey2 Handshake did not complete as expected.");
+ }
+ }
+
+ private void exchangeAuthentication()
+ throws IOException, GeneralSecurityException, BadHandleException, CryptoException {
+ if (mVerifier == null) {
+ exchangePreSharedKey();
+ } else {
+ exchangeAttestation();
+ }
+ }
+
+ private void exchangePreSharedKey()
+ throws IOException, GeneralSecurityException, BadHandleException, CryptoException {
+
+ // Exchange hashed pre-shared keys
+ if (DEBUG) {
+ Slog.d(TAG, "Exchanging pre-shared keys.");
+ }
+ sendMessage(MessageType.PRE_SHARED_KEY, constructToken(mRole, mPreSharedKey));
+ byte[] receivedAuthToken = readMessage(MessageType.PRE_SHARED_KEY);
+ byte[] expectedAuthToken = constructToken(mRole == Role.Initiator
+ ? Role.Responder
+ : Role.Initiator,
+ mPreSharedKey);
+ boolean authenticated = Arrays.equals(receivedAuthToken, expectedAuthToken);
+
+ if (!authenticated) {
+ throw new SecureChannelException("Failed to verify the hash of pre-shared key.");
+ }
+
+ if (DEBUG) {
+ Slog.d(TAG, "The pre-shared key was successfully authenticated.");
+ }
+ }
+
+ private void exchangeAttestation()
+ throws IOException, GeneralSecurityException, BadHandleException, CryptoException {
+ if (mVerificationResult == RESULT_SUCCESS) {
+ Slog.d(TAG, "Remote attestation was already verified.");
+ return;
+ }
+
+ // Send local attestation
+ if (DEBUG) {
+ Slog.d(TAG, "Exchanging device attestation.");
+ }
+ if (mAlias == null) {
+ mAlias = generateAlias();
+ }
+ byte[] localChallenge = constructToken(mRole, mConnectionContext.getSessionUnique());
+ KeyStoreUtils.generateAttestationKeyPair(mAlias, localChallenge);
+ byte[] localAttestation = KeyStoreUtils.getEncodedCertificateChain(mAlias);
+ sendMessage(MessageType.ATTESTATION, localAttestation);
+ byte[] remoteAttestation = readMessage(MessageType.ATTESTATION);
+
+ // Verifying remote attestation with public key local binding param
+ byte[] expectedChallenge = constructToken(mRole == Role.Initiator
+ ? Role.Responder
+ : Role.Initiator,
+ mConnectionContext.getSessionUnique());
+ mVerificationResult = mVerifier.verifyAttestation(remoteAttestation, expectedChallenge);
+
+ // Exchange attestation verification result and finish
+ byte[] verificationResult = ByteBuffer.allocate(4)
+ .putInt(mVerificationResult)
+ .array();
+ sendMessage(MessageType.AVF_RESULT, verificationResult);
+ byte[] remoteVerificationResult = readMessage(MessageType.AVF_RESULT);
+
+ if (ByteBuffer.wrap(remoteVerificationResult).getInt() != RESULT_SUCCESS) {
+ throw new SecureChannelException("Remote device failed to verify local attestation.");
+ }
+
+ if (mVerificationResult != RESULT_SUCCESS) {
+ throw new SecureChannelException("Failed to verify remote attestation.");
+ }
+
+ if (DEBUG) {
+ Slog.d(TAG, "Remote attestation was successfully verified.");
+ }
+ }
+
+ private boolean isSecured() {
+ if (mConnectionContext == null) {
+ return false;
+ }
+ return mVerifier == null || mVerificationResult == RESULT_SUCCESS;
+ }
+
+ private byte[] constructToken(D2DHandshakeContext.Role role, byte[] authValue)
+ throws GeneralSecurityException {
+ MessageDigest hash = MessageDigest.getInstance("SHA-256");
+ byte[] roleUtf8 = role.name().getBytes(StandardCharsets.UTF_8);
+ int tokenLength = roleUtf8.length + authValue.length;
+ return hash.digest(ByteBuffer.allocate(tokenLength)
+ .put(roleUtf8)
+ .put(authValue)
+ .array());
+ }
+
+ private String generateAlias() {
+ String alias;
+ do {
+ alias = "secure-channel-" + UUID.randomUUID();
+ } while (KeyStoreUtils.aliasExists(alias));
+ return alias;
+ }
+
+ private enum MessageType {
+ HANDSHAKE_INIT(0x4849), // HI
+ HANDSHAKE_FINISH(0x4846), // HF
+ PRE_SHARED_KEY(0x504b), // PK
+ ATTESTATION(0x4154), // AT
+ AVF_RESULT(0x5652), // VR
+ SECURE_MESSAGE(0x534d), // SM
+ UNKNOWN(0); // X
+
+ private final short mValue;
+
+ MessageType(int value) {
+ this.mValue = (short) value;
+ }
+
+ static MessageType from(short value) {
+ for (MessageType messageType : values()) {
+ if (value == messageType.mValue) {
+ return messageType;
+ }
+ }
+ return UNKNOWN;
+ }
+
+ // Encrypt every message besides Ukey2 handshake messages
+ private static boolean shouldEncrypt(MessageType type) {
+ return type != HANDSHAKE_INIT && type != HANDSHAKE_FINISH;
+ }
+ }
+
+ /**
+ * Callback that passes securely received message to the subscribed user.
+ */
+ public interface Callback {
+ /**
+ * Triggered after {@link SecureChannel#establishSecureConnection()} finishes exchanging
+ * every required handshakes to fully establish a secure connection.
+ */
+ void onSecureConnection();
+
+ /**
+ * Callback that passes securely received and decrypted data to the subscribed user.
+ *
+ * @param data securely received plaintext data.
+ */
+ void onSecureMessageReceived(byte[] data);
+
+ /**
+ * Callback that passes error that occurred during handshakes or while listening to
+ * messages in the secure channel.
+ *
+ * @param error
+ */
+ void onError(Throwable error);
+ }
+}
diff --git a/services/companion/java/com/android/server/companion/securechannel/SecureChannelException.java b/services/companion/java/com/android/server/companion/securechannel/SecureChannelException.java
new file mode 100644
index 0000000..68db97e
--- /dev/null
+++ b/services/companion/java/com/android/server/companion/securechannel/SecureChannelException.java
@@ -0,0 +1,34 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.companion.securechannel;
+
+/**
+ * Catch-all exception for any error in the secure channel.
+ */
+public class SecureChannelException extends RuntimeException {
+ /**
+ *
+ * @param message
+ */
+ public SecureChannelException(String message) {
+ super(message);
+ }
+
+ public SecureChannelException(String message, Throwable t) {
+ super(message, t);
+ }
+}
diff --git a/services/companion/java/com/android/server/companion/transport/CompanionTransportManager.java b/services/companion/java/com/android/server/companion/transport/CompanionTransportManager.java
index 6db99a0..494c5a6 100644
--- a/services/companion/java/com/android/server/companion/transport/CompanionTransportManager.java
+++ b/services/companion/java/com/android/server/companion/transport/CompanionTransportManager.java
@@ -34,6 +34,7 @@
import com.android.internal.annotations.GuardedBy;
import com.android.server.LocalServices;
+import com.android.server.companion.securechannel.SecureChannel;
import libcore.io.IoUtils;
import libcore.io.Streams;
@@ -43,6 +44,8 @@
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
+import java.util.concurrent.ArrayBlockingQueue;
+import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;
@@ -54,8 +57,6 @@
private static final boolean DEBUG = true;
private static final int HEADER_LENGTH = 12;
- // TODO: refactor message processing to use streams to remove this limit
- private static final int MAX_PAYLOAD_LENGTH = 1_000_000;
private static final int MESSAGE_REQUEST_PING = 0x63807378; // ?PIN
private static final int MESSAGE_REQUEST_PERMISSION_RESTORE = 0x63826983; // ?RES
@@ -63,6 +64,8 @@
private static final int MESSAGE_RESPONSE_SUCCESS = 0x33838567; // !SUC
private static final int MESSAGE_RESPONSE_FAILURE = 0x33706573; // !FAI
+ private boolean mSecureTransportEnabled = true;
+
private static boolean isRequest(int message) {
return (message & 0xFF000000) == 0x63000000;
}
@@ -122,7 +125,13 @@
detachSystemDataTransport(packageName, userId, associationId);
}
- final Transport transport = new Transport(associationId, fd);
+ final Transport transport;
+ if (isSecureTransportEnabled(associationId)) {
+ transport = new SecureTransport(associationId, fd);
+ } else {
+ transport = new RawTransport(associationId, fd);
+ }
+
transport.start();
mTransports.put(associationId, transport);
}
@@ -142,61 +151,65 @@
public Future<?> requestPermissionRestore(int associationId, byte[] data) {
synchronized (mTransports) {
final Transport transport = mTransports.get(associationId);
- if (transport != null) {
- return transport.requestForResponse(MESSAGE_REQUEST_PERMISSION_RESTORE, data);
- } else {
+ if (transport == null) {
return CompletableFuture.failedFuture(new IOException("Missing transport"));
}
+
+ return transport.requestForResponse(MESSAGE_REQUEST_PERMISSION_RESTORE, data);
}
}
- private class Transport {
- private final int mAssociationId;
+ /**
+ * @hide
+ */
+ public void enableSecureTransport(boolean enabled) {
+ this.mSecureTransportEnabled = enabled;
+ }
- private final InputStream mRemoteIn;
- private final OutputStream mRemoteOut;
+ private boolean isSecureTransportEnabled(int associationId) {
+ boolean enabled = !Build.IS_DEBUGGABLE || mSecureTransportEnabled;
- private final AtomicInteger mNextSequence = new AtomicInteger();
+ // TODO: version comparison logic
+ return enabled;
+ }
+
+ // TODO: Make Transport inner classes into standalone classes.
+ private abstract class Transport {
+ protected final int mAssociationId;
+ protected final InputStream mRemoteIn;
+ protected final OutputStream mRemoteOut;
@GuardedBy("mPendingRequests")
- private final SparseArray<CompletableFuture<byte[]>> mPendingRequests = new SparseArray<>();
+ protected final SparseArray<CompletableFuture<byte[]>> mPendingRequests =
+ new SparseArray<>();
+ protected final AtomicInteger mNextSequence = new AtomicInteger();
- private volatile boolean mStopped;
-
- public Transport(int associationId, ParcelFileDescriptor fd) {
- mAssociationId = associationId;
- mRemoteIn = new ParcelFileDescriptor.AutoCloseInputStream(fd);
- mRemoteOut = new ParcelFileDescriptor.AutoCloseOutputStream(fd);
+ Transport(int associationId, ParcelFileDescriptor fd) {
+ this(associationId,
+ new ParcelFileDescriptor.AutoCloseInputStream(fd),
+ new ParcelFileDescriptor.AutoCloseOutputStream(fd));
}
- public void start() {
- new Thread(() -> {
- try {
- while (!mStopped) {
- receiveMessage();
- }
- } catch (IOException e) {
- if (!mStopped) {
- Slog.w(TAG, "Trouble during transport", e);
- stop();
- }
- }
- }).start();
+ Transport(int associationId, InputStream in, OutputStream out) {
+ this.mAssociationId = associationId;
+ this.mRemoteIn = in;
+ this.mRemoteOut = out;
}
- public void stop() {
- mStopped = true;
+ public abstract void start();
+ public abstract void stop();
- IoUtils.closeQuietly(mRemoteIn);
- IoUtils.closeQuietly(mRemoteOut);
- }
+ protected abstract void sendMessage(int message, int sequence, @NonNull byte[] data)
+ throws IOException;
public Future<byte[]> requestForResponse(int message, byte[] data) {
+ if (DEBUG) Slog.d(TAG, "Requesting for response");
final int sequence = mNextSequence.incrementAndGet();
final CompletableFuture<byte[]> pending = new CompletableFuture<>();
synchronized (mPendingRequests) {
mPendingRequests.put(sequence, pending);
}
+
try {
sendMessage(message, sequence, data);
} catch (IOException e) {
@@ -205,58 +218,24 @@
}
pending.completeExceptionally(e);
}
+
return pending;
}
- private void sendMessage(int message, int sequence, @NonNull byte[] data)
+ protected final void handleMessage(int message, int sequence, @NonNull byte[] data)
throws IOException {
if (DEBUG) {
- Slog.d(TAG, "Sending message 0x" + Integer.toHexString(message)
- + " sequence " + sequence + " length " + data.length
- + " to association " + mAssociationId);
- }
-
- synchronized (mRemoteOut) {
- final ByteBuffer header = ByteBuffer.allocate(HEADER_LENGTH)
- .putInt(message)
- .putInt(sequence)
- .putInt(data.length);
- mRemoteOut.write(header.array());
- mRemoteOut.write(data);
- mRemoteOut.flush();
- }
- }
-
- private void receiveMessage() throws IOException {
- if (DEBUG) {
- Slog.d(TAG, "Waiting for next message...");
- }
-
- final byte[] headerBytes = new byte[HEADER_LENGTH];
- Streams.readFully(mRemoteIn, headerBytes);
- final ByteBuffer header = ByteBuffer.wrap(headerBytes);
- final int message = header.getInt();
- final int sequence = header.getInt();
- final int length = header.getInt();
-
- if (DEBUG) {
Slog.d(TAG, "Received message 0x" + Integer.toHexString(message)
- + " sequence " + sequence + " length " + length
+ + " sequence " + sequence + " length " + data.length
+ " from association " + mAssociationId);
}
- if (length > MAX_PAYLOAD_LENGTH) {
- Slog.w(TAG, "Ignoring message 0x" + Integer.toHexString(message)
- + " sequence " + sequence + " length " + length
- + " from association " + mAssociationId + " beyond maximum length");
- Streams.skipByReading(mRemoteIn, length);
- return;
- }
-
- final byte[] data = new byte[length];
- Streams.readFully(mRemoteIn, data);
if (isRequest(message)) {
- processRequest(message, sequence, data);
+ try {
+ processRequest(message, sequence, data);
+ } catch (IOException e) {
+ Slog.w(TAG, "Failed to respond to 0x" + Integer.toHexString(message), e);
+ }
} else if (isResponse(message)) {
processResponse(message, sequence, data);
} else {
@@ -320,4 +299,169 @@
}
}
}
+
+ private class RawTransport extends Transport {
+ private volatile boolean mStopped;
+
+ RawTransport(int associationId, ParcelFileDescriptor fd) {
+ super(associationId, fd);
+ }
+
+ @Override
+ public void start() {
+ new Thread(() -> {
+ try {
+ while (!mStopped) {
+ receiveMessage();
+ }
+ } catch (IOException e) {
+ if (!mStopped) {
+ Slog.w(TAG, "Trouble during transport", e);
+ stop();
+ }
+ }
+ }).start();
+ }
+
+ @Override
+ public void stop() {
+ mStopped = true;
+
+ IoUtils.closeQuietly(mRemoteIn);
+ IoUtils.closeQuietly(mRemoteOut);
+ }
+
+ @Override
+ protected void sendMessage(int message, int sequence, @NonNull byte[] data)
+ throws IOException {
+ if (DEBUG) {
+ Slog.d(TAG, "Sending message 0x" + Integer.toHexString(message)
+ + " sequence " + sequence + " length " + data.length
+ + " to association " + mAssociationId);
+ }
+
+ synchronized (mRemoteOut) {
+ final ByteBuffer header = ByteBuffer.allocate(HEADER_LENGTH)
+ .putInt(message)
+ .putInt(sequence)
+ .putInt(data.length);
+ mRemoteOut.write(header.array());
+ mRemoteOut.write(data);
+ mRemoteOut.flush();
+ }
+ }
+
+ private void receiveMessage() throws IOException {
+ final byte[] headerBytes = new byte[HEADER_LENGTH];
+ Streams.readFully(mRemoteIn, headerBytes);
+ final ByteBuffer header = ByteBuffer.wrap(headerBytes);
+ final int message = header.getInt();
+ final int sequence = header.getInt();
+ final int length = header.getInt();
+ final byte[] data = new byte[length];
+ Streams.readFully(mRemoteIn, data);
+
+ handleMessage(message, sequence, data);
+ }
+ }
+
+ private class SecureTransport extends Transport implements SecureChannel.Callback {
+ private final SecureChannel mSecureChannel;
+
+ private volatile boolean mShouldProcessRequests = false;
+
+ private final BlockingQueue<byte[]> mRequestQueue = new ArrayBlockingQueue<>(100);
+
+ SecureTransport(int associationId, ParcelFileDescriptor fd) {
+ super(associationId, fd);
+ mSecureChannel = new SecureChannel(mRemoteIn, mRemoteOut, this, mContext);
+ }
+
+ @Override
+ public void start() {
+ mSecureChannel.start();
+ }
+
+ @Override
+ public void stop() {
+ mSecureChannel.stop();
+ mShouldProcessRequests = false;
+ }
+
+ @Override
+ public Future<byte[]> requestForResponse(int message, byte[] data) {
+ // Check if channel is secured and start securing
+ if (!mShouldProcessRequests) {
+ Slog.d(TAG, "Establishing secure connection.");
+ try {
+ mSecureChannel.establishSecureConnection();
+ } catch (Exception e) {
+ Slog.w(TAG, "Failed to initiate secure channel handshake.", e);
+ onError(e);
+ }
+ }
+
+ return super.requestForResponse(message, data);
+ }
+
+ @Override
+ protected void sendMessage(int message, int sequence, @NonNull byte[] data)
+ throws IOException {
+ if (DEBUG) {
+ Slog.d(TAG, "Queueing message 0x" + Integer.toHexString(message)
+ + " sequence " + sequence + " length " + data.length
+ + " to association " + mAssociationId);
+ }
+
+ // Queue up a message to send
+ mRequestQueue.add(ByteBuffer.allocate(HEADER_LENGTH + data.length)
+ .putInt(message)
+ .putInt(sequence)
+ .putInt(data.length)
+ .put(data)
+ .array());
+ }
+
+ @Override
+ public void onSecureConnection() {
+ mShouldProcessRequests = true;
+ Slog.d(TAG, "Secure connection established.");
+
+ // TODO: find a better way to handle incoming requests than a dedicated thread.
+ new Thread(() -> {
+ try {
+ while (mShouldProcessRequests) {
+ byte[] request = mRequestQueue.poll();
+ if (request != null) {
+ mSecureChannel.sendSecureMessage(request);
+ }
+ }
+ } catch (IOException e) {
+ onError(e);
+ }
+ }).start();
+ }
+
+ @Override
+ public void onSecureMessageReceived(byte[] data) {
+ final ByteBuffer payload = ByteBuffer.wrap(data);
+ final int message = payload.getInt();
+ final int sequence = payload.getInt();
+ final int length = payload.getInt();
+ final byte[] content = new byte[length];
+ payload.get(content);
+
+ try {
+ handleMessage(message, sequence, content);
+ } catch (IOException error) {
+ onError(error);
+ }
+ }
+
+ @Override
+ public void onError(Throwable error) {
+ mShouldProcessRequests = false;
+ Slog.e(TAG, error.getMessage(), error);
+ }
+ }
}