Cleanup NN callback error handling
This CL introduces a new templated class CallbackValue to handle HIDL
"return value" callbacks in a terser and more readable way.
This CL also introduces a new macro HANDLE_HAL_STATUS to return from the
current function when an error is present with the ability to append a
more descriptive error message.
Finally, this CL changes the behavior of synchronous executions. Prior
to this CL, IPreparedModel fell back to an asynchronous execution if the
synchronous execution was allowed and failed. This change instead
returns a failure if synchronous execution is allowed and fails.
Bug: 173084343
Test: mma
Change-Id: I62714a932e71dfc77401bbcb9eaaaf3d94fb9707
diff --git a/neuralnetworks/1.2/utils/src/Device.cpp b/neuralnetworks/1.2/utils/src/Device.cpp
index 6cca841..9fe0de2 100644
--- a/neuralnetworks/1.2/utils/src/Device.cpp
+++ b/neuralnetworks/1.2/utils/src/Device.cpp
@@ -47,109 +47,102 @@
namespace android::hardware::neuralnetworks::V1_2::utils {
namespace {
-nn::GeneralResult<nn::Capabilities> initCapabilities(V1_2::IDevice* device) {
+nn::GeneralResult<nn::Capabilities> capabilitiesCallback(V1_0::ErrorStatus status,
+ const Capabilities& capabilities) {
+ HANDLE_HAL_STATUS(status) << "getting capabilities failed with " << toString(status);
+ return nn::convert(capabilities);
+}
+
+nn::GeneralResult<std::string> versionStringCallback(V1_0::ErrorStatus status,
+ const hidl_string& versionString) {
+ HANDLE_HAL_STATUS(status) << "getVersionString failed with " << toString(status);
+ return versionString;
+}
+
+nn::GeneralResult<nn::DeviceType> deviceTypeCallback(V1_0::ErrorStatus status,
+ DeviceType deviceType) {
+ HANDLE_HAL_STATUS(status) << "getDeviceType failed with " << toString(status);
+ return nn::convert(deviceType);
+}
+
+nn::GeneralResult<std::vector<nn::Extension>> supportedExtensionsCallback(
+ V1_0::ErrorStatus status, const hidl_vec<Extension>& extensions) {
+ HANDLE_HAL_STATUS(status) << "getExtensions failed with " << toString(status);
+ return nn::convert(extensions);
+}
+
+nn::GeneralResult<std::pair<uint32_t, uint32_t>> numberOfCacheFilesNeededCallback(
+ V1_0::ErrorStatus status, uint32_t numModelCache, uint32_t numDataCache) {
+ HANDLE_HAL_STATUS(status) << "getNumberOfCacheFilesNeeded failed with " << toString(status);
+ if (numModelCache > nn::kMaxNumberOfCacheFiles) {
+ return NN_ERROR() << "getNumberOfCacheFilesNeeded returned numModelCache files greater "
+ "than allowed max ("
+ << numModelCache << " vs " << nn::kMaxNumberOfCacheFiles << ")";
+ }
+ if (numDataCache > nn::kMaxNumberOfCacheFiles) {
+ return NN_ERROR() << "getNumberOfCacheFilesNeeded returned numDataCache files greater "
+ "than allowed max ("
+ << numDataCache << " vs " << nn::kMaxNumberOfCacheFiles << ")";
+ }
+ return std::make_pair(numModelCache, numDataCache);
+}
+
+nn::GeneralResult<nn::Capabilities> getCapabilitiesFrom(V1_2::IDevice* device) {
CHECK(device != nullptr);
- nn::GeneralResult<nn::Capabilities> result = NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
- << "uninitialized";
- const auto cb = [&result](V1_0::ErrorStatus status, const Capabilities& capabilities) {
- if (status != V1_0::ErrorStatus::NONE) {
- const auto canonical = nn::convert(status).value_or(nn::ErrorStatus::GENERAL_FAILURE);
- result = NN_ERROR(canonical) << "getCapabilities_1_2 failed with " << toString(status);
- } else {
- result = nn::convert(capabilities);
- }
- };
+ auto cb = hal::utils::CallbackValue(capabilitiesCallback);
const auto ret = device->getCapabilities_1_2(cb);
HANDLE_TRANSPORT_FAILURE(ret);
- return result;
+ return cb.take();
}
} // namespace
-nn::GeneralResult<std::string> initVersionString(V1_2::IDevice* device) {
+nn::GeneralResult<std::string> getVersionStringFrom(V1_2::IDevice* device) {
CHECK(device != nullptr);
- nn::GeneralResult<std::string> result = NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
- << "uninitialized";
- const auto cb = [&result](V1_0::ErrorStatus status, const hidl_string& versionString) {
- if (status != V1_0::ErrorStatus::NONE) {
- const auto canonical = nn::convert(status).value_or(nn::ErrorStatus::GENERAL_FAILURE);
- result = NN_ERROR(canonical) << "getVersionString failed with " << toString(status);
- } else {
- result = versionString;
- }
- };
+ auto cb = hal::utils::CallbackValue(versionStringCallback);
const auto ret = device->getVersionString(cb);
HANDLE_TRANSPORT_FAILURE(ret);
- return result;
+ return cb.take();
}
-nn::GeneralResult<nn::DeviceType> initDeviceType(V1_2::IDevice* device) {
+nn::GeneralResult<nn::DeviceType> getDeviceTypeFrom(V1_2::IDevice* device) {
CHECK(device != nullptr);
- nn::GeneralResult<nn::DeviceType> result = NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
- << "uninitialized";
- const auto cb = [&result](V1_0::ErrorStatus status, DeviceType deviceType) {
- if (status != V1_0::ErrorStatus::NONE) {
- const auto canonical = nn::convert(status).value_or(nn::ErrorStatus::GENERAL_FAILURE);
- result = NN_ERROR(canonical) << "getDeviceType failed with " << toString(status);
- } else {
- result = nn::convert(deviceType);
- }
- };
+ auto cb = hal::utils::CallbackValue(deviceTypeCallback);
const auto ret = device->getType(cb);
HANDLE_TRANSPORT_FAILURE(ret);
- return result;
+ return cb.take();
}
-nn::GeneralResult<std::vector<nn::Extension>> initExtensions(V1_2::IDevice* device) {
+nn::GeneralResult<std::vector<nn::Extension>> getSupportedExtensionsFrom(V1_2::IDevice* device) {
CHECK(device != nullptr);
- nn::GeneralResult<std::vector<nn::Extension>> result =
- NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE) << "uninitialized";
- const auto cb = [&result](V1_0::ErrorStatus status, const hidl_vec<Extension>& extensions) {
- if (status != V1_0::ErrorStatus::NONE) {
- const auto canonical = nn::convert(status).value_or(nn::ErrorStatus::GENERAL_FAILURE);
- result = NN_ERROR(canonical) << "getExtensions failed with " << toString(status);
- } else {
- result = nn::convert(extensions);
- }
- };
+ auto cb = hal::utils::CallbackValue(supportedExtensionsCallback);
const auto ret = device->getSupportedExtensions(cb);
HANDLE_TRANSPORT_FAILURE(ret);
- return result;
+ return cb.take();
}
-nn::GeneralResult<std::pair<uint32_t, uint32_t>> initNumberOfCacheFilesNeeded(
+nn::GeneralResult<std::pair<uint32_t, uint32_t>> getNumberOfCacheFilesNeededFrom(
V1_2::IDevice* device) {
CHECK(device != nullptr);
- nn::GeneralResult<std::pair<uint32_t, uint32_t>> result =
- NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE) << "uninitialized";
- const auto cb = [&result](V1_0::ErrorStatus status, uint32_t numModelCache,
- uint32_t numDataCache) {
- if (status != V1_0::ErrorStatus::NONE) {
- const auto canonical = nn::convert(status).value_or(nn::ErrorStatus::GENERAL_FAILURE);
- result = NN_ERROR(canonical)
- << "getNumberOfCacheFilesNeeded failed with " << toString(status);
- } else {
- result = std::make_pair(numModelCache, numDataCache);
- }
- };
+ auto cb = hal::utils::CallbackValue(numberOfCacheFilesNeededCallback);
const auto ret = device->getNumberOfCacheFilesNeeded(cb);
HANDLE_TRANSPORT_FAILURE(ret);
- return result;
+ return cb.take();
}
nn::GeneralResult<std::shared_ptr<const Device>> Device::create(std::string name,
@@ -163,11 +156,11 @@
<< "V1_2::utils::Device::create must have non-null device";
}
- auto versionString = NN_TRY(initVersionString(device.get()));
- const auto deviceType = NN_TRY(initDeviceType(device.get()));
- auto extensions = NN_TRY(initExtensions(device.get()));
- auto capabilities = NN_TRY(initCapabilities(device.get()));
- const auto numberOfCacheFilesNeeded = NN_TRY(initNumberOfCacheFilesNeeded(device.get()));
+ auto versionString = NN_TRY(getVersionStringFrom(device.get()));
+ const auto deviceType = NN_TRY(getDeviceTypeFrom(device.get()));
+ auto extensions = NN_TRY(getSupportedExtensionsFrom(device.get()));
+ auto capabilities = NN_TRY(getCapabilitiesFrom(device.get()));
+ const auto numberOfCacheFilesNeeded = NN_TRY(getNumberOfCacheFilesNeededFrom(device.get()));
auto deathHandler = NN_TRY(hal::utils::DeathHandler::create(device));
return std::make_shared<const Device>(
@@ -232,28 +225,12 @@
const auto hidlModel = NN_TRY(convert(modelInShared));
- nn::GeneralResult<std::vector<bool>> result = NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
- << "uninitialized";
- auto cb = [&result, &model](V1_0::ErrorStatus status,
- const hidl_vec<bool>& supportedOperations) {
- if (status != V1_0::ErrorStatus::NONE) {
- const auto canonical = nn::convert(status).value_or(nn::ErrorStatus::GENERAL_FAILURE);
- result = NN_ERROR(canonical)
- << "getSupportedOperations_1_2 failed with " << toString(status);
- } else if (supportedOperations.size() != model.main.operations.size()) {
- result = NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
- << "getSupportedOperations_1_2 returned vector of size "
- << supportedOperations.size() << " but expected "
- << model.main.operations.size();
- } else {
- result = supportedOperations;
- }
- };
+ auto cb = hal::utils::CallbackValue(V1_0::utils::supportedOperationsCallback);
const auto ret = kDevice->getSupportedOperations_1_2(hidlModel, cb);
HANDLE_TRANSPORT_FAILURE(ret);
- return result;
+ return cb.take();
}
nn::GeneralResult<nn::SharedPreparedModel> Device::prepareModel(
@@ -266,10 +243,10 @@
NN_TRY(hal::utils::flushDataFromPointerToShared(&model, &maybeModelInShared));
const auto hidlModel = NN_TRY(convert(modelInShared));
- const auto hidlPreference = NN_TRY(V1_1::utils::convert(preference));
+ const auto hidlPreference = NN_TRY(convert(preference));
const auto hidlModelCache = NN_TRY(convert(modelCache));
const auto hidlDataCache = NN_TRY(convert(dataCache));
- const auto hidlToken = token;
+ const auto hidlToken = CacheToken{token};
const auto cb = sp<PreparedModelCallback>::make();
const auto scoped = kDeathHandler.protectCallback(cb.get());
@@ -277,10 +254,7 @@
const auto ret = kDevice->prepareModel_1_2(hidlModel, hidlPreference, hidlModelCache,
hidlDataCache, hidlToken, cb);
const auto status = HANDLE_TRANSPORT_FAILURE(ret);
- if (status != V1_0::ErrorStatus::NONE) {
- const auto canonical = nn::convert(status).value_or(nn::ErrorStatus::GENERAL_FAILURE);
- return NN_ERROR(canonical) << "prepareModel_1_2 failed with " << toString(status);
- }
+ HANDLE_HAL_STATUS(status) << "model preparation failed with " << toString(status);
return cb->get();
}
@@ -290,17 +264,14 @@
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token) const {
const auto hidlModelCache = NN_TRY(convert(modelCache));
const auto hidlDataCache = NN_TRY(convert(dataCache));
- const auto hidlToken = token;
+ const auto hidlToken = CacheToken{token};
const auto cb = sp<PreparedModelCallback>::make();
const auto scoped = kDeathHandler.protectCallback(cb.get());
const auto ret = kDevice->prepareModelFromCache(hidlModelCache, hidlDataCache, hidlToken, cb);
const auto status = HANDLE_TRANSPORT_FAILURE(ret);
- if (status != V1_0::ErrorStatus::NONE) {
- const auto canonical = nn::convert(status).value_or(nn::ErrorStatus::GENERAL_FAILURE);
- return NN_ERROR(canonical) << "prepareModelFromCache failed with " << toString(status);
- }
+ HANDLE_HAL_STATUS(status) << "model preparation from cache failed with " << toString(status);
return cb->get();
}