blob: d8a7534461c0d1cf80e25e94260e8ebdf9e15cbd [file] [log] [blame]
Lev Proleev13fdfcd2019-08-30 11:35:34 +01001/*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "neuralnetworks_hidl_hal_test"
18
19#include <android-base/logging.h>
20#include <fcntl.h>
21#include <ftw.h>
22#include <gtest/gtest.h>
23#include <hidlmemory/mapping.h>
24#include <unistd.h>
25
26#include <cstdio>
27#include <cstdlib>
28#include <random>
29#include <thread>
30
Xusong Wangcc47dff2019-10-23 10:35:07 -070031#include "1.3/Callbacks.h"
Lev Proleev13fdfcd2019-08-30 11:35:34 +010032#include "GeneratedTestHarness.h"
33#include "MemoryUtils.h"
34#include "TestHarness.h"
35#include "Utils.h"
36#include "VtsHalNeuralnetworks.h"
37
38// Forward declaration of the mobilenet generated test models in
39// frameworks/ml/nn/runtime/test/generated/.
40namespace generated_tests::mobilenet_224_gender_basic_fixed {
41const test_helper::TestModel& get_test_model();
42} // namespace generated_tests::mobilenet_224_gender_basic_fixed
43
44namespace generated_tests::mobilenet_quantized {
45const test_helper::TestModel& get_test_model();
46} // namespace generated_tests::mobilenet_quantized
47
Lev Proleev26d1bc82019-08-30 11:57:18 +010048namespace android::hardware::neuralnetworks::V1_3::vts::functional {
Lev Proleev13fdfcd2019-08-30 11:35:34 +010049
50using namespace test_helper;
Xusong Wangcc47dff2019-10-23 10:35:07 -070051using implementation::PreparedModelCallback;
Lev Proleev13fdfcd2019-08-30 11:35:34 +010052using V1_0::ErrorStatus;
53using V1_1::ExecutionPreference;
Lev Proleev26d1bc82019-08-30 11:57:18 +010054using V1_2::Constant;
Lev Proleev26d1bc82019-08-30 11:57:18 +010055using V1_2::OperationType;
Lev Proleev13fdfcd2019-08-30 11:35:34 +010056
57namespace float32_model {
58
59constexpr auto get_test_model = generated_tests::mobilenet_224_gender_basic_fixed::get_test_model;
60
61} // namespace float32_model
62
63namespace quant8_model {
64
65constexpr auto get_test_model = generated_tests::mobilenet_quantized::get_test_model;
66
67} // namespace quant8_model
68
69namespace {
70
71enum class AccessMode { READ_WRITE, READ_ONLY, WRITE_ONLY };
72
73// Creates cache handles based on provided file groups.
74// The outer vector corresponds to handles and the inner vector is for fds held by each handle.
75void createCacheHandles(const std::vector<std::vector<std::string>>& fileGroups,
76 const std::vector<AccessMode>& mode, hidl_vec<hidl_handle>* handles) {
77 handles->resize(fileGroups.size());
78 for (uint32_t i = 0; i < fileGroups.size(); i++) {
79 std::vector<int> fds;
80 for (const auto& file : fileGroups[i]) {
81 int fd;
82 if (mode[i] == AccessMode::READ_ONLY) {
83 fd = open(file.c_str(), O_RDONLY);
84 } else if (mode[i] == AccessMode::WRITE_ONLY) {
85 fd = open(file.c_str(), O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR);
86 } else if (mode[i] == AccessMode::READ_WRITE) {
87 fd = open(file.c_str(), O_RDWR | O_CREAT, S_IRUSR | S_IWUSR);
88 } else {
89 FAIL();
90 }
91 ASSERT_GE(fd, 0);
92 fds.push_back(fd);
93 }
94 native_handle_t* cacheNativeHandle = native_handle_create(fds.size(), 0);
95 ASSERT_NE(cacheNativeHandle, nullptr);
96 std::copy(fds.begin(), fds.end(), &cacheNativeHandle->data[0]);
97 (*handles)[i].setTo(cacheNativeHandle, /*shouldOwn=*/true);
98 }
99}
100
101void createCacheHandles(const std::vector<std::vector<std::string>>& fileGroups, AccessMode mode,
102 hidl_vec<hidl_handle>* handles) {
103 createCacheHandles(fileGroups, std::vector<AccessMode>(fileGroups.size(), mode), handles);
104}
105
106// Create a chain of broadcast operations. The second operand is always constant tensor [1].
107// For simplicity, activation scalar is shared. The second operand is not shared
108// in the model to let driver maintain a non-trivial size of constant data and the corresponding
109// data locations in cache.
110//
111// --------- activation --------
112// ↓ ↓ ↓ ↓
113// E.g. input -> ADD -> ADD -> ADD -> ... -> ADD -> output
114// ↑ ↑ ↑ ↑
115// [1] [1] [1] [1]
116//
117// This function assumes the operation is either ADD or MUL.
118template <typename CppType, TestOperandType operandType>
119TestModel createLargeTestModelImpl(TestOperationType op, uint32_t len) {
120 EXPECT_TRUE(op == TestOperationType::ADD || op == TestOperationType::MUL);
121
122 // Model operations and operands.
123 std::vector<TestOperation> operations(len);
124 std::vector<TestOperand> operands(len * 2 + 2);
125
126 // The activation scalar, value = 0.
127 operands[0] = {
128 .type = TestOperandType::INT32,
129 .dimensions = {},
130 .numberOfConsumers = len,
131 .scale = 0.0f,
132 .zeroPoint = 0,
133 .lifetime = TestOperandLifeTime::CONSTANT_COPY,
134 .data = TestBuffer::createFromVector<int32_t>({0}),
135 };
136
137 // The buffer value of the constant second operand. The logical value is always 1.0f.
138 CppType bufferValue;
139 // The scale of the first and second operand.
140 float scale1, scale2;
141 if (operandType == TestOperandType::TENSOR_FLOAT32) {
142 bufferValue = 1.0f;
143 scale1 = 0.0f;
144 scale2 = 0.0f;
145 } else if (op == TestOperationType::ADD) {
146 bufferValue = 1;
147 scale1 = 1.0f;
148 scale2 = 1.0f;
149 } else {
150 // To satisfy the constraint on quant8 MUL: input0.scale * input1.scale < output.scale,
151 // set input1 to have scale = 0.5f and bufferValue = 2, i.e. 1.0f in floating point.
152 bufferValue = 2;
153 scale1 = 1.0f;
154 scale2 = 0.5f;
155 }
156
157 for (uint32_t i = 0; i < len; i++) {
158 const uint32_t firstInputIndex = i * 2 + 1;
159 const uint32_t secondInputIndex = firstInputIndex + 1;
160 const uint32_t outputIndex = secondInputIndex + 1;
161
162 // The first operation input.
163 operands[firstInputIndex] = {
164 .type = operandType,
165 .dimensions = {1},
166 .numberOfConsumers = 1,
167 .scale = scale1,
168 .zeroPoint = 0,
169 .lifetime = (i == 0 ? TestOperandLifeTime::MODEL_INPUT
170 : TestOperandLifeTime::TEMPORARY_VARIABLE),
171 .data = (i == 0 ? TestBuffer::createFromVector<CppType>({1}) : TestBuffer()),
172 };
173
174 // The second operation input, value = 1.
175 operands[secondInputIndex] = {
176 .type = operandType,
177 .dimensions = {1},
178 .numberOfConsumers = 1,
179 .scale = scale2,
180 .zeroPoint = 0,
181 .lifetime = TestOperandLifeTime::CONSTANT_COPY,
182 .data = TestBuffer::createFromVector<CppType>({bufferValue}),
183 };
184
185 // The operation. All operations share the same activation scalar.
186 // The output operand is created as an input in the next iteration of the loop, in the case
187 // of all but the last member of the chain; and after the loop as a model output, in the
188 // case of the last member of the chain.
189 operations[i] = {
190 .type = op,
191 .inputs = {firstInputIndex, secondInputIndex, /*activation scalar*/ 0},
192 .outputs = {outputIndex},
193 };
194 }
195
196 // For TestOperationType::ADD, output = 1 + 1 * len = len + 1
197 // For TestOperationType::MUL, output = 1 * 1 ^ len = 1
198 CppType outputResult = static_cast<CppType>(op == TestOperationType::ADD ? len + 1u : 1u);
199
200 // The model output.
201 operands.back() = {
202 .type = operandType,
203 .dimensions = {1},
204 .numberOfConsumers = 0,
205 .scale = scale1,
206 .zeroPoint = 0,
207 .lifetime = TestOperandLifeTime::MODEL_OUTPUT,
208 .data = TestBuffer::createFromVector<CppType>({outputResult}),
209 };
210
211 return {
212 .operands = std::move(operands),
213 .operations = std::move(operations),
214 .inputIndexes = {1},
215 .outputIndexes = {len * 2 + 1},
216 .isRelaxed = false,
217 };
218}
219
220} // namespace
221
222// Tag for the compilation caching tests.
223class CompilationCachingTestBase : public testing::Test {
224 protected:
225 CompilationCachingTestBase(sp<IDevice> device, OperandType type)
226 : kDevice(std::move(device)), kOperandType(type) {}
227
228 void SetUp() override {
229 testing::Test::SetUp();
230 ASSERT_NE(kDevice.get(), nullptr);
231
232 // Create cache directory. The cache directory and a temporary cache file is always created
Xusong Wangcc47dff2019-10-23 10:35:07 -0700233 // to test the behavior of prepareModelFromCache_1_3, even when caching is not supported.
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100234 char cacheDirTemp[] = "/data/local/tmp/TestCompilationCachingXXXXXX";
235 char* cacheDir = mkdtemp(cacheDirTemp);
236 ASSERT_NE(cacheDir, nullptr);
237 mCacheDir = cacheDir;
238 mCacheDir.push_back('/');
239
240 Return<void> ret = kDevice->getNumberOfCacheFilesNeeded(
241 [this](ErrorStatus status, uint32_t numModelCache, uint32_t numDataCache) {
242 EXPECT_EQ(ErrorStatus::NONE, status);
243 mNumModelCache = numModelCache;
244 mNumDataCache = numDataCache;
245 });
246 EXPECT_TRUE(ret.isOk());
247 mIsCachingSupported = mNumModelCache > 0 || mNumDataCache > 0;
248
249 // Create empty cache files.
250 mTmpCache = mCacheDir + "tmp";
251 for (uint32_t i = 0; i < mNumModelCache; i++) {
252 mModelCache.push_back({mCacheDir + "model" + std::to_string(i)});
253 }
254 for (uint32_t i = 0; i < mNumDataCache; i++) {
255 mDataCache.push_back({mCacheDir + "data" + std::to_string(i)});
256 }
257 // Dummy handles, use AccessMode::WRITE_ONLY for createCacheHandles to create files.
258 hidl_vec<hidl_handle> modelHandle, dataHandle, tmpHandle;
259 createCacheHandles(mModelCache, AccessMode::WRITE_ONLY, &modelHandle);
260 createCacheHandles(mDataCache, AccessMode::WRITE_ONLY, &dataHandle);
261 createCacheHandles({{mTmpCache}}, AccessMode::WRITE_ONLY, &tmpHandle);
262
263 if (!mIsCachingSupported) {
264 LOG(INFO) << "NN VTS: Early termination of test because vendor service does not "
265 "support compilation caching.";
266 std::cout << "[ ] Early termination of test because vendor service does not "
267 "support compilation caching."
268 << std::endl;
269 }
270 }
271
272 void TearDown() override {
273 // If the test passes, remove the tmp directory. Otherwise, keep it for debugging purposes.
274 if (!testing::Test::HasFailure()) {
275 // Recursively remove the cache directory specified by mCacheDir.
276 auto callback = [](const char* entry, const struct stat*, int, struct FTW*) {
277 return remove(entry);
278 };
279 nftw(mCacheDir.c_str(), callback, 128, FTW_DEPTH | FTW_MOUNT | FTW_PHYS);
280 }
281 testing::Test::TearDown();
282 }
283
284 // Model and examples creators. According to kOperandType, the following methods will return
285 // either float32 model/examples or the quant8 variant.
286 TestModel createTestModel() {
287 if (kOperandType == OperandType::TENSOR_FLOAT32) {
288 return float32_model::get_test_model();
289 } else {
290 return quant8_model::get_test_model();
291 }
292 }
293
294 TestModel createLargeTestModel(OperationType op, uint32_t len) {
295 if (kOperandType == OperandType::TENSOR_FLOAT32) {
296 return createLargeTestModelImpl<float, TestOperandType::TENSOR_FLOAT32>(
297 static_cast<TestOperationType>(op), len);
298 } else {
299 return createLargeTestModelImpl<uint8_t, TestOperandType::TENSOR_QUANT8_ASYMM>(
300 static_cast<TestOperationType>(op), len);
301 }
302 }
303
304 // See if the service can handle the model.
305 bool isModelFullySupported(const Model& model) {
306 bool fullySupportsModel = false;
Lev Proleev26d1bc82019-08-30 11:57:18 +0100307 Return<void> supportedCall = kDevice->getSupportedOperations_1_3(
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100308 model,
309 [&fullySupportsModel, &model](ErrorStatus status, const hidl_vec<bool>& supported) {
310 ASSERT_EQ(ErrorStatus::NONE, status);
311 ASSERT_EQ(supported.size(), model.operations.size());
312 fullySupportsModel = std::all_of(supported.begin(), supported.end(),
313 [](bool valid) { return valid; });
314 });
315 EXPECT_TRUE(supportedCall.isOk());
316 return fullySupportsModel;
317 }
318
319 void saveModelToCache(const Model& model, const hidl_vec<hidl_handle>& modelCache,
320 const hidl_vec<hidl_handle>& dataCache,
321 sp<IPreparedModel>* preparedModel = nullptr) {
322 if (preparedModel != nullptr) *preparedModel = nullptr;
323
324 // Launch prepare model.
325 sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
326 hidl_array<uint8_t, sizeof(mToken)> cacheToken(mToken);
327 Return<ErrorStatus> prepareLaunchStatus =
Lev Proleev26d1bc82019-08-30 11:57:18 +0100328 kDevice->prepareModel_1_3(model, ExecutionPreference::FAST_SINGLE_ANSWER,
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100329 modelCache, dataCache, cacheToken, preparedModelCallback);
330 ASSERT_TRUE(prepareLaunchStatus.isOk());
331 ASSERT_EQ(static_cast<ErrorStatus>(prepareLaunchStatus), ErrorStatus::NONE);
332
333 // Retrieve prepared model.
334 preparedModelCallback->wait();
335 ASSERT_EQ(preparedModelCallback->getStatus(), ErrorStatus::NONE);
336 if (preparedModel != nullptr) {
337 *preparedModel = IPreparedModel::castFrom(preparedModelCallback->getPreparedModel())
338 .withDefault(nullptr);
339 }
340 }
341
342 bool checkEarlyTermination(ErrorStatus status) {
343 if (status == ErrorStatus::GENERAL_FAILURE) {
344 LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
345 "save the prepared model that it does not support.";
346 std::cout << "[ ] Early termination of test because vendor service cannot "
347 "save the prepared model that it does not support."
348 << std::endl;
349 return true;
350 }
351 return false;
352 }
353
354 bool checkEarlyTermination(const Model& model) {
355 if (!isModelFullySupported(model)) {
356 LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
357 "prepare model that it does not support.";
358 std::cout << "[ ] Early termination of test because vendor service cannot "
359 "prepare model that it does not support."
360 << std::endl;
361 return true;
362 }
363 return false;
364 }
365
366 void prepareModelFromCache(const hidl_vec<hidl_handle>& modelCache,
367 const hidl_vec<hidl_handle>& dataCache,
368 sp<IPreparedModel>* preparedModel, ErrorStatus* status) {
369 // Launch prepare model from cache.
370 sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
371 hidl_array<uint8_t, sizeof(mToken)> cacheToken(mToken);
Xusong Wangcc47dff2019-10-23 10:35:07 -0700372 Return<ErrorStatus> prepareLaunchStatus = kDevice->prepareModelFromCache_1_3(
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100373 modelCache, dataCache, cacheToken, preparedModelCallback);
374 ASSERT_TRUE(prepareLaunchStatus.isOk());
375 if (static_cast<ErrorStatus>(prepareLaunchStatus) != ErrorStatus::NONE) {
376 *preparedModel = nullptr;
377 *status = static_cast<ErrorStatus>(prepareLaunchStatus);
378 return;
379 }
380
381 // Retrieve prepared model.
382 preparedModelCallback->wait();
383 *status = preparedModelCallback->getStatus();
384 *preparedModel = IPreparedModel::castFrom(preparedModelCallback->getPreparedModel())
385 .withDefault(nullptr);
386 }
387
388 // Absolute path to the temporary cache directory.
389 std::string mCacheDir;
390
391 // Groups of file paths for model and data cache in the tmp cache directory, initialized with
392 // outer_size = mNum{Model|Data}Cache, inner_size = 1. The outer vector corresponds to handles
393 // and the inner vector is for fds held by each handle.
394 std::vector<std::vector<std::string>> mModelCache;
395 std::vector<std::vector<std::string>> mDataCache;
396
397 // A separate temporary file path in the tmp cache directory.
398 std::string mTmpCache;
399
400 uint8_t mToken[static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN)] = {};
401 uint32_t mNumModelCache;
402 uint32_t mNumDataCache;
403 uint32_t mIsCachingSupported;
404
405 const sp<IDevice> kDevice;
406 // The primary data type of the testModel.
407 const OperandType kOperandType;
408};
409
410using CompilationCachingTestParam = std::tuple<NamedDevice, OperandType>;
411
412// A parameterized fixture of CompilationCachingTestBase. Every test will run twice, with the first
413// pass running with float32 models and the second pass running with quant8 models.
414class CompilationCachingTest : public CompilationCachingTestBase,
415 public testing::WithParamInterface<CompilationCachingTestParam> {
416 protected:
417 CompilationCachingTest()
418 : CompilationCachingTestBase(getData(std::get<NamedDevice>(GetParam())),
419 std::get<OperandType>(GetParam())) {}
420};
421
422TEST_P(CompilationCachingTest, CacheSavingAndRetrieval) {
423 // Create test HIDL model and compile.
424 const TestModel& testModel = createTestModel();
425 const Model model = createModel(testModel);
426 if (checkEarlyTermination(model)) return;
427 sp<IPreparedModel> preparedModel = nullptr;
428
429 // Save the compilation to cache.
430 {
431 hidl_vec<hidl_handle> modelCache, dataCache;
432 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
433 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
434 saveModelToCache(model, modelCache, dataCache);
435 }
436
437 // Retrieve preparedModel from cache.
438 {
439 preparedModel = nullptr;
440 ErrorStatus status;
441 hidl_vec<hidl_handle> modelCache, dataCache;
442 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
443 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
444 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
445 if (!mIsCachingSupported) {
446 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
447 ASSERT_EQ(preparedModel, nullptr);
448 return;
449 } else if (checkEarlyTermination(status)) {
450 ASSERT_EQ(preparedModel, nullptr);
451 return;
452 } else {
453 ASSERT_EQ(status, ErrorStatus::NONE);
454 ASSERT_NE(preparedModel, nullptr);
455 }
456 }
457
458 // Execute and verify results.
459 EvaluatePreparedModel(preparedModel, testModel,
460 /*testDynamicOutputShape=*/false);
461}
462
463TEST_P(CompilationCachingTest, CacheSavingAndRetrievalNonZeroOffset) {
464 // Create test HIDL model and compile.
465 const TestModel& testModel = createTestModel();
466 const Model model = createModel(testModel);
467 if (checkEarlyTermination(model)) return;
468 sp<IPreparedModel> preparedModel = nullptr;
469
470 // Save the compilation to cache.
471 {
472 hidl_vec<hidl_handle> modelCache, dataCache;
473 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
474 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
475 uint8_t dummyBytes[] = {0, 0};
476 // Write a dummy integer to the cache.
477 // The driver should be able to handle non-empty cache and non-zero fd offset.
478 for (uint32_t i = 0; i < modelCache.size(); i++) {
479 ASSERT_EQ(write(modelCache[i].getNativeHandle()->data[0], &dummyBytes,
480 sizeof(dummyBytes)),
481 sizeof(dummyBytes));
482 }
483 for (uint32_t i = 0; i < dataCache.size(); i++) {
484 ASSERT_EQ(
485 write(dataCache[i].getNativeHandle()->data[0], &dummyBytes, sizeof(dummyBytes)),
486 sizeof(dummyBytes));
487 }
488 saveModelToCache(model, modelCache, dataCache);
489 }
490
491 // Retrieve preparedModel from cache.
492 {
493 preparedModel = nullptr;
494 ErrorStatus status;
495 hidl_vec<hidl_handle> modelCache, dataCache;
496 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
497 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
498 uint8_t dummyByte = 0;
499 // Advance the offset of each handle by one byte.
500 // The driver should be able to handle non-zero fd offset.
501 for (uint32_t i = 0; i < modelCache.size(); i++) {
502 ASSERT_GE(read(modelCache[i].getNativeHandle()->data[0], &dummyByte, 1), 0);
503 }
504 for (uint32_t i = 0; i < dataCache.size(); i++) {
505 ASSERT_GE(read(dataCache[i].getNativeHandle()->data[0], &dummyByte, 1), 0);
506 }
507 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
508 if (!mIsCachingSupported) {
509 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
510 ASSERT_EQ(preparedModel, nullptr);
511 return;
512 } else if (checkEarlyTermination(status)) {
513 ASSERT_EQ(preparedModel, nullptr);
514 return;
515 } else {
516 ASSERT_EQ(status, ErrorStatus::NONE);
517 ASSERT_NE(preparedModel, nullptr);
518 }
519 }
520
521 // Execute and verify results.
522 EvaluatePreparedModel(preparedModel, testModel,
523 /*testDynamicOutputShape=*/false);
524}
525
526TEST_P(CompilationCachingTest, SaveToCacheInvalidNumCache) {
527 // Create test HIDL model and compile.
528 const TestModel& testModel = createTestModel();
529 const Model model = createModel(testModel);
530 if (checkEarlyTermination(model)) return;
531
532 // Test with number of model cache files greater than mNumModelCache.
533 {
534 hidl_vec<hidl_handle> modelCache, dataCache;
535 // Pass an additional cache file for model cache.
536 mModelCache.push_back({mTmpCache});
537 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
538 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
539 mModelCache.pop_back();
540 sp<IPreparedModel> preparedModel = nullptr;
541 saveModelToCache(model, modelCache, dataCache, &preparedModel);
542 ASSERT_NE(preparedModel, nullptr);
543 // Execute and verify results.
544 EvaluatePreparedModel(preparedModel, testModel,
545 /*testDynamicOutputShape=*/false);
546 // Check if prepareModelFromCache fails.
547 preparedModel = nullptr;
548 ErrorStatus status;
549 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
550 if (status != ErrorStatus::INVALID_ARGUMENT) {
551 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
552 }
553 ASSERT_EQ(preparedModel, nullptr);
554 }
555
556 // Test with number of model cache files smaller than mNumModelCache.
557 if (mModelCache.size() > 0) {
558 hidl_vec<hidl_handle> modelCache, dataCache;
559 // Pop out the last cache file.
560 auto tmp = mModelCache.back();
561 mModelCache.pop_back();
562 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
563 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
564 mModelCache.push_back(tmp);
565 sp<IPreparedModel> preparedModel = nullptr;
566 saveModelToCache(model, modelCache, dataCache, &preparedModel);
567 ASSERT_NE(preparedModel, nullptr);
568 // Execute and verify results.
569 EvaluatePreparedModel(preparedModel, testModel,
570 /*testDynamicOutputShape=*/false);
571 // Check if prepareModelFromCache fails.
572 preparedModel = nullptr;
573 ErrorStatus status;
574 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
575 if (status != ErrorStatus::INVALID_ARGUMENT) {
576 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
577 }
578 ASSERT_EQ(preparedModel, nullptr);
579 }
580
581 // Test with number of data cache files greater than mNumDataCache.
582 {
583 hidl_vec<hidl_handle> modelCache, dataCache;
584 // Pass an additional cache file for data cache.
585 mDataCache.push_back({mTmpCache});
586 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
587 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
588 mDataCache.pop_back();
589 sp<IPreparedModel> preparedModel = nullptr;
590 saveModelToCache(model, modelCache, dataCache, &preparedModel);
591 ASSERT_NE(preparedModel, nullptr);
592 // Execute and verify results.
593 EvaluatePreparedModel(preparedModel, testModel,
594 /*testDynamicOutputShape=*/false);
595 // Check if prepareModelFromCache fails.
596 preparedModel = nullptr;
597 ErrorStatus status;
598 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
599 if (status != ErrorStatus::INVALID_ARGUMENT) {
600 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
601 }
602 ASSERT_EQ(preparedModel, nullptr);
603 }
604
605 // Test with number of data cache files smaller than mNumDataCache.
606 if (mDataCache.size() > 0) {
607 hidl_vec<hidl_handle> modelCache, dataCache;
608 // Pop out the last cache file.
609 auto tmp = mDataCache.back();
610 mDataCache.pop_back();
611 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
612 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
613 mDataCache.push_back(tmp);
614 sp<IPreparedModel> preparedModel = nullptr;
615 saveModelToCache(model, modelCache, dataCache, &preparedModel);
616 ASSERT_NE(preparedModel, nullptr);
617 // Execute and verify results.
618 EvaluatePreparedModel(preparedModel, testModel,
619 /*testDynamicOutputShape=*/false);
620 // Check if prepareModelFromCache fails.
621 preparedModel = nullptr;
622 ErrorStatus status;
623 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
624 if (status != ErrorStatus::INVALID_ARGUMENT) {
625 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
626 }
627 ASSERT_EQ(preparedModel, nullptr);
628 }
629}
630
631TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidNumCache) {
632 // Create test HIDL model and compile.
633 const TestModel& testModel = createTestModel();
634 const Model model = createModel(testModel);
635 if (checkEarlyTermination(model)) return;
636
637 // Save the compilation to cache.
638 {
639 hidl_vec<hidl_handle> modelCache, dataCache;
640 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
641 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
642 saveModelToCache(model, modelCache, dataCache);
643 }
644
645 // Test with number of model cache files greater than mNumModelCache.
646 {
647 sp<IPreparedModel> preparedModel = nullptr;
648 ErrorStatus status;
649 hidl_vec<hidl_handle> modelCache, dataCache;
650 mModelCache.push_back({mTmpCache});
651 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
652 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
653 mModelCache.pop_back();
654 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
655 if (status != ErrorStatus::GENERAL_FAILURE) {
656 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
657 }
658 ASSERT_EQ(preparedModel, nullptr);
659 }
660
661 // Test with number of model cache files smaller than mNumModelCache.
662 if (mModelCache.size() > 0) {
663 sp<IPreparedModel> preparedModel = nullptr;
664 ErrorStatus status;
665 hidl_vec<hidl_handle> modelCache, dataCache;
666 auto tmp = mModelCache.back();
667 mModelCache.pop_back();
668 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
669 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
670 mModelCache.push_back(tmp);
671 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
672 if (status != ErrorStatus::GENERAL_FAILURE) {
673 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
674 }
675 ASSERT_EQ(preparedModel, nullptr);
676 }
677
678 // Test with number of data cache files greater than mNumDataCache.
679 {
680 sp<IPreparedModel> preparedModel = nullptr;
681 ErrorStatus status;
682 hidl_vec<hidl_handle> modelCache, dataCache;
683 mDataCache.push_back({mTmpCache});
684 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
685 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
686 mDataCache.pop_back();
687 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
688 if (status != ErrorStatus::GENERAL_FAILURE) {
689 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
690 }
691 ASSERT_EQ(preparedModel, nullptr);
692 }
693
694 // Test with number of data cache files smaller than mNumDataCache.
695 if (mDataCache.size() > 0) {
696 sp<IPreparedModel> preparedModel = nullptr;
697 ErrorStatus status;
698 hidl_vec<hidl_handle> modelCache, dataCache;
699 auto tmp = mDataCache.back();
700 mDataCache.pop_back();
701 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
702 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
703 mDataCache.push_back(tmp);
704 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
705 if (status != ErrorStatus::GENERAL_FAILURE) {
706 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
707 }
708 ASSERT_EQ(preparedModel, nullptr);
709 }
710}
711
712TEST_P(CompilationCachingTest, SaveToCacheInvalidNumFd) {
713 // Create test HIDL model and compile.
714 const TestModel& testModel = createTestModel();
715 const Model model = createModel(testModel);
716 if (checkEarlyTermination(model)) return;
717
718 // Go through each handle in model cache, test with NumFd greater than 1.
719 for (uint32_t i = 0; i < mNumModelCache; i++) {
720 hidl_vec<hidl_handle> modelCache, dataCache;
721 // Pass an invalid number of fds for handle i.
722 mModelCache[i].push_back(mTmpCache);
723 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
724 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
725 mModelCache[i].pop_back();
726 sp<IPreparedModel> preparedModel = nullptr;
727 saveModelToCache(model, modelCache, dataCache, &preparedModel);
728 ASSERT_NE(preparedModel, nullptr);
729 // Execute and verify results.
730 EvaluatePreparedModel(preparedModel, testModel,
731 /*testDynamicOutputShape=*/false);
732 // Check if prepareModelFromCache fails.
733 preparedModel = nullptr;
734 ErrorStatus status;
735 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
736 if (status != ErrorStatus::INVALID_ARGUMENT) {
737 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
738 }
739 ASSERT_EQ(preparedModel, nullptr);
740 }
741
742 // Go through each handle in model cache, test with NumFd equal to 0.
743 for (uint32_t i = 0; i < mNumModelCache; i++) {
744 hidl_vec<hidl_handle> modelCache, dataCache;
745 // Pass an invalid number of fds for handle i.
746 auto tmp = mModelCache[i].back();
747 mModelCache[i].pop_back();
748 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
749 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
750 mModelCache[i].push_back(tmp);
751 sp<IPreparedModel> preparedModel = nullptr;
752 saveModelToCache(model, modelCache, dataCache, &preparedModel);
753 ASSERT_NE(preparedModel, nullptr);
754 // Execute and verify results.
755 EvaluatePreparedModel(preparedModel, testModel,
756 /*testDynamicOutputShape=*/false);
757 // Check if prepareModelFromCache fails.
758 preparedModel = nullptr;
759 ErrorStatus status;
760 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
761 if (status != ErrorStatus::INVALID_ARGUMENT) {
762 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
763 }
764 ASSERT_EQ(preparedModel, nullptr);
765 }
766
767 // Go through each handle in data cache, test with NumFd greater than 1.
768 for (uint32_t i = 0; i < mNumDataCache; i++) {
769 hidl_vec<hidl_handle> modelCache, dataCache;
770 // Pass an invalid number of fds for handle i.
771 mDataCache[i].push_back(mTmpCache);
772 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
773 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
774 mDataCache[i].pop_back();
775 sp<IPreparedModel> preparedModel = nullptr;
776 saveModelToCache(model, modelCache, dataCache, &preparedModel);
777 ASSERT_NE(preparedModel, nullptr);
778 // Execute and verify results.
779 EvaluatePreparedModel(preparedModel, testModel,
780 /*testDynamicOutputShape=*/false);
781 // Check if prepareModelFromCache fails.
782 preparedModel = nullptr;
783 ErrorStatus status;
784 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
785 if (status != ErrorStatus::INVALID_ARGUMENT) {
786 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
787 }
788 ASSERT_EQ(preparedModel, nullptr);
789 }
790
791 // Go through each handle in data cache, test with NumFd equal to 0.
792 for (uint32_t i = 0; i < mNumDataCache; i++) {
793 hidl_vec<hidl_handle> modelCache, dataCache;
794 // Pass an invalid number of fds for handle i.
795 auto tmp = mDataCache[i].back();
796 mDataCache[i].pop_back();
797 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
798 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
799 mDataCache[i].push_back(tmp);
800 sp<IPreparedModel> preparedModel = nullptr;
801 saveModelToCache(model, modelCache, dataCache, &preparedModel);
802 ASSERT_NE(preparedModel, nullptr);
803 // Execute and verify results.
804 EvaluatePreparedModel(preparedModel, testModel,
805 /*testDynamicOutputShape=*/false);
806 // Check if prepareModelFromCache fails.
807 preparedModel = nullptr;
808 ErrorStatus status;
809 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
810 if (status != ErrorStatus::INVALID_ARGUMENT) {
811 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
812 }
813 ASSERT_EQ(preparedModel, nullptr);
814 }
815}
816
817TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidNumFd) {
818 // Create test HIDL model and compile.
819 const TestModel& testModel = createTestModel();
820 const Model model = createModel(testModel);
821 if (checkEarlyTermination(model)) return;
822
823 // Save the compilation to cache.
824 {
825 hidl_vec<hidl_handle> modelCache, dataCache;
826 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
827 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
828 saveModelToCache(model, modelCache, dataCache);
829 }
830
831 // Go through each handle in model cache, test with NumFd greater than 1.
832 for (uint32_t i = 0; i < mNumModelCache; i++) {
833 sp<IPreparedModel> preparedModel = nullptr;
834 ErrorStatus status;
835 hidl_vec<hidl_handle> modelCache, dataCache;
836 mModelCache[i].push_back(mTmpCache);
837 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
838 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
839 mModelCache[i].pop_back();
840 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
841 if (status != ErrorStatus::GENERAL_FAILURE) {
842 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
843 }
844 ASSERT_EQ(preparedModel, nullptr);
845 }
846
847 // Go through each handle in model cache, test with NumFd equal to 0.
848 for (uint32_t i = 0; i < mNumModelCache; i++) {
849 sp<IPreparedModel> preparedModel = nullptr;
850 ErrorStatus status;
851 hidl_vec<hidl_handle> modelCache, dataCache;
852 auto tmp = mModelCache[i].back();
853 mModelCache[i].pop_back();
854 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
855 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
856 mModelCache[i].push_back(tmp);
857 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
858 if (status != ErrorStatus::GENERAL_FAILURE) {
859 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
860 }
861 ASSERT_EQ(preparedModel, nullptr);
862 }
863
864 // Go through each handle in data cache, test with NumFd greater than 1.
865 for (uint32_t i = 0; i < mNumDataCache; i++) {
866 sp<IPreparedModel> preparedModel = nullptr;
867 ErrorStatus status;
868 hidl_vec<hidl_handle> modelCache, dataCache;
869 mDataCache[i].push_back(mTmpCache);
870 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
871 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
872 mDataCache[i].pop_back();
873 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
874 if (status != ErrorStatus::GENERAL_FAILURE) {
875 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
876 }
877 ASSERT_EQ(preparedModel, nullptr);
878 }
879
880 // Go through each handle in data cache, test with NumFd equal to 0.
881 for (uint32_t i = 0; i < mNumDataCache; i++) {
882 sp<IPreparedModel> preparedModel = nullptr;
883 ErrorStatus status;
884 hidl_vec<hidl_handle> modelCache, dataCache;
885 auto tmp = mDataCache[i].back();
886 mDataCache[i].pop_back();
887 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
888 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
889 mDataCache[i].push_back(tmp);
890 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
891 if (status != ErrorStatus::GENERAL_FAILURE) {
892 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
893 }
894 ASSERT_EQ(preparedModel, nullptr);
895 }
896}
897
898TEST_P(CompilationCachingTest, SaveToCacheInvalidAccessMode) {
899 // Create test HIDL model and compile.
900 const TestModel& testModel = createTestModel();
901 const Model model = createModel(testModel);
902 if (checkEarlyTermination(model)) return;
903 std::vector<AccessMode> modelCacheMode(mNumModelCache, AccessMode::READ_WRITE);
904 std::vector<AccessMode> dataCacheMode(mNumDataCache, AccessMode::READ_WRITE);
905
906 // Go through each handle in model cache, test with invalid access mode.
907 for (uint32_t i = 0; i < mNumModelCache; i++) {
908 hidl_vec<hidl_handle> modelCache, dataCache;
909 modelCacheMode[i] = AccessMode::READ_ONLY;
910 createCacheHandles(mModelCache, modelCacheMode, &modelCache);
911 createCacheHandles(mDataCache, dataCacheMode, &dataCache);
912 modelCacheMode[i] = AccessMode::READ_WRITE;
913 sp<IPreparedModel> preparedModel = nullptr;
914 saveModelToCache(model, modelCache, dataCache, &preparedModel);
915 ASSERT_NE(preparedModel, nullptr);
916 // Execute and verify results.
917 EvaluatePreparedModel(preparedModel, testModel,
918 /*testDynamicOutputShape=*/false);
919 // Check if prepareModelFromCache fails.
920 preparedModel = nullptr;
921 ErrorStatus status;
922 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
923 if (status != ErrorStatus::INVALID_ARGUMENT) {
924 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
925 }
926 ASSERT_EQ(preparedModel, nullptr);
927 }
928
929 // Go through each handle in data cache, test with invalid access mode.
930 for (uint32_t i = 0; i < mNumDataCache; i++) {
931 hidl_vec<hidl_handle> modelCache, dataCache;
932 dataCacheMode[i] = AccessMode::READ_ONLY;
933 createCacheHandles(mModelCache, modelCacheMode, &modelCache);
934 createCacheHandles(mDataCache, dataCacheMode, &dataCache);
935 dataCacheMode[i] = AccessMode::READ_WRITE;
936 sp<IPreparedModel> preparedModel = nullptr;
937 saveModelToCache(model, modelCache, dataCache, &preparedModel);
938 ASSERT_NE(preparedModel, nullptr);
939 // Execute and verify results.
940 EvaluatePreparedModel(preparedModel, testModel,
941 /*testDynamicOutputShape=*/false);
942 // Check if prepareModelFromCache fails.
943 preparedModel = nullptr;
944 ErrorStatus status;
945 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
946 if (status != ErrorStatus::INVALID_ARGUMENT) {
947 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
948 }
949 ASSERT_EQ(preparedModel, nullptr);
950 }
951}
952
953TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidAccessMode) {
954 // Create test HIDL model and compile.
955 const TestModel& testModel = createTestModel();
956 const Model model = createModel(testModel);
957 if (checkEarlyTermination(model)) return;
958 std::vector<AccessMode> modelCacheMode(mNumModelCache, AccessMode::READ_WRITE);
959 std::vector<AccessMode> dataCacheMode(mNumDataCache, AccessMode::READ_WRITE);
960
961 // Save the compilation to cache.
962 {
963 hidl_vec<hidl_handle> modelCache, dataCache;
964 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
965 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
966 saveModelToCache(model, modelCache, dataCache);
967 }
968
969 // Go through each handle in model cache, test with invalid access mode.
970 for (uint32_t i = 0; i < mNumModelCache; i++) {
971 sp<IPreparedModel> preparedModel = nullptr;
972 ErrorStatus status;
973 hidl_vec<hidl_handle> modelCache, dataCache;
974 modelCacheMode[i] = AccessMode::WRITE_ONLY;
975 createCacheHandles(mModelCache, modelCacheMode, &modelCache);
976 createCacheHandles(mDataCache, dataCacheMode, &dataCache);
977 modelCacheMode[i] = AccessMode::READ_WRITE;
978 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
979 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
980 ASSERT_EQ(preparedModel, nullptr);
981 }
982
983 // Go through each handle in data cache, test with invalid access mode.
984 for (uint32_t i = 0; i < mNumDataCache; i++) {
985 sp<IPreparedModel> preparedModel = nullptr;
986 ErrorStatus status;
987 hidl_vec<hidl_handle> modelCache, dataCache;
988 dataCacheMode[i] = AccessMode::WRITE_ONLY;
989 createCacheHandles(mModelCache, modelCacheMode, &modelCache);
990 createCacheHandles(mDataCache, dataCacheMode, &dataCache);
991 dataCacheMode[i] = AccessMode::READ_WRITE;
992 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
993 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
994 ASSERT_EQ(preparedModel, nullptr);
995 }
996}
997
998// Copy file contents between file groups.
999// The outer vector corresponds to handles and the inner vector is for fds held by each handle.
1000// The outer vector sizes must match and the inner vectors must have size = 1.
1001static void copyCacheFiles(const std::vector<std::vector<std::string>>& from,
1002 const std::vector<std::vector<std::string>>& to) {
1003 constexpr size_t kBufferSize = 1000000;
1004 uint8_t buffer[kBufferSize];
1005
1006 ASSERT_EQ(from.size(), to.size());
1007 for (uint32_t i = 0; i < from.size(); i++) {
1008 ASSERT_EQ(from[i].size(), 1u);
1009 ASSERT_EQ(to[i].size(), 1u);
1010 int fromFd = open(from[i][0].c_str(), O_RDONLY);
1011 int toFd = open(to[i][0].c_str(), O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR);
1012 ASSERT_GE(fromFd, 0);
1013 ASSERT_GE(toFd, 0);
1014
1015 ssize_t readBytes;
1016 while ((readBytes = read(fromFd, &buffer, kBufferSize)) > 0) {
1017 ASSERT_EQ(write(toFd, &buffer, readBytes), readBytes);
1018 }
1019 ASSERT_GE(readBytes, 0);
1020
1021 close(fromFd);
1022 close(toFd);
1023 }
1024}
1025
1026// Number of operations in the large test model.
1027constexpr uint32_t kLargeModelSize = 100;
1028constexpr uint32_t kNumIterationsTOCTOU = 100;
1029
1030TEST_P(CompilationCachingTest, SaveToCache_TOCTOU) {
1031 if (!mIsCachingSupported) return;
1032
1033 // Create test models and check if fully supported by the service.
1034 const TestModel testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
1035 const Model modelMul = createModel(testModelMul);
1036 if (checkEarlyTermination(modelMul)) return;
1037 const TestModel testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
1038 const Model modelAdd = createModel(testModelAdd);
1039 if (checkEarlyTermination(modelAdd)) return;
1040
1041 // Save the modelMul compilation to cache.
1042 auto modelCacheMul = mModelCache;
1043 for (auto& cache : modelCacheMul) {
1044 cache[0].append("_mul");
1045 }
1046 {
1047 hidl_vec<hidl_handle> modelCache, dataCache;
1048 createCacheHandles(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
1049 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1050 saveModelToCache(modelMul, modelCache, dataCache);
1051 }
1052
1053 // Use a different token for modelAdd.
1054 mToken[0]++;
1055
1056 // This test is probabilistic, so we run it multiple times.
1057 for (uint32_t i = 0; i < kNumIterationsTOCTOU; i++) {
1058 // Save the modelAdd compilation to cache.
1059 {
1060 hidl_vec<hidl_handle> modelCache, dataCache;
1061 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1062 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1063
1064 // Spawn a thread to copy the cache content concurrently while saving to cache.
1065 std::thread thread(copyCacheFiles, std::cref(modelCacheMul), std::cref(mModelCache));
1066 saveModelToCache(modelAdd, modelCache, dataCache);
1067 thread.join();
1068 }
1069
1070 // Retrieve preparedModel from cache.
1071 {
1072 sp<IPreparedModel> preparedModel = nullptr;
1073 ErrorStatus status;
1074 hidl_vec<hidl_handle> modelCache, dataCache;
1075 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1076 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1077 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1078
1079 // The preparation may fail or succeed, but must not crash. If the preparation succeeds,
1080 // the prepared model must be executed with the correct result and not crash.
1081 if (status != ErrorStatus::NONE) {
1082 ASSERT_EQ(preparedModel, nullptr);
1083 } else {
1084 ASSERT_NE(preparedModel, nullptr);
1085 EvaluatePreparedModel(preparedModel, testModelAdd,
1086 /*testDynamicOutputShape=*/false);
1087 }
1088 }
1089 }
1090}
1091
1092TEST_P(CompilationCachingTest, PrepareFromCache_TOCTOU) {
1093 if (!mIsCachingSupported) return;
1094
1095 // Create test models and check if fully supported by the service.
1096 const TestModel testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
1097 const Model modelMul = createModel(testModelMul);
1098 if (checkEarlyTermination(modelMul)) return;
1099 const TestModel testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
1100 const Model modelAdd = createModel(testModelAdd);
1101 if (checkEarlyTermination(modelAdd)) return;
1102
1103 // Save the modelMul compilation to cache.
1104 auto modelCacheMul = mModelCache;
1105 for (auto& cache : modelCacheMul) {
1106 cache[0].append("_mul");
1107 }
1108 {
1109 hidl_vec<hidl_handle> modelCache, dataCache;
1110 createCacheHandles(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
1111 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1112 saveModelToCache(modelMul, modelCache, dataCache);
1113 }
1114
1115 // Use a different token for modelAdd.
1116 mToken[0]++;
1117
1118 // This test is probabilistic, so we run it multiple times.
1119 for (uint32_t i = 0; i < kNumIterationsTOCTOU; i++) {
1120 // Save the modelAdd compilation to cache.
1121 {
1122 hidl_vec<hidl_handle> modelCache, dataCache;
1123 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1124 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1125 saveModelToCache(modelAdd, modelCache, dataCache);
1126 }
1127
1128 // Retrieve preparedModel from cache.
1129 {
1130 sp<IPreparedModel> preparedModel = nullptr;
1131 ErrorStatus status;
1132 hidl_vec<hidl_handle> modelCache, dataCache;
1133 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1134 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1135
1136 // Spawn a thread to copy the cache content concurrently while preparing from cache.
1137 std::thread thread(copyCacheFiles, std::cref(modelCacheMul), std::cref(mModelCache));
1138 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1139 thread.join();
1140
1141 // The preparation may fail or succeed, but must not crash. If the preparation succeeds,
1142 // the prepared model must be executed with the correct result and not crash.
1143 if (status != ErrorStatus::NONE) {
1144 ASSERT_EQ(preparedModel, nullptr);
1145 } else {
1146 ASSERT_NE(preparedModel, nullptr);
1147 EvaluatePreparedModel(preparedModel, testModelAdd,
1148 /*testDynamicOutputShape=*/false);
1149 }
1150 }
1151 }
1152}
1153
1154TEST_P(CompilationCachingTest, ReplaceSecuritySensitiveCache) {
1155 if (!mIsCachingSupported) return;
1156
1157 // Create test models and check if fully supported by the service.
1158 const TestModel testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
1159 const Model modelMul = createModel(testModelMul);
1160 if (checkEarlyTermination(modelMul)) return;
1161 const TestModel testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
1162 const Model modelAdd = createModel(testModelAdd);
1163 if (checkEarlyTermination(modelAdd)) return;
1164
1165 // Save the modelMul compilation to cache.
1166 auto modelCacheMul = mModelCache;
1167 for (auto& cache : modelCacheMul) {
1168 cache[0].append("_mul");
1169 }
1170 {
1171 hidl_vec<hidl_handle> modelCache, dataCache;
1172 createCacheHandles(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
1173 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1174 saveModelToCache(modelMul, modelCache, dataCache);
1175 }
1176
1177 // Use a different token for modelAdd.
1178 mToken[0]++;
1179
1180 // Save the modelAdd compilation to cache.
1181 {
1182 hidl_vec<hidl_handle> modelCache, dataCache;
1183 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1184 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1185 saveModelToCache(modelAdd, modelCache, dataCache);
1186 }
1187
1188 // Replace the model cache of modelAdd with modelMul.
1189 copyCacheFiles(modelCacheMul, mModelCache);
1190
1191 // Retrieve the preparedModel from cache, expect failure.
1192 {
1193 sp<IPreparedModel> preparedModel = nullptr;
1194 ErrorStatus status;
1195 hidl_vec<hidl_handle> modelCache, dataCache;
1196 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1197 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1198 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1199 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
1200 ASSERT_EQ(preparedModel, nullptr);
1201 }
1202}
1203
1204static const auto kNamedDeviceChoices = testing::ValuesIn(getNamedDevices());
1205static const auto kOperandTypeChoices =
1206 testing::Values(OperandType::TENSOR_FLOAT32, OperandType::TENSOR_QUANT8_ASYMM);
1207
1208std::string printCompilationCachingTest(
1209 const testing::TestParamInfo<CompilationCachingTestParam>& info) {
1210 const auto& [namedDevice, operandType] = info.param;
1211 const std::string type = (operandType == OperandType::TENSOR_FLOAT32 ? "float32" : "quant8");
1212 return gtestCompliantName(getName(namedDevice) + "_" + type);
1213}
1214
1215INSTANTIATE_TEST_CASE_P(TestCompilationCaching, CompilationCachingTest,
1216 testing::Combine(kNamedDeviceChoices, kOperandTypeChoices),
1217 printCompilationCachingTest);
1218
1219using CompilationCachingSecurityTestParam = std::tuple<NamedDevice, OperandType, uint32_t>;
1220
1221class CompilationCachingSecurityTest
1222 : public CompilationCachingTestBase,
1223 public testing::WithParamInterface<CompilationCachingSecurityTestParam> {
1224 protected:
1225 CompilationCachingSecurityTest()
1226 : CompilationCachingTestBase(getData(std::get<NamedDevice>(GetParam())),
1227 std::get<OperandType>(GetParam())) {}
1228
1229 void SetUp() {
1230 CompilationCachingTestBase::SetUp();
1231 generator.seed(kSeed);
1232 }
1233
1234 // Get a random integer within a closed range [lower, upper].
1235 template <typename T>
1236 T getRandomInt(T lower, T upper) {
1237 std::uniform_int_distribution<T> dis(lower, upper);
1238 return dis(generator);
1239 }
1240
1241 // Randomly flip one single bit of the cache entry.
1242 void flipOneBitOfCache(const std::string& filename, bool* skip) {
1243 FILE* pFile = fopen(filename.c_str(), "r+");
1244 ASSERT_EQ(fseek(pFile, 0, SEEK_END), 0);
1245 long int fileSize = ftell(pFile);
1246 if (fileSize == 0) {
1247 fclose(pFile);
1248 *skip = true;
1249 return;
1250 }
1251 ASSERT_EQ(fseek(pFile, getRandomInt(0l, fileSize - 1), SEEK_SET), 0);
1252 int readByte = fgetc(pFile);
1253 ASSERT_NE(readByte, EOF);
1254 ASSERT_EQ(fseek(pFile, -1, SEEK_CUR), 0);
1255 ASSERT_NE(fputc(static_cast<uint8_t>(readByte) ^ (1U << getRandomInt(0, 7)), pFile), EOF);
1256 fclose(pFile);
1257 *skip = false;
1258 }
1259
1260 // Randomly append bytes to the cache entry.
1261 void appendBytesToCache(const std::string& filename, bool* skip) {
1262 FILE* pFile = fopen(filename.c_str(), "a");
1263 uint32_t appendLength = getRandomInt(1, 256);
1264 for (uint32_t i = 0; i < appendLength; i++) {
1265 ASSERT_NE(fputc(getRandomInt<uint8_t>(0, 255), pFile), EOF);
1266 }
1267 fclose(pFile);
1268 *skip = false;
1269 }
1270
1271 enum class ExpectedResult { GENERAL_FAILURE, NOT_CRASH };
1272
1273 // Test if the driver behaves as expected when given corrupted cache or token.
1274 // The modifier will be invoked after save to cache but before prepare from cache.
1275 // The modifier accepts one pointer argument "skip" as the returning value, indicating
1276 // whether the test should be skipped or not.
1277 void testCorruptedCache(ExpectedResult expected, std::function<void(bool*)> modifier) {
1278 const TestModel& testModel = createTestModel();
1279 const Model model = createModel(testModel);
1280 if (checkEarlyTermination(model)) return;
1281
1282 // Save the compilation to cache.
1283 {
1284 hidl_vec<hidl_handle> modelCache, dataCache;
1285 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1286 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1287 saveModelToCache(model, modelCache, dataCache);
1288 }
1289
1290 bool skip = false;
1291 modifier(&skip);
1292 if (skip) return;
1293
1294 // Retrieve preparedModel from cache.
1295 {
1296 sp<IPreparedModel> preparedModel = nullptr;
1297 ErrorStatus status;
1298 hidl_vec<hidl_handle> modelCache, dataCache;
1299 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1300 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1301 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1302
1303 switch (expected) {
1304 case ExpectedResult::GENERAL_FAILURE:
1305 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
1306 ASSERT_EQ(preparedModel, nullptr);
1307 break;
1308 case ExpectedResult::NOT_CRASH:
1309 ASSERT_EQ(preparedModel == nullptr, status != ErrorStatus::NONE);
1310 break;
1311 default:
1312 FAIL();
1313 }
1314 }
1315 }
1316
1317 const uint32_t kSeed = std::get<uint32_t>(GetParam());
1318 std::mt19937 generator;
1319};
1320
1321TEST_P(CompilationCachingSecurityTest, CorruptedModelCache) {
1322 if (!mIsCachingSupported) return;
1323 for (uint32_t i = 0; i < mNumModelCache; i++) {
1324 testCorruptedCache(ExpectedResult::GENERAL_FAILURE,
1325 [this, i](bool* skip) { flipOneBitOfCache(mModelCache[i][0], skip); });
1326 }
1327}
1328
1329TEST_P(CompilationCachingSecurityTest, WrongLengthModelCache) {
1330 if (!mIsCachingSupported) return;
1331 for (uint32_t i = 0; i < mNumModelCache; i++) {
1332 testCorruptedCache(ExpectedResult::GENERAL_FAILURE,
1333 [this, i](bool* skip) { appendBytesToCache(mModelCache[i][0], skip); });
1334 }
1335}
1336
1337TEST_P(CompilationCachingSecurityTest, CorruptedDataCache) {
1338 if (!mIsCachingSupported) return;
1339 for (uint32_t i = 0; i < mNumDataCache; i++) {
1340 testCorruptedCache(ExpectedResult::NOT_CRASH,
1341 [this, i](bool* skip) { flipOneBitOfCache(mDataCache[i][0], skip); });
1342 }
1343}
1344
1345TEST_P(CompilationCachingSecurityTest, WrongLengthDataCache) {
1346 if (!mIsCachingSupported) return;
1347 for (uint32_t i = 0; i < mNumDataCache; i++) {
1348 testCorruptedCache(ExpectedResult::NOT_CRASH,
1349 [this, i](bool* skip) { appendBytesToCache(mDataCache[i][0], skip); });
1350 }
1351}
1352
1353TEST_P(CompilationCachingSecurityTest, WrongToken) {
1354 if (!mIsCachingSupported) return;
1355 testCorruptedCache(ExpectedResult::GENERAL_FAILURE, [this](bool* skip) {
1356 // Randomly flip one single bit in mToken.
1357 uint32_t ind =
1358 getRandomInt(0u, static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN) - 1);
1359 mToken[ind] ^= (1U << getRandomInt(0, 7));
1360 *skip = false;
1361 });
1362}
1363
1364std::string printCompilationCachingSecurityTest(
1365 const testing::TestParamInfo<CompilationCachingSecurityTestParam>& info) {
1366 const auto& [namedDevice, operandType, seed] = info.param;
1367 const std::string type = (operandType == OperandType::TENSOR_FLOAT32 ? "float32" : "quant8");
1368 return gtestCompliantName(getName(namedDevice) + "_" + type + "_" + std::to_string(seed));
1369}
1370
1371INSTANTIATE_TEST_CASE_P(TestCompilationCaching, CompilationCachingSecurityTest,
1372 testing::Combine(kNamedDeviceChoices, kOperandTypeChoices,
1373 testing::Range(0U, 10U)),
1374 printCompilationCachingSecurityTest);
1375
Lev Proleev26d1bc82019-08-30 11:57:18 +01001376} // namespace android::hardware::neuralnetworks::V1_3::vts::functional