blob: 09ccc9a7190390dce4df81c18232ef6a68ad5eef [file] [log] [blame]
Lev Proleev13fdfcd2019-08-30 11:35:34 +01001/*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "GeneratedTestHarness.h"
18
19#include <android-base/logging.h>
20#include <android/hardware/neuralnetworks/1.0/IDevice.h>
21#include <android/hardware/neuralnetworks/1.0/IExecutionCallback.h>
22#include <android/hardware/neuralnetworks/1.0/IPreparedModel.h>
23#include <android/hardware/neuralnetworks/1.0/IPreparedModelCallback.h>
24#include <android/hardware/neuralnetworks/1.0/types.h>
25#include <android/hardware/neuralnetworks/1.1/IDevice.h>
26#include <android/hardware/neuralnetworks/1.2/IDevice.h>
27#include <android/hardware/neuralnetworks/1.2/IExecutionCallback.h>
28#include <android/hardware/neuralnetworks/1.2/IPreparedModel.h>
29#include <android/hardware/neuralnetworks/1.2/IPreparedModelCallback.h>
Lev Proleev26d1bc82019-08-30 11:57:18 +010030#include <android/hardware/neuralnetworks/1.2/types.h>
31#include <android/hardware/neuralnetworks/1.3/IDevice.h>
Xusong Wang1b3f4262019-10-25 12:07:17 -070032#include <android/hardware/neuralnetworks/1.3/IPreparedModel.h>
Xusong Wangcc47dff2019-10-23 10:35:07 -070033#include <android/hardware/neuralnetworks/1.3/IPreparedModelCallback.h>
Lev Proleev26d1bc82019-08-30 11:57:18 +010034#include <android/hardware/neuralnetworks/1.3/types.h>
Lev Proleev13fdfcd2019-08-30 11:35:34 +010035#include <android/hidl/allocator/1.0/IAllocator.h>
36#include <android/hidl/memory/1.0/IMemory.h>
Lev Proleev56cda832019-12-05 14:49:47 +000037#include <gtest/gtest.h>
Lev Proleev13fdfcd2019-08-30 11:35:34 +010038#include <hidlmemory/mapping.h>
39
Lev Proleev13fdfcd2019-08-30 11:35:34 +010040#include <algorithm>
Michael Butler648ada52019-07-25 17:22:11 -070041#include <chrono>
Lev Proleev13fdfcd2019-08-30 11:35:34 +010042#include <iostream>
43#include <numeric>
Lev Proleev56cda832019-12-05 14:49:47 +000044#include <vector>
Lev Proleev13fdfcd2019-08-30 11:35:34 +010045
46#include "1.0/Utils.h"
47#include "1.2/Callbacks.h"
Xusong Wangcc47dff2019-10-23 10:35:07 -070048#include "1.3/Callbacks.h"
Lev Proleev13fdfcd2019-08-30 11:35:34 +010049#include "ExecutionBurstController.h"
50#include "MemoryUtils.h"
51#include "TestHarness.h"
52#include "Utils.h"
53#include "VtsHalNeuralnetworks.h"
54
Lev Proleev26d1bc82019-08-30 11:57:18 +010055namespace android::hardware::neuralnetworks::V1_3::vts::functional {
Lev Proleev13fdfcd2019-08-30 11:35:34 +010056
57using namespace test_helper;
58using hidl::memory::V1_0::IMemory;
Xusong Wangcc47dff2019-10-23 10:35:07 -070059using implementation::PreparedModelCallback;
Lev Proleev13fdfcd2019-08-30 11:35:34 +010060using V1_0::DataLocation;
61using V1_0::ErrorStatus;
62using V1_0::OperandLifeTime;
Lev Proleev13fdfcd2019-08-30 11:35:34 +010063using V1_1::ExecutionPreference;
Lev Proleev26d1bc82019-08-30 11:57:18 +010064using V1_2::Constant;
Lev Proleev26d1bc82019-08-30 11:57:18 +010065using V1_2::MeasureTiming;
Lev Proleev26d1bc82019-08-30 11:57:18 +010066using V1_2::OutputShape;
67using V1_2::SymmPerChannelQuantParams;
68using V1_2::Timing;
69using V1_2::implementation::ExecutionCallback;
Lev Proleev13fdfcd2019-08-30 11:35:34 +010070using HidlToken = hidl_array<uint8_t, static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN)>;
71
Lev Proleev0d4ba3f2019-10-02 17:32:06 +010072namespace {
73
74enum class Executor { ASYNC, SYNC, BURST };
75
Lev Proleev13fdfcd2019-08-30 11:35:34 +010076enum class OutputType { FULLY_SPECIFIED, UNSPECIFIED, INSUFFICIENT };
77
Lev Proleev0d4ba3f2019-10-02 17:32:06 +010078struct TestConfig {
79 Executor executor;
80 MeasureTiming measureTiming;
81 OutputType outputType;
Lev Proleev9226c1e2019-10-03 14:43:18 +010082 // `reportSkipping` indicates if a test should print an info message in case
83 // it is skipped. The field is set to true by default and is set to false in
84 // quantization coupling tests to suppress skipping a test
85 bool reportSkipping;
86 TestConfig(Executor executor, MeasureTiming measureTiming, OutputType outputType)
87 : executor(executor),
88 measureTiming(measureTiming),
89 outputType(outputType),
90 reportSkipping(true) {}
91 TestConfig(Executor executor, MeasureTiming measureTiming, OutputType outputType,
92 bool reportSkipping)
93 : executor(executor),
94 measureTiming(measureTiming),
95 outputType(outputType),
96 reportSkipping(reportSkipping) {}
Lev Proleev0d4ba3f2019-10-02 17:32:06 +010097};
98
99} // namespace
100
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100101Model createModel(const TestModel& testModel) {
102 // Model operands.
103 hidl_vec<Operand> operands(testModel.operands.size());
104 size_t constCopySize = 0, constRefSize = 0;
105 for (uint32_t i = 0; i < testModel.operands.size(); i++) {
106 const auto& op = testModel.operands[i];
107
108 DataLocation loc = {};
109 if (op.lifetime == TestOperandLifeTime::CONSTANT_COPY) {
110 loc = {.poolIndex = 0,
111 .offset = static_cast<uint32_t>(constCopySize),
112 .length = static_cast<uint32_t>(op.data.size())};
113 constCopySize += op.data.alignedSize();
114 } else if (op.lifetime == TestOperandLifeTime::CONSTANT_REFERENCE) {
115 loc = {.poolIndex = 0,
116 .offset = static_cast<uint32_t>(constRefSize),
117 .length = static_cast<uint32_t>(op.data.size())};
118 constRefSize += op.data.alignedSize();
119 }
120
121 Operand::ExtraParams extraParams;
122 if (op.type == TestOperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
123 extraParams.channelQuant(SymmPerChannelQuantParams{
124 .scales = op.channelQuant.scales, .channelDim = op.channelQuant.channelDim});
125 }
126
127 operands[i] = {.type = static_cast<OperandType>(op.type),
128 .dimensions = op.dimensions,
129 .numberOfConsumers = op.numberOfConsumers,
130 .scale = op.scale,
131 .zeroPoint = op.zeroPoint,
132 .lifetime = static_cast<OperandLifeTime>(op.lifetime),
133 .location = loc,
134 .extraParams = std::move(extraParams)};
135 }
136
137 // Model operations.
138 hidl_vec<Operation> operations(testModel.operations.size());
139 std::transform(testModel.operations.begin(), testModel.operations.end(), operations.begin(),
140 [](const TestOperation& op) -> Operation {
141 return {.type = static_cast<OperationType>(op.type),
142 .inputs = op.inputs,
143 .outputs = op.outputs};
144 });
145
146 // Constant copies.
147 hidl_vec<uint8_t> operandValues(constCopySize);
148 for (uint32_t i = 0; i < testModel.operands.size(); i++) {
149 const auto& op = testModel.operands[i];
150 if (op.lifetime == TestOperandLifeTime::CONSTANT_COPY) {
151 const uint8_t* begin = op.data.get<uint8_t>();
152 const uint8_t* end = begin + op.data.size();
153 std::copy(begin, end, operandValues.data() + operands[i].location.offset);
154 }
155 }
156
157 // Shared memory.
158 hidl_vec<hidl_memory> pools = {};
159 if (constRefSize > 0) {
160 hidl_vec_push_back(&pools, nn::allocateSharedMemory(constRefSize));
161 CHECK_NE(pools[0].size(), 0u);
162
163 // load data
164 sp<IMemory> mappedMemory = mapMemory(pools[0]);
165 CHECK(mappedMemory.get() != nullptr);
166 uint8_t* mappedPtr =
167 reinterpret_cast<uint8_t*>(static_cast<void*>(mappedMemory->getPointer()));
168 CHECK(mappedPtr != nullptr);
169
170 for (uint32_t i = 0; i < testModel.operands.size(); i++) {
171 const auto& op = testModel.operands[i];
172 if (op.lifetime == TestOperandLifeTime::CONSTANT_REFERENCE) {
173 const uint8_t* begin = op.data.get<uint8_t>();
174 const uint8_t* end = begin + op.data.size();
175 std::copy(begin, end, mappedPtr + operands[i].location.offset);
176 }
177 }
178 }
179
180 return {.operands = std::move(operands),
181 .operations = std::move(operations),
182 .inputIndexes = testModel.inputIndexes,
183 .outputIndexes = testModel.outputIndexes,
184 .operandValues = std::move(operandValues),
185 .pools = std::move(pools),
186 .relaxComputationFloat32toFloat16 = testModel.isRelaxed};
187}
188
189static bool isOutputSizeGreaterThanOne(const TestModel& testModel, uint32_t index) {
190 const auto byteSize = testModel.operands[testModel.outputIndexes[index]].data.size();
191 return byteSize > 1u;
192}
193
Xusong Wangb345a462019-11-27 12:46:48 -0800194static void makeOutputInsufficientSize(uint32_t outputIndex, V1_0::Request* request) {
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100195 auto& length = request->outputs[outputIndex].location.length;
196 ASSERT_GT(length, 1u);
197 length -= 1u;
198}
199
200static void makeOutputDimensionsUnspecified(Model* model) {
201 for (auto i : model->outputIndexes) {
202 auto& dims = model->operands[i].dimensions;
203 std::fill(dims.begin(), dims.end(), 0);
204 }
205}
206
207static Return<ErrorStatus> ExecutePreparedModel(const sp<IPreparedModel>& preparedModel,
208 const Request& request, MeasureTiming measure,
209 sp<ExecutionCallback>& callback) {
Xusong Wang1b3f4262019-10-25 12:07:17 -0700210 return preparedModel->execute_1_3(request, measure, callback);
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100211}
212static Return<ErrorStatus> ExecutePreparedModel(const sp<IPreparedModel>& preparedModel,
213 const Request& request, MeasureTiming measure,
214 hidl_vec<OutputShape>* outputShapes,
215 Timing* timing) {
216 ErrorStatus result;
Xusong Wangd4a060b2019-10-28 11:11:19 -0700217 Return<void> ret = preparedModel->executeSynchronously_1_3(
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100218 request, measure,
219 [&result, outputShapes, timing](ErrorStatus error, const hidl_vec<OutputShape>& shapes,
220 const Timing& time) {
221 result = error;
222 *outputShapes = shapes;
223 *timing = time;
224 });
225 if (!ret.isOk()) {
226 return ErrorStatus::GENERAL_FAILURE;
227 }
228 return result;
229}
230static std::shared_ptr<::android::nn::ExecutionBurstController> CreateBurst(
231 const sp<IPreparedModel>& preparedModel) {
Michael Butler648ada52019-07-25 17:22:11 -0700232 return android::nn::ExecutionBurstController::create(preparedModel,
233 std::chrono::microseconds{0});
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100234}
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100235
236void EvaluatePreparedModel(const sp<IPreparedModel>& preparedModel, const TestModel& testModel,
Lev Proleev9226c1e2019-10-03 14:43:18 +0100237 const TestConfig& testConfig, bool* skipped = nullptr) {
238 if (skipped != nullptr) {
239 *skipped = false;
240 }
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100241 // If output0 does not have size larger than one byte, we can not test with insufficient buffer.
Lev Proleev0d4ba3f2019-10-02 17:32:06 +0100242 if (testConfig.outputType == OutputType::INSUFFICIENT &&
243 !isOutputSizeGreaterThanOne(testModel, 0)) {
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100244 return;
245 }
246
Xusong Wangb345a462019-11-27 12:46:48 -0800247 V1_0::Request request10 = createRequest(testModel);
Lev Proleev0d4ba3f2019-10-02 17:32:06 +0100248 if (testConfig.outputType == OutputType::INSUFFICIENT) {
Xusong Wangb345a462019-11-27 12:46:48 -0800249 makeOutputInsufficientSize(/*outputIndex=*/0, &request10);
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100250 }
Xusong Wangb345a462019-11-27 12:46:48 -0800251 Request request = nn::convertToV1_3(request10);
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100252
253 ErrorStatus executionStatus;
254 hidl_vec<OutputShape> outputShapes;
255 Timing timing;
Lev Proleev0d4ba3f2019-10-02 17:32:06 +0100256 switch (testConfig.executor) {
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100257 case Executor::ASYNC: {
258 SCOPED_TRACE("asynchronous");
259
260 // launch execution
261 sp<ExecutionCallback> executionCallback = new ExecutionCallback();
Lev Proleev0d4ba3f2019-10-02 17:32:06 +0100262 Return<ErrorStatus> executionLaunchStatus = ExecutePreparedModel(
263 preparedModel, request, testConfig.measureTiming, executionCallback);
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100264 ASSERT_TRUE(executionLaunchStatus.isOk());
265 EXPECT_EQ(ErrorStatus::NONE, static_cast<ErrorStatus>(executionLaunchStatus));
266
267 // retrieve execution status
268 executionCallback->wait();
269 executionStatus = executionCallback->getStatus();
270 outputShapes = executionCallback->getOutputShapes();
271 timing = executionCallback->getTiming();
272
273 break;
274 }
275 case Executor::SYNC: {
276 SCOPED_TRACE("synchronous");
277
278 // execute
Lev Proleev0d4ba3f2019-10-02 17:32:06 +0100279 Return<ErrorStatus> executionReturnStatus = ExecutePreparedModel(
280 preparedModel, request, testConfig.measureTiming, &outputShapes, &timing);
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100281 ASSERT_TRUE(executionReturnStatus.isOk());
282 executionStatus = static_cast<ErrorStatus>(executionReturnStatus);
283
284 break;
285 }
286 case Executor::BURST: {
Xusong Wangb345a462019-11-27 12:46:48 -0800287 // TODO(butlermichael): Check if we need to test burst in V1_3 if the interface remains
288 // V1_2.
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100289 SCOPED_TRACE("burst");
290
291 // create burst
292 const std::shared_ptr<::android::nn::ExecutionBurstController> controller =
293 CreateBurst(preparedModel);
294 ASSERT_NE(nullptr, controller.get());
295
296 // create memory keys
Xusong Wangb345a462019-11-27 12:46:48 -0800297 std::vector<intptr_t> keys(request10.pools.size());
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100298 for (size_t i = 0; i < keys.size(); ++i) {
Xusong Wangb345a462019-11-27 12:46:48 -0800299 keys[i] = reinterpret_cast<intptr_t>(&request10.pools[i]);
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100300 }
301
302 // execute burst
Michael Butler648ada52019-07-25 17:22:11 -0700303 int n;
304 std::tie(n, outputShapes, timing, std::ignore) =
Xusong Wangb345a462019-11-27 12:46:48 -0800305 controller->compute(request10, testConfig.measureTiming, keys);
Michael Butler648ada52019-07-25 17:22:11 -0700306 executionStatus = nn::convertResultCodeToErrorStatus(n);
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100307
308 break;
309 }
310 }
311
Lev Proleev0d4ba3f2019-10-02 17:32:06 +0100312 if (testConfig.outputType != OutputType::FULLY_SPECIFIED &&
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100313 executionStatus == ErrorStatus::GENERAL_FAILURE) {
Lev Proleev9226c1e2019-10-03 14:43:18 +0100314 if (skipped != nullptr) {
315 *skipped = true;
316 }
317 if (!testConfig.reportSkipping) {
318 return;
319 }
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100320 LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
321 "execute model that it does not support.";
322 std::cout << "[ ] Early termination of test because vendor service cannot "
323 "execute model that it does not support."
324 << std::endl;
325 GTEST_SKIP();
326 }
Lev Proleev0d4ba3f2019-10-02 17:32:06 +0100327 if (testConfig.measureTiming == MeasureTiming::NO) {
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100328 EXPECT_EQ(UINT64_MAX, timing.timeOnDevice);
329 EXPECT_EQ(UINT64_MAX, timing.timeInDriver);
330 } else {
331 if (timing.timeOnDevice != UINT64_MAX && timing.timeInDriver != UINT64_MAX) {
332 EXPECT_LE(timing.timeOnDevice, timing.timeInDriver);
333 }
334 }
335
Lev Proleev0d4ba3f2019-10-02 17:32:06 +0100336 switch (testConfig.outputType) {
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100337 case OutputType::FULLY_SPECIFIED:
338 // If the model output operands are fully specified, outputShapes must be either
339 // either empty, or have the same number of elements as the number of outputs.
340 ASSERT_EQ(ErrorStatus::NONE, executionStatus);
341 ASSERT_TRUE(outputShapes.size() == 0 ||
342 outputShapes.size() == testModel.outputIndexes.size());
343 break;
344 case OutputType::UNSPECIFIED:
345 // If the model output operands are not fully specified, outputShapes must have
346 // the same number of elements as the number of outputs.
347 ASSERT_EQ(ErrorStatus::NONE, executionStatus);
348 ASSERT_EQ(outputShapes.size(), testModel.outputIndexes.size());
349 break;
350 case OutputType::INSUFFICIENT:
351 ASSERT_EQ(ErrorStatus::OUTPUT_INSUFFICIENT_SIZE, executionStatus);
352 ASSERT_EQ(outputShapes.size(), testModel.outputIndexes.size());
353 ASSERT_FALSE(outputShapes[0].isSufficient);
354 return;
355 }
356
357 // Go through all outputs, check returned output shapes.
358 for (uint32_t i = 0; i < outputShapes.size(); i++) {
359 EXPECT_TRUE(outputShapes[i].isSufficient);
360 const auto& expect = testModel.operands[testModel.outputIndexes[i]].dimensions;
361 const std::vector<uint32_t> actual = outputShapes[i].dimensions;
362 EXPECT_EQ(expect, actual);
363 }
364
365 // Retrieve execution results.
Xusong Wangb345a462019-11-27 12:46:48 -0800366 const std::vector<TestBuffer> outputs = getOutputBuffers(request10);
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100367
368 // We want "close-enough" results.
369 checkResults(testModel, outputs);
370}
371
372void EvaluatePreparedModel(const sp<IPreparedModel>& preparedModel, const TestModel& testModel,
Lev Proleev9226c1e2019-10-03 14:43:18 +0100373 TestKind testKind) {
Lev Proleev56cda832019-12-05 14:49:47 +0000374 std::vector<OutputType> outputTypesList;
375 std::vector<MeasureTiming> measureTimingList;
376 std::vector<Executor> executorList;
Lev Proleev0d4ba3f2019-10-02 17:32:06 +0100377
Lev Proleev9226c1e2019-10-03 14:43:18 +0100378 switch (testKind) {
379 case TestKind::GENERAL: {
380 outputTypesList = {OutputType::FULLY_SPECIFIED};
381 measureTimingList = {MeasureTiming::NO, MeasureTiming::YES};
382 executorList = {Executor::ASYNC, Executor::SYNC, Executor::BURST};
383 } break;
384 case TestKind::DYNAMIC_SHAPE: {
385 outputTypesList = {OutputType::UNSPECIFIED, OutputType::INSUFFICIENT};
386 measureTimingList = {MeasureTiming::NO, MeasureTiming::YES};
387 executorList = {Executor::ASYNC, Executor::SYNC, Executor::BURST};
388 } break;
389 case TestKind::QUANTIZATION_COUPLING: {
390 LOG(FATAL) << "Wrong TestKind for EvaluatePreparedModel";
391 return;
392 } break;
Lev Proleev0d4ba3f2019-10-02 17:32:06 +0100393 }
394
395 for (const OutputType outputType : outputTypesList) {
396 for (const MeasureTiming measureTiming : measureTimingList) {
397 for (const Executor executor : executorList) {
Lev Proleev9226c1e2019-10-03 14:43:18 +0100398 const TestConfig testConfig(executor, measureTiming, outputType);
Lev Proleev0d4ba3f2019-10-02 17:32:06 +0100399 EvaluatePreparedModel(preparedModel, testModel, testConfig);
400 }
401 }
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100402 }
403}
404
Lev Proleev9226c1e2019-10-03 14:43:18 +0100405void EvaluatePreparedCoupledModels(const sp<IPreparedModel>& preparedModel,
406 const TestModel& testModel,
407 const sp<IPreparedModel>& preparedCoupledModel,
408 const TestModel& coupledModel) {
Lev Proleev56cda832019-12-05 14:49:47 +0000409 const std::vector<OutputType> outputTypesList = {OutputType::FULLY_SPECIFIED};
410 const std::vector<MeasureTiming> measureTimingList = {MeasureTiming::NO, MeasureTiming::YES};
411 const std::vector<Executor> executorList = {Executor::ASYNC, Executor::SYNC, Executor::BURST};
Lev Proleev9226c1e2019-10-03 14:43:18 +0100412
413 for (const OutputType outputType : outputTypesList) {
414 for (const MeasureTiming measureTiming : measureTimingList) {
415 for (const Executor executor : executorList) {
416 const TestConfig testConfig(executor, measureTiming, outputType,
417 /*reportSkipping=*/false);
418 bool baseSkipped = false;
419 EvaluatePreparedModel(preparedModel, testModel, testConfig, &baseSkipped);
420 bool coupledSkipped = false;
421 EvaluatePreparedModel(preparedCoupledModel, coupledModel, testConfig,
422 &coupledSkipped);
423 ASSERT_EQ(baseSkipped, coupledSkipped);
424 if (baseSkipped) {
425 LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
426 "execute model that it does not support.";
427 std::cout << "[ ] Early termination of test because vendor service "
428 "cannot "
429 "execute model that it does not support."
430 << std::endl;
431 GTEST_SKIP();
432 }
433 }
434 }
435 }
436}
437
438void Execute(const sp<IDevice>& device, const TestModel& testModel, TestKind testKind) {
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100439 Model model = createModel(testModel);
Lev Proleev9226c1e2019-10-03 14:43:18 +0100440 if (testKind == TestKind::DYNAMIC_SHAPE) {
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100441 makeOutputDimensionsUnspecified(&model);
442 }
443
444 sp<IPreparedModel> preparedModel;
Lev Proleev9226c1e2019-10-03 14:43:18 +0100445 switch (testKind) {
446 case TestKind::GENERAL: {
447 createPreparedModel(device, model, &preparedModel);
448 if (preparedModel == nullptr) return;
449 EvaluatePreparedModel(preparedModel, testModel, TestKind::GENERAL);
450 } break;
451 case TestKind::DYNAMIC_SHAPE: {
452 createPreparedModel(device, model, &preparedModel);
453 if (preparedModel == nullptr) return;
454 EvaluatePreparedModel(preparedModel, testModel, TestKind::DYNAMIC_SHAPE);
455 } break;
456 case TestKind::QUANTIZATION_COUPLING: {
Lev Proleev673fdcf2020-01-02 18:22:30 +0000457 ASSERT_TRUE(testModel.hasQuant8CoupledOperands());
Lev Proleev9226c1e2019-10-03 14:43:18 +0100458 createPreparedModel(device, model, &preparedModel, /*reportSkipping*/ false);
459 TestModel signedQuantizedModel = convertQuant8AsymmOperandsToSigned(testModel);
460 sp<IPreparedModel> preparedCoupledModel;
461 createPreparedModel(device, createModel(signedQuantizedModel), &preparedCoupledModel,
462 /*reportSkipping*/ false);
463 // If we couldn't prepare a model with unsigned quantization, we must
464 // fail to prepare a model with signed quantization as well.
465 if (preparedModel == nullptr) {
466 ASSERT_EQ(preparedCoupledModel, nullptr);
467 // If we failed to prepare both of the models, we can safely skip
468 // the test.
469 LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
470 "prepare model that it does not support.";
471 std::cout
472 << "[ ] Early termination of test because vendor service cannot "
473 "prepare model that it does not support."
474 << std::endl;
475 GTEST_SKIP();
476 }
477 ASSERT_NE(preparedCoupledModel, nullptr);
478 EvaluatePreparedCoupledModels(preparedModel, testModel, preparedCoupledModel,
479 signedQuantizedModel);
480 } break;
481 }
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100482}
483
484void GeneratedTestBase::SetUp() {
485 testing::TestWithParam<GeneratedTestParam>::SetUp();
486 ASSERT_NE(kDevice, nullptr);
487}
488
489std::vector<NamedModel> getNamedModels(const FilterFn& filter) {
490 return TestModelManager::get().getTestModels(filter);
491}
492
493std::string printGeneratedTest(const testing::TestParamInfo<GeneratedTestParam>& info) {
494 const auto& [namedDevice, namedModel] = info.param;
495 return gtestCompliantName(getName(namedDevice) + "_" + getName(namedModel));
496}
497
498// Tag for the generated tests
499class GeneratedTest : public GeneratedTestBase {};
500
501// Tag for the dynamic output shape tests
502class DynamicOutputShapeTest : public GeneratedTest {};
503
Lev Proleev9226c1e2019-10-03 14:43:18 +0100504// Tag for the dynamic output shape tests
Lev Proleev3c68b342020-01-09 16:37:28 +0000505class QuantizationCouplingTest : public GeneratedTest {};
Lev Proleev9226c1e2019-10-03 14:43:18 +0100506
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100507TEST_P(GeneratedTest, Test) {
Lev Proleev9226c1e2019-10-03 14:43:18 +0100508 Execute(kDevice, kTestModel, /*testKind=*/TestKind::GENERAL);
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100509}
510
511TEST_P(DynamicOutputShapeTest, Test) {
Lev Proleev9226c1e2019-10-03 14:43:18 +0100512 Execute(kDevice, kTestModel, /*testKind=*/TestKind::DYNAMIC_SHAPE);
513}
514
Lev Proleev3c68b342020-01-09 16:37:28 +0000515TEST_P(QuantizationCouplingTest, Test) {
Lev Proleev9226c1e2019-10-03 14:43:18 +0100516 Execute(kDevice, kTestModel, /*testKind=*/TestKind::QUANTIZATION_COUPLING);
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100517}
518
519INSTANTIATE_GENERATED_TEST(GeneratedTest,
520 [](const TestModel& testModel) { return !testModel.expectFailure; });
521
522INSTANTIATE_GENERATED_TEST(DynamicOutputShapeTest,
523 [](const TestModel& testModel) { return !testModel.expectFailure; });
524
Lev Proleev3c68b342020-01-09 16:37:28 +0000525INSTANTIATE_GENERATED_TEST(QuantizationCouplingTest, [](const TestModel& testModel) {
Lev Proleev673fdcf2020-01-02 18:22:30 +0000526 return testModel.hasQuant8CoupledOperands() && testModel.operations.size() == 1;
Lev Proleev9226c1e2019-10-03 14:43:18 +0100527});
528
Lev Proleev26d1bc82019-08-30 11:57:18 +0100529} // namespace android::hardware::neuralnetworks::V1_3::vts::functional