blob: e53b0a8df9e043299c689d40db68c959295c7cfe [file] [log] [blame]
Lev Proleev900c28a2021-01-26 19:40:20 +00001/*
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "MockBuffer.h"
18#include "MockDevice.h"
19#include "MockPreparedModel.h"
20
21#include <aidl/android/hardware/neuralnetworks/BnDevice.h>
22#include <android/binder_auto_utils.h>
23#include <android/binder_status.h>
24#include <gmock/gmock.h>
25#include <gtest/gtest.h>
26#include <nnapi/IDevice.h>
27#include <nnapi/TypeUtils.h>
28#include <nnapi/Types.h>
29#include <nnapi/hal/aidl/Device.h>
30
31#include <functional>
32#include <memory>
33#include <string>
34
35namespace aidl::android::hardware::neuralnetworks::utils {
36namespace {
37
38namespace nn = ::android::nn;
39using ::testing::_;
40using ::testing::DoAll;
41using ::testing::Invoke;
42using ::testing::InvokeWithoutArgs;
43using ::testing::SetArgPointee;
44
45const nn::Model kSimpleModel = {
46 .main = {.operands = {{.type = nn::OperandType::TENSOR_FLOAT32,
47 .dimensions = {1},
48 .lifetime = nn::Operand::LifeTime::SUBGRAPH_INPUT},
49 {.type = nn::OperandType::TENSOR_FLOAT32,
50 .dimensions = {1},
51 .lifetime = nn::Operand::LifeTime::SUBGRAPH_OUTPUT}},
52 .operations = {{.type = nn::OperationType::RELU, .inputs = {0}, .outputs = {1}}},
53 .inputIndexes = {0},
54 .outputIndexes = {1}}};
55
56const std::string kName = "Google-MockV1";
57const std::string kInvalidName = "";
58const std::shared_ptr<BnDevice> kInvalidDevice;
59constexpr PerformanceInfo kNoPerformanceInfo = {.execTime = std::numeric_limits<float>::max(),
60 .powerUsage = std::numeric_limits<float>::max()};
61constexpr NumberOfCacheFiles kNumberOfCacheFiles = {.numModelCache = nn::kMaxNumberOfCacheFiles,
62 .numDataCache = nn::kMaxNumberOfCacheFiles};
63
64constexpr auto makeStatusOk = [] { return ndk::ScopedAStatus::ok(); };
65
66std::shared_ptr<MockDevice> createMockDevice() {
67 const auto mockDevice = MockDevice::create();
68
69 // Setup default actions for each relevant call.
70 ON_CALL(*mockDevice, getVersionString(_))
71 .WillByDefault(DoAll(SetArgPointee<0>(kName), InvokeWithoutArgs(makeStatusOk)));
72 ON_CALL(*mockDevice, getType(_))
73 .WillByDefault(
74 DoAll(SetArgPointee<0>(DeviceType::OTHER), InvokeWithoutArgs(makeStatusOk)));
75 ON_CALL(*mockDevice, getSupportedExtensions(_))
76 .WillByDefault(DoAll(SetArgPointee<0>(std::vector<Extension>{}),
77 InvokeWithoutArgs(makeStatusOk)));
78 ON_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_))
79 .WillByDefault(
80 DoAll(SetArgPointee<0>(kNumberOfCacheFiles), InvokeWithoutArgs(makeStatusOk)));
81 ON_CALL(*mockDevice, getCapabilities(_))
82 .WillByDefault(
83 DoAll(SetArgPointee<0>(Capabilities{
84 .relaxedFloat32toFloat16PerformanceScalar = kNoPerformanceInfo,
85 .relaxedFloat32toFloat16PerformanceTensor = kNoPerformanceInfo,
86 .ifPerformance = kNoPerformanceInfo,
87 .whilePerformance = kNoPerformanceInfo,
88 }),
89 InvokeWithoutArgs(makeStatusOk)));
90
91 // These EXPECT_CALL(...).Times(testing::AnyNumber()) calls are to suppress warnings on the
92 // uninteresting methods calls.
93 EXPECT_CALL(*mockDevice, getVersionString(_)).Times(testing::AnyNumber());
94 EXPECT_CALL(*mockDevice, getType(_)).Times(testing::AnyNumber());
95 EXPECT_CALL(*mockDevice, getSupportedExtensions(_)).Times(testing::AnyNumber());
96 EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_)).Times(testing::AnyNumber());
97 EXPECT_CALL(*mockDevice, getCapabilities(_)).Times(testing::AnyNumber());
98
99 return mockDevice;
100}
101
102constexpr auto makePreparedModelReturnImpl =
103 [](ErrorStatus launchStatus, ErrorStatus returnStatus,
104 const std::shared_ptr<MockPreparedModel>& preparedModel,
105 const std::shared_ptr<IPreparedModelCallback>& cb) {
106 cb->notify(returnStatus, preparedModel);
107 if (launchStatus == ErrorStatus::NONE) {
108 return ndk::ScopedAStatus::ok();
109 }
110 return ndk::ScopedAStatus::fromServiceSpecificError(static_cast<int32_t>(launchStatus));
111 };
112
113auto makePreparedModelReturn(ErrorStatus launchStatus, ErrorStatus returnStatus,
114 const std::shared_ptr<MockPreparedModel>& preparedModel) {
115 return [launchStatus, returnStatus, preparedModel](
116 const Model& /*model*/, ExecutionPreference /*preference*/,
117 Priority /*priority*/, const int64_t& /*deadline*/,
118 const std::vector<ndk::ScopedFileDescriptor>& /*modelCache*/,
119 const std::vector<ndk::ScopedFileDescriptor>& /*dataCache*/,
120 const std::vector<uint8_t>& /*token*/,
121 const std::shared_ptr<IPreparedModelCallback>& cb) -> ndk::ScopedAStatus {
122 return makePreparedModelReturnImpl(launchStatus, returnStatus, preparedModel, cb);
123 };
124}
125
126auto makePreparedModelFromCacheReturn(ErrorStatus launchStatus, ErrorStatus returnStatus,
127 const std::shared_ptr<MockPreparedModel>& preparedModel) {
128 return [launchStatus, returnStatus, preparedModel](
129 const int64_t& /*deadline*/,
130 const std::vector<ndk::ScopedFileDescriptor>& /*modelCache*/,
131 const std::vector<ndk::ScopedFileDescriptor>& /*dataCache*/,
132 const std::vector<uint8_t>& /*token*/,
133 const std::shared_ptr<IPreparedModelCallback>& cb) {
134 return makePreparedModelReturnImpl(launchStatus, returnStatus, preparedModel, cb);
135 };
136}
137
138constexpr auto makeGeneralFailure = [] {
139 return ndk::ScopedAStatus::fromServiceSpecificError(
140 static_cast<int32_t>(ErrorStatus::GENERAL_FAILURE));
141};
142constexpr auto makeGeneralTransportFailure = [] {
143 return ndk::ScopedAStatus::fromStatus(STATUS_NO_MEMORY);
144};
145constexpr auto makeDeadObjectFailure = [] {
146 return ndk::ScopedAStatus::fromStatus(STATUS_DEAD_OBJECT);
147};
148
149} // namespace
150
151TEST(DeviceTest, invalidName) {
152 // run test
153 const auto device = MockDevice::create();
154 const auto result = Device::create(kInvalidName, device);
155
156 // verify result
157 ASSERT_FALSE(result.has_value());
158 EXPECT_EQ(result.error().code, nn::ErrorStatus::INVALID_ARGUMENT);
159}
160
161TEST(DeviceTest, invalidDevice) {
162 // run test
163 const auto result = Device::create(kName, kInvalidDevice);
164
165 // verify result
166 ASSERT_FALSE(result.has_value());
167 EXPECT_EQ(result.error().code, nn::ErrorStatus::INVALID_ARGUMENT);
168}
169
170TEST(DeviceTest, getVersionStringError) {
171 // setup call
172 const auto mockDevice = createMockDevice();
173 EXPECT_CALL(*mockDevice, getVersionString(_))
174 .Times(1)
175 .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
176
177 // run test
178 const auto result = Device::create(kName, mockDevice);
179
180 // verify result
181 ASSERT_FALSE(result.has_value());
182 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
183}
184
185TEST(DeviceTest, getVersionStringTransportFailure) {
186 // setup call
187 const auto mockDevice = createMockDevice();
188 EXPECT_CALL(*mockDevice, getVersionString(_))
189 .Times(1)
190 .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
191
192 // run test
193 const auto result = Device::create(kName, mockDevice);
194
195 // verify result
196 ASSERT_FALSE(result.has_value());
197 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
198}
199
200TEST(DeviceTest, getVersionStringDeadObject) {
201 // setup call
202 const auto mockDevice = createMockDevice();
203 EXPECT_CALL(*mockDevice, getVersionString(_))
204 .Times(1)
205 .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
206
207 // run test
208 const auto result = Device::create(kName, mockDevice);
209
210 // verify result
211 ASSERT_FALSE(result.has_value());
212 EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
213}
214
215TEST(DeviceTest, getTypeError) {
216 // setup call
217 const auto mockDevice = createMockDevice();
218 EXPECT_CALL(*mockDevice, getType(_)).Times(1).WillOnce(InvokeWithoutArgs(makeGeneralFailure));
219
220 // run test
221 const auto result = Device::create(kName, mockDevice);
222
223 // verify result
224 ASSERT_FALSE(result.has_value());
225 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
226}
227
228TEST(DeviceTest, getTypeTransportFailure) {
229 // setup call
230 const auto mockDevice = createMockDevice();
231 EXPECT_CALL(*mockDevice, getType(_))
232 .Times(1)
233 .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
234
235 // run test
236 const auto result = Device::create(kName, mockDevice);
237
238 // verify result
239 ASSERT_FALSE(result.has_value());
240 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
241}
242
243TEST(DeviceTest, getTypeDeadObject) {
244 // setup call
245 const auto mockDevice = createMockDevice();
246 EXPECT_CALL(*mockDevice, getType(_))
247 .Times(1)
248 .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
249
250 // run test
251 const auto result = Device::create(kName, mockDevice);
252
253 // verify result
254 ASSERT_FALSE(result.has_value());
255 EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
256}
257
258TEST(DeviceTest, getSupportedExtensionsError) {
259 // setup call
260 const auto mockDevice = createMockDevice();
261 EXPECT_CALL(*mockDevice, getSupportedExtensions(_))
262 .Times(1)
263 .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
264
265 // run test
266 const auto result = Device::create(kName, mockDevice);
267
268 // verify result
269 ASSERT_FALSE(result.has_value());
270 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
271}
272
273TEST(DeviceTest, getSupportedExtensionsTransportFailure) {
274 // setup call
275 const auto mockDevice = createMockDevice();
276 EXPECT_CALL(*mockDevice, getSupportedExtensions(_))
277 .Times(1)
278 .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
279
280 // run test
281 const auto result = Device::create(kName, mockDevice);
282
283 // verify result
284 ASSERT_FALSE(result.has_value());
285 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
286}
287
288TEST(DeviceTest, getSupportedExtensionsDeadObject) {
289 // setup call
290 const auto mockDevice = createMockDevice();
291 EXPECT_CALL(*mockDevice, getSupportedExtensions(_))
292 .Times(1)
293 .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
294
295 // run test
296 const auto result = Device::create(kName, mockDevice);
297
298 // verify result
299 ASSERT_FALSE(result.has_value());
300 EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
301}
302
303TEST(DeviceTest, getNumberOfCacheFilesNeededError) {
304 // setup call
305 const auto mockDevice = createMockDevice();
306 EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_))
307 .Times(1)
308 .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
309
310 // run test
311 const auto result = Device::create(kName, mockDevice);
312
313 // verify result
314 ASSERT_FALSE(result.has_value());
315 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
316}
317
318TEST(DeviceTest, dataCacheFilesExceedsSpecifiedMax) {
319 // setup test
320 const auto mockDevice = createMockDevice();
321 EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_))
322 .Times(1)
323 .WillOnce(DoAll(SetArgPointee<0>(NumberOfCacheFiles{
324 .numModelCache = nn::kMaxNumberOfCacheFiles + 1,
325 .numDataCache = nn::kMaxNumberOfCacheFiles}),
326 InvokeWithoutArgs(makeStatusOk)));
327
328 // run test
329 const auto result = Device::create(kName, mockDevice);
330
331 // verify result
332 ASSERT_FALSE(result.has_value());
333 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
334}
335
336TEST(DeviceTest, modelCacheFilesExceedsSpecifiedMax) {
337 // setup test
338 const auto mockDevice = createMockDevice();
339 EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_))
340 .Times(1)
341 .WillOnce(DoAll(SetArgPointee<0>(NumberOfCacheFiles{
342 .numModelCache = nn::kMaxNumberOfCacheFiles,
343 .numDataCache = nn::kMaxNumberOfCacheFiles + 1}),
344 InvokeWithoutArgs(makeStatusOk)));
345
346 // run test
347 const auto result = Device::create(kName, mockDevice);
348
349 // verify result
350 ASSERT_FALSE(result.has_value());
351 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
352}
353
354TEST(DeviceTest, getNumberOfCacheFilesNeededTransportFailure) {
355 // setup call
356 const auto mockDevice = createMockDevice();
357 EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_))
358 .Times(1)
359 .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
360
361 // run test
362 const auto result = Device::create(kName, mockDevice);
363
364 // verify result
365 ASSERT_FALSE(result.has_value());
366 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
367}
368
369TEST(DeviceTest, getNumberOfCacheFilesNeededDeadObject) {
370 // setup call
371 const auto mockDevice = createMockDevice();
372 EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_))
373 .Times(1)
374 .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
375
376 // run test
377 const auto result = Device::create(kName, mockDevice);
378
379 // verify result
380 ASSERT_FALSE(result.has_value());
381 EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
382}
383
384TEST(DeviceTest, getCapabilitiesError) {
385 // setup call
386 const auto mockDevice = createMockDevice();
387 EXPECT_CALL(*mockDevice, getCapabilities(_))
388 .Times(1)
389 .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
390
391 // run test
392 const auto result = Device::create(kName, mockDevice);
393
394 // verify result
395 ASSERT_FALSE(result.has_value());
396 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
397}
398
399TEST(DeviceTest, getCapabilitiesTransportFailure) {
400 // setup call
401 const auto mockDevice = createMockDevice();
402 EXPECT_CALL(*mockDevice, getCapabilities(_))
403 .Times(1)
404 .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
405
406 // run test
407 const auto result = Device::create(kName, mockDevice);
408
409 // verify result
410 ASSERT_FALSE(result.has_value());
411 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
412}
413
414TEST(DeviceTest, getCapabilitiesDeadObject) {
415 // setup call
416 const auto mockDevice = createMockDevice();
417 EXPECT_CALL(*mockDevice, getCapabilities(_))
418 .Times(1)
419 .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
420
421 // run test
422 const auto result = Device::create(kName, mockDevice);
423
424 // verify result
425 ASSERT_FALSE(result.has_value());
426 EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
427}
428
429TEST(DeviceTest, getName) {
430 // setup call
431 const auto mockDevice = createMockDevice();
432 const auto device = Device::create(kName, mockDevice).value();
433
434 // run test
435 const auto& name = device->getName();
436
437 // verify result
438 EXPECT_EQ(name, kName);
439}
440
441TEST(DeviceTest, getFeatureLevel) {
442 // setup call
443 const auto mockDevice = createMockDevice();
444 const auto device = Device::create(kName, mockDevice).value();
445
446 // run test
447 const auto featureLevel = device->getFeatureLevel();
448
449 // verify result
450 EXPECT_EQ(featureLevel, nn::Version::ANDROID_S);
451}
452
453TEST(DeviceTest, getCachedData) {
454 // setup call
455 const auto mockDevice = createMockDevice();
456 EXPECT_CALL(*mockDevice, getVersionString(_)).Times(1);
457 EXPECT_CALL(*mockDevice, getType(_)).Times(1);
458 EXPECT_CALL(*mockDevice, getSupportedExtensions(_)).Times(1);
459 EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_)).Times(1);
460 EXPECT_CALL(*mockDevice, getCapabilities(_)).Times(1);
461
462 const auto result = Device::create(kName, mockDevice);
463 ASSERT_TRUE(result.has_value())
464 << "Failed with " << result.error().code << ": " << result.error().message;
465 const auto& device = result.value();
466
467 // run test and verify results
468 EXPECT_EQ(device->getVersionString(), device->getVersionString());
469 EXPECT_EQ(device->getType(), device->getType());
470 EXPECT_EQ(device->getSupportedExtensions(), device->getSupportedExtensions());
471 EXPECT_EQ(device->getNumberOfCacheFilesNeeded(), device->getNumberOfCacheFilesNeeded());
472 EXPECT_EQ(device->getCapabilities(), device->getCapabilities());
473}
474
475TEST(DeviceTest, getSupportedOperations) {
476 // setup call
477 const auto mockDevice = createMockDevice();
478 const auto device = Device::create(kName, mockDevice).value();
479 EXPECT_CALL(*mockDevice, getSupportedOperations(_, _))
480 .Times(1)
481 .WillOnce(DoAll(
482 SetArgPointee<1>(std::vector<bool>(kSimpleModel.main.operations.size(), true)),
483 InvokeWithoutArgs(makeStatusOk)));
484
485 // run test
486 const auto result = device->getSupportedOperations(kSimpleModel);
487
488 // verify result
489 ASSERT_TRUE(result.has_value())
490 << "Failed with " << result.error().code << ": " << result.error().message;
491 const auto& supportedOperations = result.value();
492 EXPECT_EQ(supportedOperations.size(), kSimpleModel.main.operations.size());
493 EXPECT_THAT(supportedOperations, Each(testing::IsTrue()));
494}
495
496TEST(DeviceTest, getSupportedOperationsError) {
497 // setup call
498 const auto mockDevice = createMockDevice();
499 const auto device = Device::create(kName, mockDevice).value();
500 EXPECT_CALL(*mockDevice, getSupportedOperations(_, _))
501 .Times(1)
502 .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
503
504 // run test
505 const auto result = device->getSupportedOperations(kSimpleModel);
506
507 // verify result
508 ASSERT_FALSE(result.has_value());
509 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
510}
511
512TEST(DeviceTest, getSupportedOperationsTransportFailure) {
513 // setup call
514 const auto mockDevice = createMockDevice();
515 const auto device = Device::create(kName, mockDevice).value();
516 EXPECT_CALL(*mockDevice, getSupportedOperations(_, _))
517 .Times(1)
518 .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
519
520 // run test
521 const auto result = device->getSupportedOperations(kSimpleModel);
522
523 // verify result
524 ASSERT_FALSE(result.has_value());
525 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
526}
527
528TEST(DeviceTest, getSupportedOperationsDeadObject) {
529 // setup call
530 const auto mockDevice = createMockDevice();
531 const auto device = Device::create(kName, mockDevice).value();
532 EXPECT_CALL(*mockDevice, getSupportedOperations(_, _))
533 .Times(1)
534 .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
535
536 // run test
537 const auto result = device->getSupportedOperations(kSimpleModel);
538
539 // verify result
540 ASSERT_FALSE(result.has_value());
541 EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
542}
543
544TEST(DeviceTest, prepareModel) {
545 // setup call
546 const auto mockDevice = createMockDevice();
547 const auto device = Device::create(kName, mockDevice).value();
548 const auto mockPreparedModel = MockPreparedModel::create();
549 EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
550 .Times(1)
551 .WillOnce(Invoke(makePreparedModelReturn(ErrorStatus::NONE, ErrorStatus::NONE,
552 mockPreparedModel)));
553
554 // run test
555 const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
556 nn::Priority::DEFAULT, {}, {}, {}, {});
557
558 // verify result
559 ASSERT_TRUE(result.has_value())
560 << "Failed with " << result.error().code << ": " << result.error().message;
561 EXPECT_NE(result.value(), nullptr);
562}
563
564TEST(DeviceTest, prepareModelLaunchError) {
565 // setup call
566 const auto mockDevice = createMockDevice();
567 const auto device = Device::create(kName, mockDevice).value();
568 EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
569 .Times(1)
570 .WillOnce(Invoke(makePreparedModelReturn(ErrorStatus::GENERAL_FAILURE,
571 ErrorStatus::GENERAL_FAILURE, nullptr)));
572
573 // run test
574 const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
575 nn::Priority::DEFAULT, {}, {}, {}, {});
576
577 // verify result
578 ASSERT_FALSE(result.has_value());
579 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
580}
581
582TEST(DeviceTest, prepareModelReturnError) {
583 // setup call
584 const auto mockDevice = createMockDevice();
585 const auto device = Device::create(kName, mockDevice).value();
586 EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
587 .Times(1)
588 .WillOnce(Invoke(makePreparedModelReturn(ErrorStatus::NONE,
589 ErrorStatus::GENERAL_FAILURE, nullptr)));
590
591 // run test
592 const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
593 nn::Priority::DEFAULT, {}, {}, {}, {});
594
595 // verify result
596 ASSERT_FALSE(result.has_value());
597 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
598}
599
600TEST(DeviceTest, prepareModelNullptrError) {
601 // setup call
602 const auto mockDevice = createMockDevice();
603 const auto device = Device::create(kName, mockDevice).value();
604 EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
605 .Times(1)
606 .WillOnce(
607 Invoke(makePreparedModelReturn(ErrorStatus::NONE, ErrorStatus::NONE, nullptr)));
608
609 // run test
610 const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
611 nn::Priority::DEFAULT, {}, {}, {}, {});
612
613 // verify result
614 ASSERT_FALSE(result.has_value());
615 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
616}
617
618TEST(DeviceTest, prepareModelTransportFailure) {
619 // setup call
620 const auto mockDevice = createMockDevice();
621 const auto device = Device::create(kName, mockDevice).value();
622 EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
623 .Times(1)
624 .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
625
626 // run test
627 const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
628 nn::Priority::DEFAULT, {}, {}, {}, {});
629
630 // verify result
631 ASSERT_FALSE(result.has_value());
632 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
633}
634
635TEST(DeviceTest, prepareModelDeadObject) {
636 // setup call
637 const auto mockDevice = createMockDevice();
638 const auto device = Device::create(kName, mockDevice).value();
639 EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
640 .Times(1)
641 .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
642
643 // run test
644 const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
645 nn::Priority::DEFAULT, {}, {}, {}, {});
646
647 // verify result
648 ASSERT_FALSE(result.has_value());
649 EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
650}
651
652TEST(DeviceTest, prepareModelAsyncCrash) {
653 // setup test
654 const auto mockDevice = createMockDevice();
655 const auto device = Device::create(kName, mockDevice).value();
656 const auto ret = [&device]() {
657 DeathMonitor::serviceDied(device->getDeathMonitor());
658 return ndk::ScopedAStatus::ok();
659 };
660 EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
661 .Times(1)
662 .WillOnce(InvokeWithoutArgs(ret));
663
664 // run test
665 const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
666 nn::Priority::DEFAULT, {}, {}, {}, {});
667
668 // verify result
669 ASSERT_FALSE(result.has_value());
670 EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
671}
672
673TEST(DeviceTest, prepareModelFromCache) {
674 // setup call
675 const auto mockDevice = createMockDevice();
676 const auto device = Device::create(kName, mockDevice).value();
677 const auto mockPreparedModel = MockPreparedModel::create();
678 EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
679 .Times(1)
680 .WillOnce(Invoke(makePreparedModelFromCacheReturn(ErrorStatus::NONE, ErrorStatus::NONE,
681 mockPreparedModel)));
682
683 // run test
684 const auto result = device->prepareModelFromCache({}, {}, {}, {});
685
686 // verify result
687 ASSERT_TRUE(result.has_value())
688 << "Failed with " << result.error().code << ": " << result.error().message;
689 EXPECT_NE(result.value(), nullptr);
690}
691
692TEST(DeviceTest, prepareModelFromCacheLaunchError) {
693 // setup call
694 const auto mockDevice = createMockDevice();
695 const auto device = Device::create(kName, mockDevice).value();
696 EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
697 .Times(1)
698 .WillOnce(Invoke(makePreparedModelFromCacheReturn(
699 ErrorStatus::GENERAL_FAILURE, ErrorStatus::GENERAL_FAILURE, nullptr)));
700
701 // run test
702 const auto result = device->prepareModelFromCache({}, {}, {}, {});
703
704 // verify result
705 ASSERT_FALSE(result.has_value());
706 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
707}
708
709TEST(DeviceTest, prepareModelFromCacheReturnError) {
710 // setup call
711 const auto mockDevice = createMockDevice();
712 const auto device = Device::create(kName, mockDevice).value();
713 EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
714 .Times(1)
715 .WillOnce(Invoke(makePreparedModelFromCacheReturn(
716 ErrorStatus::NONE, ErrorStatus::GENERAL_FAILURE, nullptr)));
717
718 // run test
719 const auto result = device->prepareModelFromCache({}, {}, {}, {});
720
721 // verify result
722 ASSERT_FALSE(result.has_value());
723 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
724}
725
726TEST(DeviceTest, prepareModelFromCacheNullptrError) {
727 // setup call
728 const auto mockDevice = createMockDevice();
729 const auto device = Device::create(kName, mockDevice).value();
730 EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
731 .Times(1)
732 .WillOnce(Invoke(makePreparedModelFromCacheReturn(ErrorStatus::NONE, ErrorStatus::NONE,
733 nullptr)));
734
735 // run test
736 const auto result = device->prepareModelFromCache({}, {}, {}, {});
737
738 // verify result
739 ASSERT_FALSE(result.has_value());
740 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
741}
742
743TEST(DeviceTest, prepareModelFromCacheTransportFailure) {
744 // setup call
745 const auto mockDevice = createMockDevice();
746 const auto device = Device::create(kName, mockDevice).value();
747 EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
748 .Times(1)
749 .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
750
751 // run test
752 const auto result = device->prepareModelFromCache({}, {}, {}, {});
753
754 // verify result
755 ASSERT_FALSE(result.has_value());
756 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
757}
758
759TEST(DeviceTest, prepareModelFromCacheDeadObject) {
760 // setup call
761 const auto mockDevice = createMockDevice();
762 const auto device = Device::create(kName, mockDevice).value();
763 EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
764 .Times(1)
765 .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
766
767 // run test
768 const auto result = device->prepareModelFromCache({}, {}, {}, {});
769
770 // verify result
771 ASSERT_FALSE(result.has_value());
772 EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
773}
774
775TEST(DeviceTest, prepareModelFromCacheAsyncCrash) {
776 // setup test
777 const auto mockDevice = createMockDevice();
778 const auto device = Device::create(kName, mockDevice).value();
779 const auto ret = [&device]() {
780 DeathMonitor::serviceDied(device->getDeathMonitor());
781 return ndk::ScopedAStatus::ok();
782 };
783 EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
784 .Times(1)
785 .WillOnce(InvokeWithoutArgs(ret));
786
787 // run test
788 const auto result = device->prepareModelFromCache({}, {}, {}, {});
789
790 // verify result
791 ASSERT_FALSE(result.has_value());
792 EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
793}
794
795TEST(DeviceTest, allocate) {
796 // setup call
797 const auto mockDevice = createMockDevice();
798 const auto device = Device::create(kName, mockDevice).value();
799 const auto mockBuffer = DeviceBuffer{.buffer = MockBuffer::create(), .token = 1};
800 EXPECT_CALL(*mockDevice, allocate(_, _, _, _, _))
801 .Times(1)
802 .WillOnce(DoAll(SetArgPointee<4>(mockBuffer), InvokeWithoutArgs(makeStatusOk)));
803
804 // run test
805 const auto result = device->allocate({}, {}, {}, {});
806
807 // verify result
808 ASSERT_TRUE(result.has_value())
809 << "Failed with " << result.error().code << ": " << result.error().message;
810 EXPECT_NE(result.value(), nullptr);
811}
812
813TEST(DeviceTest, allocateError) {
814 // setup call
815 const auto mockDevice = createMockDevice();
816 const auto device = Device::create(kName, mockDevice).value();
817 EXPECT_CALL(*mockDevice, allocate(_, _, _, _, _))
818 .Times(1)
819 .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
820
821 // run test
822 const auto result = device->allocate({}, {}, {}, {});
823
824 // verify result
825 ASSERT_FALSE(result.has_value());
826 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
827}
828
829TEST(DeviceTest, allocateTransportFailure) {
830 // setup call
831 const auto mockDevice = createMockDevice();
832 const auto device = Device::create(kName, mockDevice).value();
833 EXPECT_CALL(*mockDevice, allocate(_, _, _, _, _))
834 .Times(1)
835 .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
836
837 // run test
838 const auto result = device->allocate({}, {}, {}, {});
839
840 // verify result
841 ASSERT_FALSE(result.has_value());
842 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
843}
844
845TEST(DeviceTest, allocateDeadObject) {
846 // setup call
847 const auto mockDevice = createMockDevice();
848 const auto device = Device::create(kName, mockDevice).value();
849 EXPECT_CALL(*mockDevice, allocate(_, _, _, _, _))
850 .Times(1)
851 .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
852
853 // run test
854 const auto result = device->allocate({}, {}, {}, {});
855
856 // verify result
857 ASSERT_FALSE(result.has_value());
858 EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
859}
860
861} // namespace aidl::android::hardware::neuralnetworks::utils