blob: f121acaf7bbaa15d0911b2babac235839a635d28 [file] [log] [blame]
Lev Proleev900c28a2021-01-26 19:40:20 +00001/*
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "MockBuffer.h"
18#include "MockDevice.h"
19#include "MockPreparedModel.h"
20
21#include <aidl/android/hardware/neuralnetworks/BnDevice.h>
22#include <android/binder_auto_utils.h>
23#include <android/binder_status.h>
24#include <gmock/gmock.h>
25#include <gtest/gtest.h>
26#include <nnapi/IDevice.h>
27#include <nnapi/TypeUtils.h>
28#include <nnapi/Types.h>
29#include <nnapi/hal/aidl/Device.h>
30
31#include <functional>
32#include <memory>
33#include <string>
34
35namespace aidl::android::hardware::neuralnetworks::utils {
36namespace {
37
38namespace nn = ::android::nn;
39using ::testing::_;
40using ::testing::DoAll;
41using ::testing::Invoke;
42using ::testing::InvokeWithoutArgs;
43using ::testing::SetArgPointee;
44
45const nn::Model kSimpleModel = {
46 .main = {.operands = {{.type = nn::OperandType::TENSOR_FLOAT32,
47 .dimensions = {1},
48 .lifetime = nn::Operand::LifeTime::SUBGRAPH_INPUT},
49 {.type = nn::OperandType::TENSOR_FLOAT32,
50 .dimensions = {1},
51 .lifetime = nn::Operand::LifeTime::SUBGRAPH_OUTPUT}},
52 .operations = {{.type = nn::OperationType::RELU, .inputs = {0}, .outputs = {1}}},
53 .inputIndexes = {0},
54 .outputIndexes = {1}}};
55
56const std::string kName = "Google-MockV1";
57const std::string kInvalidName = "";
58const std::shared_ptr<BnDevice> kInvalidDevice;
59constexpr PerformanceInfo kNoPerformanceInfo = {.execTime = std::numeric_limits<float>::max(),
60 .powerUsage = std::numeric_limits<float>::max()};
Lev Proleeva31aff12021-06-28 13:10:54 +010061constexpr NumberOfCacheFiles kNumberOfCacheFiles = {.numModelCache = nn::kMaxNumberOfCacheFiles - 1,
Lev Proleev900c28a2021-01-26 19:40:20 +000062 .numDataCache = nn::kMaxNumberOfCacheFiles};
63
64constexpr auto makeStatusOk = [] { return ndk::ScopedAStatus::ok(); };
65
66std::shared_ptr<MockDevice> createMockDevice() {
67 const auto mockDevice = MockDevice::create();
68
69 // Setup default actions for each relevant call.
70 ON_CALL(*mockDevice, getVersionString(_))
71 .WillByDefault(DoAll(SetArgPointee<0>(kName), InvokeWithoutArgs(makeStatusOk)));
72 ON_CALL(*mockDevice, getType(_))
73 .WillByDefault(
74 DoAll(SetArgPointee<0>(DeviceType::OTHER), InvokeWithoutArgs(makeStatusOk)));
75 ON_CALL(*mockDevice, getSupportedExtensions(_))
76 .WillByDefault(DoAll(SetArgPointee<0>(std::vector<Extension>{}),
77 InvokeWithoutArgs(makeStatusOk)));
78 ON_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_))
79 .WillByDefault(
80 DoAll(SetArgPointee<0>(kNumberOfCacheFiles), InvokeWithoutArgs(makeStatusOk)));
81 ON_CALL(*mockDevice, getCapabilities(_))
82 .WillByDefault(
83 DoAll(SetArgPointee<0>(Capabilities{
84 .relaxedFloat32toFloat16PerformanceScalar = kNoPerformanceInfo,
85 .relaxedFloat32toFloat16PerformanceTensor = kNoPerformanceInfo,
86 .ifPerformance = kNoPerformanceInfo,
87 .whilePerformance = kNoPerformanceInfo,
88 }),
89 InvokeWithoutArgs(makeStatusOk)));
90
91 // These EXPECT_CALL(...).Times(testing::AnyNumber()) calls are to suppress warnings on the
92 // uninteresting methods calls.
93 EXPECT_CALL(*mockDevice, getVersionString(_)).Times(testing::AnyNumber());
94 EXPECT_CALL(*mockDevice, getType(_)).Times(testing::AnyNumber());
95 EXPECT_CALL(*mockDevice, getSupportedExtensions(_)).Times(testing::AnyNumber());
96 EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_)).Times(testing::AnyNumber());
97 EXPECT_CALL(*mockDevice, getCapabilities(_)).Times(testing::AnyNumber());
98
99 return mockDevice;
100}
101
102constexpr auto makePreparedModelReturnImpl =
103 [](ErrorStatus launchStatus, ErrorStatus returnStatus,
104 const std::shared_ptr<MockPreparedModel>& preparedModel,
105 const std::shared_ptr<IPreparedModelCallback>& cb) {
106 cb->notify(returnStatus, preparedModel);
107 if (launchStatus == ErrorStatus::NONE) {
108 return ndk::ScopedAStatus::ok();
109 }
110 return ndk::ScopedAStatus::fromServiceSpecificError(static_cast<int32_t>(launchStatus));
111 };
112
113auto makePreparedModelReturn(ErrorStatus launchStatus, ErrorStatus returnStatus,
114 const std::shared_ptr<MockPreparedModel>& preparedModel) {
115 return [launchStatus, returnStatus, preparedModel](
116 const Model& /*model*/, ExecutionPreference /*preference*/,
117 Priority /*priority*/, const int64_t& /*deadline*/,
118 const std::vector<ndk::ScopedFileDescriptor>& /*modelCache*/,
119 const std::vector<ndk::ScopedFileDescriptor>& /*dataCache*/,
120 const std::vector<uint8_t>& /*token*/,
121 const std::shared_ptr<IPreparedModelCallback>& cb) -> ndk::ScopedAStatus {
122 return makePreparedModelReturnImpl(launchStatus, returnStatus, preparedModel, cb);
123 };
124}
125
126auto makePreparedModelFromCacheReturn(ErrorStatus launchStatus, ErrorStatus returnStatus,
127 const std::shared_ptr<MockPreparedModel>& preparedModel) {
128 return [launchStatus, returnStatus, preparedModel](
129 const int64_t& /*deadline*/,
130 const std::vector<ndk::ScopedFileDescriptor>& /*modelCache*/,
131 const std::vector<ndk::ScopedFileDescriptor>& /*dataCache*/,
132 const std::vector<uint8_t>& /*token*/,
133 const std::shared_ptr<IPreparedModelCallback>& cb) {
134 return makePreparedModelReturnImpl(launchStatus, returnStatus, preparedModel, cb);
135 };
136}
137
138constexpr auto makeGeneralFailure = [] {
139 return ndk::ScopedAStatus::fromServiceSpecificError(
140 static_cast<int32_t>(ErrorStatus::GENERAL_FAILURE));
141};
142constexpr auto makeGeneralTransportFailure = [] {
143 return ndk::ScopedAStatus::fromStatus(STATUS_NO_MEMORY);
144};
145constexpr auto makeDeadObjectFailure = [] {
146 return ndk::ScopedAStatus::fromStatus(STATUS_DEAD_OBJECT);
147};
148
149} // namespace
150
151TEST(DeviceTest, invalidName) {
152 // run test
153 const auto device = MockDevice::create();
154 const auto result = Device::create(kInvalidName, device);
155
156 // verify result
157 ASSERT_FALSE(result.has_value());
158 EXPECT_EQ(result.error().code, nn::ErrorStatus::INVALID_ARGUMENT);
159}
160
161TEST(DeviceTest, invalidDevice) {
162 // run test
163 const auto result = Device::create(kName, kInvalidDevice);
164
165 // verify result
166 ASSERT_FALSE(result.has_value());
167 EXPECT_EQ(result.error().code, nn::ErrorStatus::INVALID_ARGUMENT);
168}
169
170TEST(DeviceTest, getVersionStringError) {
171 // setup call
172 const auto mockDevice = createMockDevice();
173 EXPECT_CALL(*mockDevice, getVersionString(_))
174 .Times(1)
175 .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
176
177 // run test
178 const auto result = Device::create(kName, mockDevice);
179
180 // verify result
181 ASSERT_FALSE(result.has_value());
182 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
183}
184
185TEST(DeviceTest, getVersionStringTransportFailure) {
186 // setup call
187 const auto mockDevice = createMockDevice();
188 EXPECT_CALL(*mockDevice, getVersionString(_))
189 .Times(1)
190 .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
191
192 // run test
193 const auto result = Device::create(kName, mockDevice);
194
195 // verify result
196 ASSERT_FALSE(result.has_value());
197 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
198}
199
200TEST(DeviceTest, getVersionStringDeadObject) {
201 // setup call
202 const auto mockDevice = createMockDevice();
203 EXPECT_CALL(*mockDevice, getVersionString(_))
204 .Times(1)
205 .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
206
207 // run test
208 const auto result = Device::create(kName, mockDevice);
209
210 // verify result
211 ASSERT_FALSE(result.has_value());
212 EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
213}
214
215TEST(DeviceTest, getTypeError) {
216 // setup call
217 const auto mockDevice = createMockDevice();
218 EXPECT_CALL(*mockDevice, getType(_)).Times(1).WillOnce(InvokeWithoutArgs(makeGeneralFailure));
219
220 // run test
221 const auto result = Device::create(kName, mockDevice);
222
223 // verify result
224 ASSERT_FALSE(result.has_value());
225 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
226}
227
228TEST(DeviceTest, getTypeTransportFailure) {
229 // setup call
230 const auto mockDevice = createMockDevice();
231 EXPECT_CALL(*mockDevice, getType(_))
232 .Times(1)
233 .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
234
235 // run test
236 const auto result = Device::create(kName, mockDevice);
237
238 // verify result
239 ASSERT_FALSE(result.has_value());
240 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
241}
242
243TEST(DeviceTest, getTypeDeadObject) {
244 // setup call
245 const auto mockDevice = createMockDevice();
246 EXPECT_CALL(*mockDevice, getType(_))
247 .Times(1)
248 .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
249
250 // run test
251 const auto result = Device::create(kName, mockDevice);
252
253 // verify result
254 ASSERT_FALSE(result.has_value());
255 EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
256}
257
258TEST(DeviceTest, getSupportedExtensionsError) {
259 // setup call
260 const auto mockDevice = createMockDevice();
261 EXPECT_CALL(*mockDevice, getSupportedExtensions(_))
262 .Times(1)
263 .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
264
265 // run test
266 const auto result = Device::create(kName, mockDevice);
267
268 // verify result
269 ASSERT_FALSE(result.has_value());
270 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
271}
272
273TEST(DeviceTest, getSupportedExtensionsTransportFailure) {
274 // setup call
275 const auto mockDevice = createMockDevice();
276 EXPECT_CALL(*mockDevice, getSupportedExtensions(_))
277 .Times(1)
278 .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
279
280 // run test
281 const auto result = Device::create(kName, mockDevice);
282
283 // verify result
284 ASSERT_FALSE(result.has_value());
285 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
286}
287
288TEST(DeviceTest, getSupportedExtensionsDeadObject) {
289 // setup call
290 const auto mockDevice = createMockDevice();
291 EXPECT_CALL(*mockDevice, getSupportedExtensions(_))
292 .Times(1)
293 .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
294
295 // run test
296 const auto result = Device::create(kName, mockDevice);
297
298 // verify result
299 ASSERT_FALSE(result.has_value());
300 EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
301}
302
Lev Proleeva31aff12021-06-28 13:10:54 +0100303TEST(DeviceTest, getNumberOfCacheFilesNeeded) {
304 // setup call
305 const auto mockDevice = createMockDevice();
306 EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_)).Times(1);
307
308 // run test
309 const auto result = Device::create(kName, mockDevice);
310
311 // verify result
312 ASSERT_TRUE(result.has_value());
313 constexpr auto kNumberOfCacheFilesPair = std::make_pair<uint32_t, uint32_t>(
314 kNumberOfCacheFiles.numModelCache, kNumberOfCacheFiles.numDataCache);
315 EXPECT_EQ(result.value()->getNumberOfCacheFilesNeeded(), kNumberOfCacheFilesPair);
316}
317
Lev Proleev900c28a2021-01-26 19:40:20 +0000318TEST(DeviceTest, getNumberOfCacheFilesNeededError) {
319 // setup call
320 const auto mockDevice = createMockDevice();
321 EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_))
322 .Times(1)
323 .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
324
325 // run test
326 const auto result = Device::create(kName, mockDevice);
327
328 // verify result
329 ASSERT_FALSE(result.has_value());
330 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
331}
332
333TEST(DeviceTest, dataCacheFilesExceedsSpecifiedMax) {
334 // setup test
335 const auto mockDevice = createMockDevice();
336 EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_))
337 .Times(1)
338 .WillOnce(DoAll(SetArgPointee<0>(NumberOfCacheFiles{
339 .numModelCache = nn::kMaxNumberOfCacheFiles + 1,
340 .numDataCache = nn::kMaxNumberOfCacheFiles}),
341 InvokeWithoutArgs(makeStatusOk)));
342
343 // run test
344 const auto result = Device::create(kName, mockDevice);
345
346 // verify result
347 ASSERT_FALSE(result.has_value());
348 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
349}
350
351TEST(DeviceTest, modelCacheFilesExceedsSpecifiedMax) {
352 // setup test
353 const auto mockDevice = createMockDevice();
354 EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_))
355 .Times(1)
356 .WillOnce(DoAll(SetArgPointee<0>(NumberOfCacheFiles{
357 .numModelCache = nn::kMaxNumberOfCacheFiles,
358 .numDataCache = nn::kMaxNumberOfCacheFiles + 1}),
359 InvokeWithoutArgs(makeStatusOk)));
360
361 // run test
362 const auto result = Device::create(kName, mockDevice);
363
364 // verify result
365 ASSERT_FALSE(result.has_value());
366 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
367}
368
369TEST(DeviceTest, getNumberOfCacheFilesNeededTransportFailure) {
370 // setup call
371 const auto mockDevice = createMockDevice();
372 EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_))
373 .Times(1)
374 .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
375
376 // run test
377 const auto result = Device::create(kName, mockDevice);
378
379 // verify result
380 ASSERT_FALSE(result.has_value());
381 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
382}
383
384TEST(DeviceTest, getNumberOfCacheFilesNeededDeadObject) {
385 // setup call
386 const auto mockDevice = createMockDevice();
387 EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_))
388 .Times(1)
389 .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
390
391 // run test
392 const auto result = Device::create(kName, mockDevice);
393
394 // verify result
395 ASSERT_FALSE(result.has_value());
396 EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
397}
398
399TEST(DeviceTest, getCapabilitiesError) {
400 // setup call
401 const auto mockDevice = createMockDevice();
402 EXPECT_CALL(*mockDevice, getCapabilities(_))
403 .Times(1)
404 .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
405
406 // run test
407 const auto result = Device::create(kName, mockDevice);
408
409 // verify result
410 ASSERT_FALSE(result.has_value());
411 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
412}
413
414TEST(DeviceTest, getCapabilitiesTransportFailure) {
415 // setup call
416 const auto mockDevice = createMockDevice();
417 EXPECT_CALL(*mockDevice, getCapabilities(_))
418 .Times(1)
419 .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
420
421 // run test
422 const auto result = Device::create(kName, mockDevice);
423
424 // verify result
425 ASSERT_FALSE(result.has_value());
426 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
427}
428
429TEST(DeviceTest, getCapabilitiesDeadObject) {
430 // setup call
431 const auto mockDevice = createMockDevice();
432 EXPECT_CALL(*mockDevice, getCapabilities(_))
433 .Times(1)
434 .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
435
436 // run test
437 const auto result = Device::create(kName, mockDevice);
438
439 // verify result
440 ASSERT_FALSE(result.has_value());
441 EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
442}
443
444TEST(DeviceTest, getName) {
445 // setup call
446 const auto mockDevice = createMockDevice();
447 const auto device = Device::create(kName, mockDevice).value();
448
449 // run test
450 const auto& name = device->getName();
451
452 // verify result
453 EXPECT_EQ(name, kName);
454}
455
456TEST(DeviceTest, getFeatureLevel) {
457 // setup call
458 const auto mockDevice = createMockDevice();
459 const auto device = Device::create(kName, mockDevice).value();
460
461 // run test
462 const auto featureLevel = device->getFeatureLevel();
463
464 // verify result
465 EXPECT_EQ(featureLevel, nn::Version::ANDROID_S);
466}
467
468TEST(DeviceTest, getCachedData) {
469 // setup call
470 const auto mockDevice = createMockDevice();
471 EXPECT_CALL(*mockDevice, getVersionString(_)).Times(1);
472 EXPECT_CALL(*mockDevice, getType(_)).Times(1);
473 EXPECT_CALL(*mockDevice, getSupportedExtensions(_)).Times(1);
474 EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_)).Times(1);
475 EXPECT_CALL(*mockDevice, getCapabilities(_)).Times(1);
476
477 const auto result = Device::create(kName, mockDevice);
478 ASSERT_TRUE(result.has_value())
479 << "Failed with " << result.error().code << ": " << result.error().message;
480 const auto& device = result.value();
481
482 // run test and verify results
483 EXPECT_EQ(device->getVersionString(), device->getVersionString());
484 EXPECT_EQ(device->getType(), device->getType());
485 EXPECT_EQ(device->getSupportedExtensions(), device->getSupportedExtensions());
486 EXPECT_EQ(device->getNumberOfCacheFilesNeeded(), device->getNumberOfCacheFilesNeeded());
487 EXPECT_EQ(device->getCapabilities(), device->getCapabilities());
488}
489
490TEST(DeviceTest, getSupportedOperations) {
491 // setup call
492 const auto mockDevice = createMockDevice();
493 const auto device = Device::create(kName, mockDevice).value();
494 EXPECT_CALL(*mockDevice, getSupportedOperations(_, _))
495 .Times(1)
496 .WillOnce(DoAll(
497 SetArgPointee<1>(std::vector<bool>(kSimpleModel.main.operations.size(), true)),
498 InvokeWithoutArgs(makeStatusOk)));
499
500 // run test
501 const auto result = device->getSupportedOperations(kSimpleModel);
502
503 // verify result
504 ASSERT_TRUE(result.has_value())
505 << "Failed with " << result.error().code << ": " << result.error().message;
506 const auto& supportedOperations = result.value();
507 EXPECT_EQ(supportedOperations.size(), kSimpleModel.main.operations.size());
508 EXPECT_THAT(supportedOperations, Each(testing::IsTrue()));
509}
510
511TEST(DeviceTest, getSupportedOperationsError) {
512 // setup call
513 const auto mockDevice = createMockDevice();
514 const auto device = Device::create(kName, mockDevice).value();
515 EXPECT_CALL(*mockDevice, getSupportedOperations(_, _))
516 .Times(1)
517 .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
518
519 // run test
520 const auto result = device->getSupportedOperations(kSimpleModel);
521
522 // verify result
523 ASSERT_FALSE(result.has_value());
524 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
525}
526
527TEST(DeviceTest, getSupportedOperationsTransportFailure) {
528 // setup call
529 const auto mockDevice = createMockDevice();
530 const auto device = Device::create(kName, mockDevice).value();
531 EXPECT_CALL(*mockDevice, getSupportedOperations(_, _))
532 .Times(1)
533 .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
534
535 // run test
536 const auto result = device->getSupportedOperations(kSimpleModel);
537
538 // verify result
539 ASSERT_FALSE(result.has_value());
540 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
541}
542
543TEST(DeviceTest, getSupportedOperationsDeadObject) {
544 // setup call
545 const auto mockDevice = createMockDevice();
546 const auto device = Device::create(kName, mockDevice).value();
547 EXPECT_CALL(*mockDevice, getSupportedOperations(_, _))
548 .Times(1)
549 .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
550
551 // run test
552 const auto result = device->getSupportedOperations(kSimpleModel);
553
554 // verify result
555 ASSERT_FALSE(result.has_value());
556 EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
557}
558
559TEST(DeviceTest, prepareModel) {
560 // setup call
561 const auto mockDevice = createMockDevice();
562 const auto device = Device::create(kName, mockDevice).value();
563 const auto mockPreparedModel = MockPreparedModel::create();
564 EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
565 .Times(1)
566 .WillOnce(Invoke(makePreparedModelReturn(ErrorStatus::NONE, ErrorStatus::NONE,
567 mockPreparedModel)));
568
569 // run test
570 const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
571 nn::Priority::DEFAULT, {}, {}, {}, {});
572
573 // verify result
574 ASSERT_TRUE(result.has_value())
575 << "Failed with " << result.error().code << ": " << result.error().message;
576 EXPECT_NE(result.value(), nullptr);
577}
578
579TEST(DeviceTest, prepareModelLaunchError) {
580 // setup call
581 const auto mockDevice = createMockDevice();
582 const auto device = Device::create(kName, mockDevice).value();
583 EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
584 .Times(1)
585 .WillOnce(Invoke(makePreparedModelReturn(ErrorStatus::GENERAL_FAILURE,
586 ErrorStatus::GENERAL_FAILURE, nullptr)));
587
588 // run test
589 const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
590 nn::Priority::DEFAULT, {}, {}, {}, {});
591
592 // verify result
593 ASSERT_FALSE(result.has_value());
594 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
595}
596
597TEST(DeviceTest, prepareModelReturnError) {
598 // setup call
599 const auto mockDevice = createMockDevice();
600 const auto device = Device::create(kName, mockDevice).value();
601 EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
602 .Times(1)
603 .WillOnce(Invoke(makePreparedModelReturn(ErrorStatus::NONE,
604 ErrorStatus::GENERAL_FAILURE, nullptr)));
605
606 // run test
607 const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
608 nn::Priority::DEFAULT, {}, {}, {}, {});
609
610 // verify result
611 ASSERT_FALSE(result.has_value());
612 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
613}
614
615TEST(DeviceTest, prepareModelNullptrError) {
616 // setup call
617 const auto mockDevice = createMockDevice();
618 const auto device = Device::create(kName, mockDevice).value();
619 EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
620 .Times(1)
621 .WillOnce(
622 Invoke(makePreparedModelReturn(ErrorStatus::NONE, ErrorStatus::NONE, nullptr)));
623
624 // run test
625 const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
626 nn::Priority::DEFAULT, {}, {}, {}, {});
627
628 // verify result
629 ASSERT_FALSE(result.has_value());
630 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
631}
632
633TEST(DeviceTest, prepareModelTransportFailure) {
634 // setup call
635 const auto mockDevice = createMockDevice();
636 const auto device = Device::create(kName, mockDevice).value();
637 EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
638 .Times(1)
639 .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
640
641 // run test
642 const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
643 nn::Priority::DEFAULT, {}, {}, {}, {});
644
645 // verify result
646 ASSERT_FALSE(result.has_value());
647 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
648}
649
650TEST(DeviceTest, prepareModelDeadObject) {
651 // setup call
652 const auto mockDevice = createMockDevice();
653 const auto device = Device::create(kName, mockDevice).value();
654 EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
655 .Times(1)
656 .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
657
658 // run test
659 const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
660 nn::Priority::DEFAULT, {}, {}, {}, {});
661
662 // verify result
663 ASSERT_FALSE(result.has_value());
664 EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
665}
666
667TEST(DeviceTest, prepareModelAsyncCrash) {
668 // setup test
669 const auto mockDevice = createMockDevice();
670 const auto device = Device::create(kName, mockDevice).value();
671 const auto ret = [&device]() {
672 DeathMonitor::serviceDied(device->getDeathMonitor());
673 return ndk::ScopedAStatus::ok();
674 };
675 EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
676 .Times(1)
677 .WillOnce(InvokeWithoutArgs(ret));
678
679 // run test
680 const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
681 nn::Priority::DEFAULT, {}, {}, {}, {});
682
683 // verify result
684 ASSERT_FALSE(result.has_value());
685 EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
686}
687
688TEST(DeviceTest, prepareModelFromCache) {
689 // setup call
690 const auto mockDevice = createMockDevice();
691 const auto device = Device::create(kName, mockDevice).value();
692 const auto mockPreparedModel = MockPreparedModel::create();
693 EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
694 .Times(1)
695 .WillOnce(Invoke(makePreparedModelFromCacheReturn(ErrorStatus::NONE, ErrorStatus::NONE,
696 mockPreparedModel)));
697
698 // run test
699 const auto result = device->prepareModelFromCache({}, {}, {}, {});
700
701 // verify result
702 ASSERT_TRUE(result.has_value())
703 << "Failed with " << result.error().code << ": " << result.error().message;
704 EXPECT_NE(result.value(), nullptr);
705}
706
707TEST(DeviceTest, prepareModelFromCacheLaunchError) {
708 // setup call
709 const auto mockDevice = createMockDevice();
710 const auto device = Device::create(kName, mockDevice).value();
711 EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
712 .Times(1)
713 .WillOnce(Invoke(makePreparedModelFromCacheReturn(
714 ErrorStatus::GENERAL_FAILURE, ErrorStatus::GENERAL_FAILURE, nullptr)));
715
716 // run test
717 const auto result = device->prepareModelFromCache({}, {}, {}, {});
718
719 // verify result
720 ASSERT_FALSE(result.has_value());
721 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
722}
723
724TEST(DeviceTest, prepareModelFromCacheReturnError) {
725 // setup call
726 const auto mockDevice = createMockDevice();
727 const auto device = Device::create(kName, mockDevice).value();
728 EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
729 .Times(1)
730 .WillOnce(Invoke(makePreparedModelFromCacheReturn(
731 ErrorStatus::NONE, ErrorStatus::GENERAL_FAILURE, nullptr)));
732
733 // run test
734 const auto result = device->prepareModelFromCache({}, {}, {}, {});
735
736 // verify result
737 ASSERT_FALSE(result.has_value());
738 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
739}
740
741TEST(DeviceTest, prepareModelFromCacheNullptrError) {
742 // setup call
743 const auto mockDevice = createMockDevice();
744 const auto device = Device::create(kName, mockDevice).value();
745 EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
746 .Times(1)
747 .WillOnce(Invoke(makePreparedModelFromCacheReturn(ErrorStatus::NONE, ErrorStatus::NONE,
748 nullptr)));
749
750 // run test
751 const auto result = device->prepareModelFromCache({}, {}, {}, {});
752
753 // verify result
754 ASSERT_FALSE(result.has_value());
755 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
756}
757
758TEST(DeviceTest, prepareModelFromCacheTransportFailure) {
759 // setup call
760 const auto mockDevice = createMockDevice();
761 const auto device = Device::create(kName, mockDevice).value();
762 EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
763 .Times(1)
764 .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
765
766 // run test
767 const auto result = device->prepareModelFromCache({}, {}, {}, {});
768
769 // verify result
770 ASSERT_FALSE(result.has_value());
771 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
772}
773
774TEST(DeviceTest, prepareModelFromCacheDeadObject) {
775 // setup call
776 const auto mockDevice = createMockDevice();
777 const auto device = Device::create(kName, mockDevice).value();
778 EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
779 .Times(1)
780 .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
781
782 // run test
783 const auto result = device->prepareModelFromCache({}, {}, {}, {});
784
785 // verify result
786 ASSERT_FALSE(result.has_value());
787 EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
788}
789
790TEST(DeviceTest, prepareModelFromCacheAsyncCrash) {
791 // setup test
792 const auto mockDevice = createMockDevice();
793 const auto device = Device::create(kName, mockDevice).value();
794 const auto ret = [&device]() {
795 DeathMonitor::serviceDied(device->getDeathMonitor());
796 return ndk::ScopedAStatus::ok();
797 };
798 EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
799 .Times(1)
800 .WillOnce(InvokeWithoutArgs(ret));
801
802 // run test
803 const auto result = device->prepareModelFromCache({}, {}, {}, {});
804
805 // verify result
806 ASSERT_FALSE(result.has_value());
807 EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
808}
809
810TEST(DeviceTest, allocate) {
811 // setup call
812 const auto mockDevice = createMockDevice();
813 const auto device = Device::create(kName, mockDevice).value();
814 const auto mockBuffer = DeviceBuffer{.buffer = MockBuffer::create(), .token = 1};
815 EXPECT_CALL(*mockDevice, allocate(_, _, _, _, _))
816 .Times(1)
817 .WillOnce(DoAll(SetArgPointee<4>(mockBuffer), InvokeWithoutArgs(makeStatusOk)));
818
819 // run test
820 const auto result = device->allocate({}, {}, {}, {});
821
822 // verify result
823 ASSERT_TRUE(result.has_value())
824 << "Failed with " << result.error().code << ": " << result.error().message;
825 EXPECT_NE(result.value(), nullptr);
826}
827
828TEST(DeviceTest, allocateError) {
829 // setup call
830 const auto mockDevice = createMockDevice();
831 const auto device = Device::create(kName, mockDevice).value();
832 EXPECT_CALL(*mockDevice, allocate(_, _, _, _, _))
833 .Times(1)
834 .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
835
836 // run test
837 const auto result = device->allocate({}, {}, {}, {});
838
839 // verify result
840 ASSERT_FALSE(result.has_value());
841 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
842}
843
844TEST(DeviceTest, allocateTransportFailure) {
845 // setup call
846 const auto mockDevice = createMockDevice();
847 const auto device = Device::create(kName, mockDevice).value();
848 EXPECT_CALL(*mockDevice, allocate(_, _, _, _, _))
849 .Times(1)
850 .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
851
852 // run test
853 const auto result = device->allocate({}, {}, {}, {});
854
855 // verify result
856 ASSERT_FALSE(result.has_value());
857 EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
858}
859
860TEST(DeviceTest, allocateDeadObject) {
861 // setup call
862 const auto mockDevice = createMockDevice();
863 const auto device = Device::create(kName, mockDevice).value();
864 EXPECT_CALL(*mockDevice, allocate(_, _, _, _, _))
865 .Times(1)
866 .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
867
868 // run test
869 const auto result = device->allocate({}, {}, {}, {});
870
871 // verify result
872 ASSERT_FALSE(result.has_value());
873 EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
874}
875
876} // namespace aidl::android::hardware::neuralnetworks::utils