Merge "Keystore 2.0: km_compat: Fix operation slot accounting."
diff --git a/keystore2/src/km_compat/km_compat.cpp b/keystore2/src/km_compat/km_compat.cpp
index a6ca179..3ade2cf 100644
--- a/keystore2/src/km_compat/km_compat.cpp
+++ b/keystore2/src/km_compat/km_compat.cpp
@@ -384,29 +384,39 @@
     return ssps;
 }
 
-void OperationSlots::setNumFreeSlots(uint8_t numFreeSlots) {
+void OperationSlotManager::setNumFreeSlots(uint8_t numFreeSlots) {
     std::lock_guard<std::mutex> lock(mNumFreeSlotsMutex);
     mNumFreeSlots = numFreeSlots;
 }
 
-bool OperationSlots::claimSlot() {
-    std::lock_guard<std::mutex> lock(mNumFreeSlotsMutex);
-    if (mNumFreeSlots > 0) {
-        mNumFreeSlots--;
-        return true;
+std::optional<OperationSlot>
+OperationSlotManager::claimSlot(std::shared_ptr<OperationSlotManager> operationSlots) {
+    std::lock_guard<std::mutex> lock(operationSlots->mNumFreeSlotsMutex);
+    if (operationSlots->mNumFreeSlots > 0) {
+        operationSlots->mNumFreeSlots--;
+        return OperationSlot(std::move(operationSlots), std::nullopt);
     }
-    return false;
+    return std::nullopt;
 }
 
-void OperationSlots::freeSlot() {
+OperationSlot
+OperationSlotManager::claimReservedSlot(std::shared_ptr<OperationSlotManager> operationSlots) {
+    std::unique_lock<std::mutex> reservedGuard(operationSlots->mReservedSlotMutex);
+    return OperationSlot(std::move(operationSlots), std::move(reservedGuard));
+}
+
+OperationSlot::OperationSlot(std::shared_ptr<OperationSlotManager> slots,
+                             std::optional<std::unique_lock<std::mutex>> reservedGuard)
+    : mOperationSlots(std::move(slots)), mReservedGuard(std::move(reservedGuard)) {}
+
+void OperationSlotManager::freeSlot() {
     std::lock_guard<std::mutex> lock(mNumFreeSlotsMutex);
     mNumFreeSlots++;
 }
 
-void OperationSlot::freeSlot() {
-    if (mIsActive) {
+OperationSlot::~OperationSlot() {
+    if (!mReservedGuard && mOperationSlots) {
         mOperationSlots->freeSlot();
-        mIsActive = false;
     }
 }
 
@@ -613,9 +623,15 @@
                                    const std::vector<KeyParameter>& in_inParams,
                                    const std::optional<HardwareAuthToken>& in_inAuthToken,
                                    BeginResult* _aidl_return) {
-    if (!mOperationSlots.claimSlot()) {
-        return convertErrorCode(V4_0_ErrorCode::TOO_MANY_OPERATIONS);
-    }
+    return beginInternal(in_inPurpose, prefixedKeyBlob, in_inParams, in_inAuthToken,
+                         false /* useReservedSlot */, _aidl_return);
+}
+
+ScopedAStatus KeyMintDevice::beginInternal(KeyPurpose in_inPurpose,
+                                           const std::vector<uint8_t>& prefixedKeyBlob,
+                                           const std::vector<KeyParameter>& in_inParams,
+                                           const std::optional<HardwareAuthToken>& in_inAuthToken,
+                                           bool useReservedSlot, BeginResult* _aidl_return) {
 
     const std::vector<uint8_t>& in_inKeyBlob = prefixedKeyBlobRemovePrefix(prefixedKeyBlob);
     if (prefixedKeyBlobIsSoftKeyMint(prefixedKeyBlob)) {
@@ -623,28 +639,41 @@
                                          _aidl_return);
     }
 
+    OperationSlot slot;
+    // No need to claim a slot for software device.
+    if (useReservedSlot) {
+        // There is only one reserved slot. This function blocks until
+        // the reserved slot becomes available.
+        slot = OperationSlotManager::claimReservedSlot(mOperationSlots);
+    } else {
+        if (auto opt_slot = OperationSlotManager::claimSlot(mOperationSlots)) {
+            slot = std::move(*opt_slot);
+        } else {
+            return convertErrorCode(V4_0_ErrorCode::TOO_MANY_OPERATIONS);
+        }
+    }
+
     auto legacyPurpose =
         static_cast<::android::hardware::keymaster::V4_0::KeyPurpose>(in_inPurpose);
     auto legacyParams = convertKeyParametersToLegacy(in_inParams);
     auto legacyAuthToken = convertAuthTokenToLegacy(in_inAuthToken);
     KMV1::ErrorCode errorCode;
-    auto result = mDevice->begin(
-        legacyPurpose, in_inKeyBlob, legacyParams, legacyAuthToken,
-        [&](V4_0_ErrorCode error, const hidl_vec<V4_0_KeyParameter>& outParams,
-            uint64_t operationHandle) {
-            errorCode = convert(error);
-            _aidl_return->challenge = operationHandle;
-            _aidl_return->params = convertKeyParametersFromLegacy(outParams);
-            _aidl_return->operation = ndk::SharedRefBase::make<KeyMintOperation>(
-                mDevice, operationHandle, &mOperationSlots, error == V4_0_ErrorCode::OK);
-        });
+    auto result =
+        mDevice->begin(legacyPurpose, in_inKeyBlob, legacyParams, legacyAuthToken,
+                       [&](V4_0_ErrorCode error, const hidl_vec<V4_0_KeyParameter>& outParams,
+                           uint64_t operationHandle) {
+                           errorCode = convert(error);
+                           if (error == V4_0_ErrorCode::OK) {
+                               _aidl_return->challenge = operationHandle;
+                               _aidl_return->params = convertKeyParametersFromLegacy(outParams);
+                               _aidl_return->operation = ndk::SharedRefBase::make<KeyMintOperation>(
+                                   mDevice, operationHandle, std::move(slot));
+                           }
+                       });
     if (!result.isOk()) {
         LOG(ERROR) << __func__ << " transaction failed. " << result.description();
         errorCode = KMV1::ErrorCode::UNKNOWN_ERROR;
     }
-    if (errorCode != KMV1::ErrorCode::OK) {
-        mOperationSlots.freeSlot();
-    }
     return convertErrorCode(errorCode);
 }
 
@@ -704,8 +733,9 @@
         LOG(ERROR) << __func__ << " export_key failed: " << ret.description();
         return convertErrorCode(KMV1::ErrorCode::UNKNOWN_ERROR);
     }
-    if (km_error != KMV1::ErrorCode::OK)
+    if (km_error != KMV1::ErrorCode::OK) {
         LOG(ERROR) << __func__ << " export_key failed, code " << int32_t(km_error);
+    }
 
     return convertErrorCode(km_error);
 }
@@ -757,7 +787,11 @@
         LOG(ERROR) << __func__ << " transaction failed. " << result.description();
         errorCode = KMV1::ErrorCode::UNKNOWN_ERROR;
     }
-    if (errorCode != KMV1::ErrorCode::OK) mOperationSlot.freeSlot();
+
+    // Operation slot is no longer occupied.
+    if (errorCode != KMV1::ErrorCode::OK) {
+        mOperationSlot = std::nullopt;
+    }
 
     return convertErrorCode(errorCode);
 }
@@ -815,7 +849,10 @@
         inputPos += consumed;
     }
 
-    if (errorCode != KMV1::ErrorCode::OK) mOperationSlot.freeSlot();
+    // Operation slot is no longer occupied.
+    if (errorCode != KMV1::ErrorCode::OK) {
+        mOperationSlot = std::nullopt;
+    }
 
     return convertErrorCode(errorCode);
 }
@@ -846,17 +883,19 @@
             *out_output = output;
         });
 
-    mOperationSlot.freeSlot();
     if (!result.isOk()) {
         LOG(ERROR) << __func__ << " transaction failed. " << result.description();
         errorCode = KMV1::ErrorCode::UNKNOWN_ERROR;
     }
+
+    mOperationSlot = std::nullopt;
+
     return convertErrorCode(errorCode);
 }
 
 ScopedAStatus KeyMintOperation::abort() {
     auto result = mDevice->abort(mOperationHandle);
-    mOperationSlot.freeSlot();
+    mOperationSlot = std::nullopt;
     if (!result.isOk()) {
         LOG(ERROR) << __func__ << " transaction failed. " << result.description();
         return convertErrorCode(KMV1::ErrorCode::UNKNOWN_ERROR);
@@ -865,7 +904,7 @@
 }
 
 KeyMintOperation::~KeyMintOperation() {
-    if (mOperationSlot.hasSlot()) {
+    if (mOperationSlot) {
         auto error = abort();
         if (!error.isOk()) {
             LOG(WARNING) << "Error calling abort in ~KeyMintOperation: " << error.getMessage();
@@ -1118,8 +1157,8 @@
                 kps.push_back(KMV1::makeKeyParameter(KMV1::TAG_PADDING, origPadding));
             }
             BeginResult beginResult;
-            auto error =
-                begin(KeyPurpose::SIGN, prefixedKeyBlob, kps, HardwareAuthToken(), &beginResult);
+            auto error = beginInternal(KeyPurpose::SIGN, prefixedKeyBlob, kps, HardwareAuthToken(),
+                                       true /* useReservedSlot */, &beginResult);
             if (!error.isOk()) {
                 errorCode = toErrorCode(error);
                 return std::vector<uint8_t>();
@@ -1355,13 +1394,14 @@
 }
 
 void KeyMintDevice::setNumFreeSlots(uint8_t numFreeSlots) {
-    mOperationSlots.setNumFreeSlots(numFreeSlots);
+    mOperationSlots->setNumFreeSlots(numFreeSlots);
 }
 
 // Constructors and helpers.
 
 KeyMintDevice::KeyMintDevice(sp<Keymaster> device, KeyMintSecurityLevel securityLevel)
-    : mDevice(device), securityLevel_(securityLevel) {
+    : mDevice(device), mOperationSlots(std::make_shared<OperationSlotManager>()),
+      securityLevel_(securityLevel) {
     if (securityLevel == KeyMintSecurityLevel::STRONGBOX) {
         setNumFreeSlots(3);
     } else {
diff --git a/keystore2/src/km_compat/km_compat.h b/keystore2/src/km_compat/km_compat.h
index c07470d..f6f5eb4 100644
--- a/keystore2/src/km_compat/km_compat.h
+++ b/keystore2/src/km_compat/km_compat.h
@@ -50,37 +50,49 @@
 using ::android::hardware::keymaster::V4_1::support::Keymaster;
 using ::ndk::ScopedAStatus;
 
-class OperationSlots {
-  private:
-    uint8_t mNumFreeSlots;
-    std::mutex mNumFreeSlotsMutex;
-
-  public:
-    void setNumFreeSlots(uint8_t numFreeSlots);
-    bool claimSlot();
-    void freeSlot();
-};
-
+class OperationSlot;
+class OperationSlotManager;
 // An abstraction for a single operation slot.
 // This contains logic to ensure that we do not free the slot multiple times,
 // e.g., if we call abort twice on the same operation.
 class OperationSlot {
+    friend OperationSlotManager;
+
   private:
-    OperationSlots* mOperationSlots;
-    bool mIsActive;
+    std::shared_ptr<OperationSlotManager> mOperationSlots;
+    std::optional<std::unique_lock<std::mutex>> mReservedGuard;
+
+  protected:
+    OperationSlot(std::shared_ptr<OperationSlotManager>,
+                  std::optional<std::unique_lock<std::mutex>> reservedGuard);
+    OperationSlot(const OperationSlot&) = delete;
+    OperationSlot& operator=(const OperationSlot&) = delete;
 
   public:
-    OperationSlot(OperationSlots* slots, bool isActive)
-        : mOperationSlots(slots), mIsActive(isActive) {}
+    OperationSlot() : mOperationSlots(nullptr), mReservedGuard(std::nullopt) {}
+    OperationSlot(OperationSlot&&) = default;
+    OperationSlot& operator=(OperationSlot&&) = default;
+    ~OperationSlot();
+};
 
+class OperationSlotManager {
+  private:
+    uint8_t mNumFreeSlots;
+    std::mutex mNumFreeSlotsMutex;
+    std::mutex mReservedSlotMutex;
+
+  public:
+    void setNumFreeSlots(uint8_t numFreeSlots);
+    static std::optional<OperationSlot>
+    claimSlot(std::shared_ptr<OperationSlotManager> operationSlots);
+    static OperationSlot claimReservedSlot(std::shared_ptr<OperationSlotManager> operationSlots);
     void freeSlot();
-    bool hasSlot() { return mIsActive; }
 };
 
 class KeyMintDevice : public aidl::android::hardware::security::keymint::BnKeyMintDevice {
   private:
     ::android::sp<Keymaster> mDevice;
-    OperationSlots mOperationSlots;
+    std::shared_ptr<OperationSlotManager> mOperationSlots;
 
   public:
     explicit KeyMintDevice(::android::sp<Keymaster>, KeyMintSecurityLevel);
@@ -109,10 +121,15 @@
     ScopedAStatus deleteKey(const std::vector<uint8_t>& in_inKeyBlob) override;
     ScopedAStatus deleteAllKeys() override;
     ScopedAStatus destroyAttestationIds() override;
+
     ScopedAStatus begin(KeyPurpose in_inPurpose, const std::vector<uint8_t>& in_inKeyBlob,
                         const std::vector<KeyParameter>& in_inParams,
                         const std::optional<HardwareAuthToken>& in_inAuthToken,
                         BeginResult* _aidl_return) override;
+    ScopedAStatus beginInternal(KeyPurpose in_inPurpose, const std::vector<uint8_t>& in_inKeyBlob,
+                                const std::vector<KeyParameter>& in_inParams,
+                                const std::optional<HardwareAuthToken>& in_inAuthToken,
+                                bool useReservedSlot, BeginResult* _aidl_return);
     ScopedAStatus deviceLocked(bool passwordOnly,
                                const std::optional<TimeStampToken>& timestampToken) override;
     ScopedAStatus earlyBootEnded() override;
@@ -143,9 +160,8 @@
 
 class KeyMintOperation : public aidl::android::hardware::security::keymint::BnKeyMintOperation {
   public:
-    KeyMintOperation(::android::sp<Keymaster> device, uint64_t operationHandle,
-                     OperationSlots* slots, bool isActive)
-        : mDevice(device), mOperationHandle(operationHandle), mOperationSlot(slots, isActive) {}
+    KeyMintOperation(::android::sp<Keymaster> device, uint64_t operationHandle, OperationSlot slot)
+        : mDevice(device), mOperationHandle(operationHandle), mOperationSlot(std::move(slot)) {}
     ~KeyMintOperation();
 
     ScopedAStatus updateAad(const std::vector<uint8_t>& input,
@@ -183,7 +199,7 @@
     std::vector<uint8_t> mUpdateBuffer;
     ::android::sp<Keymaster> mDevice;
     uint64_t mOperationHandle;
-    OperationSlot mOperationSlot;
+    std::optional<OperationSlot> mOperationSlot;
 };
 
 class SharedSecret : public aidl::android::hardware::security::sharedsecret::BnSharedSecret {
diff --git a/keystore2/src/km_compat/slot_test.cpp b/keystore2/src/km_compat/slot_test.cpp
index 3539c4d..d734970 100644
--- a/keystore2/src/km_compat/slot_test.cpp
+++ b/keystore2/src/km_compat/slot_test.cpp
@@ -26,6 +26,7 @@
 using ::aidl::android::hardware::security::keymint::BlockMode;
 using ::aidl::android::hardware::security::keymint::Certificate;
 using ::aidl::android::hardware::security::keymint::Digest;
+using ::aidl::android::hardware::security::keymint::EcCurve;
 using ::aidl::android::hardware::security::keymint::ErrorCode;
 using ::aidl::android::hardware::security::keymint::IKeyMintOperation;
 using ::aidl::android::hardware::security::keymint::KeyCharacteristics;
@@ -53,6 +54,25 @@
     return creationResult.keyBlob;
 }
 
+static bool generateECSingingKey(std::shared_ptr<KeyMintDevice> device) {
+    uint64_t now_ms = (uint64_t)time(nullptr) * 1000;
+
+    auto keyParams = std::vector<KeyParameter>({
+        KMV1::makeKeyParameter(KMV1::TAG_ALGORITHM, Algorithm::EC),
+        KMV1::makeKeyParameter(KMV1::TAG_EC_CURVE, EcCurve::P_256),
+        KMV1::makeKeyParameter(KMV1::TAG_NO_AUTH_REQUIRED, true),
+        KMV1::makeKeyParameter(KMV1::TAG_DIGEST, Digest::SHA_2_256),
+        KMV1::makeKeyParameter(KMV1::TAG_PURPOSE, KeyPurpose::SIGN),
+        KMV1::makeKeyParameter(KMV1::TAG_PURPOSE, KeyPurpose::VERIFY),
+        KMV1::makeKeyParameter(KMV1::TAG_CERTIFICATE_NOT_BEFORE, now_ms - 60 * 60 * 1000),
+        KMV1::makeKeyParameter(KMV1::TAG_CERTIFICATE_NOT_AFTER, now_ms + 60 * 60 * 1000),
+    });
+    KeyCreationResult creationResult;
+    auto status = device->generateKey(keyParams, std::nullopt /* attest_key */, &creationResult);
+    EXPECT_TRUE(status.isOk()) << status.getDescription();
+    return status.isOk();
+}
+
 static std::variant<BeginResult, ScopedAStatus> begin(std::shared_ptr<KeyMintDevice> device,
                                                       bool valid) {
     auto blob = generateAESKey(device);
@@ -69,6 +89,36 @@
     return beginResult;
 }
 
+static std::shared_ptr<KMV1::IKeyMintOperation>
+generateAndBeginECDHKeyOperation(std::shared_ptr<KeyMintDevice> device) {
+    uint64_t now_ms = (uint64_t)time(nullptr) * 1000;
+
+    auto keyParams = std::vector<KeyParameter>({
+        KMV1::makeKeyParameter(KMV1::TAG_ALGORITHM, Algorithm::EC),
+        KMV1::makeKeyParameter(KMV1::TAG_EC_CURVE, EcCurve::P_256),
+        KMV1::makeKeyParameter(KMV1::TAG_NO_AUTH_REQUIRED, true),
+        KMV1::makeKeyParameter(KMV1::TAG_DIGEST, Digest::NONE),
+        KMV1::makeKeyParameter(KMV1::TAG_PURPOSE, KeyPurpose::AGREE_KEY),
+        KMV1::makeKeyParameter(KMV1::TAG_CERTIFICATE_NOT_BEFORE, now_ms - 60 * 60 * 1000),
+        KMV1::makeKeyParameter(KMV1::TAG_CERTIFICATE_NOT_AFTER, now_ms + 60 * 60 * 1000),
+    });
+    KeyCreationResult creationResult;
+    auto status = device->generateKey(keyParams, std::nullopt /* attest_key */, &creationResult);
+    EXPECT_TRUE(status.isOk()) << status.getDescription();
+    if (!status.isOk()) {
+        return {};
+    }
+    std::vector<KeyParameter> kps;
+    BeginResult beginResult;
+    auto bstatus = device->begin(KeyPurpose::AGREE_KEY, creationResult.keyBlob, kps,
+                                 HardwareAuthToken(), &beginResult);
+    EXPECT_TRUE(status.isOk()) << status.getDescription();
+    if (status.isOk()) {
+        return beginResult.operation;
+    }
+    return {};
+}
+
 static const int NUM_SLOTS = 2;
 
 TEST(SlotTest, TestSlots) {
@@ -82,6 +132,14 @@
     auto result = begin(device, false);
     ASSERT_TRUE(std::holds_alternative<ScopedAStatus>(result));
 
+    // Software emulated operations must not leak virtual slots.
+    ASSERT_TRUE(!!generateAndBeginECDHKeyOperation(device));
+
+    // Software emulated operations must not impact virtual slots accounting.
+    // As opposed to the previous call, the software operation is kept alive.
+    auto software_op = generateAndBeginECDHKeyOperation(device);
+    ASSERT_TRUE(!!software_op);
+
     // Fill up all the slots.
     std::vector<std::shared_ptr<IKeyMintOperation>> operations;
     for (int i = 0; i < NUM_SLOTS; i++) {
@@ -96,6 +154,14 @@
     ASSERT_EQ(std::get<ScopedAStatus>(result).getServiceSpecificError(),
               static_cast<int32_t>(ErrorCode::TOO_MANY_OPERATIONS));
 
+    // At this point all slots are in use. We should still be able to generate keys which
+    // require an operation slot during generation.
+    ASSERT_TRUE(generateECSingingKey(device));
+
+    // Software emulated operations should work despite having all virtual operation slots
+    // depleted.
+    ASSERT_TRUE(generateAndBeginECDHKeyOperation(device));
+
     // TODO: I'm not sure how to generate a failing update call to test that.
 
     // Calling finish should free up a slot.