blob: 65c425ee1e127dfd62ab0a83d755c7ddefae7f5e [file] [log] [blame]
I-Jui (Ray) Sung2c4e1362017-09-06 02:15:54 -07001/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Michael Butlercf22a572017-09-22 13:26:12 -070017#include "Callbacks.h"
I-Jui (Ray) Sung2c4e1362017-09-06 02:15:54 -070018#include "TestHarness.h"
Miao Wanga2d04c82018-02-05 17:26:54 -080019#include "Utils.h"
I-Jui (Ray) Sung2c4e1362017-09-06 02:15:54 -070020
21#include <android-base/logging.h>
Miao Wanga2d04c82018-02-05 17:26:54 -080022#include <android/hardware/neuralnetworks/1.0/IDevice.h>
23#include <android/hardware/neuralnetworks/1.0/IExecutionCallback.h>
24#include <android/hardware/neuralnetworks/1.0/IPreparedModel.h>
25#include <android/hardware/neuralnetworks/1.0/IPreparedModelCallback.h>
26#include <android/hardware/neuralnetworks/1.0/types.h>
Xusong Wangb5cb8f72018-10-31 08:43:12 -070027#include <android/hardware/neuralnetworks/1.1/IDevice.h>
28#include <android/hardware/neuralnetworks/1.2/IDevice.h>
29#include <android/hardware/neuralnetworks/1.2/IExecutionCallback.h>
30#include <android/hardware/neuralnetworks/1.2/IPreparedModel.h>
31#include <android/hardware/neuralnetworks/1.2/IPreparedModelCallback.h>
Miao Wanga2d04c82018-02-05 17:26:54 -080032#include <android/hidl/allocator/1.0/IAllocator.h>
I-Jui (Ray) Sung2c4e1362017-09-06 02:15:54 -070033#include <android/hidl/memory/1.0/IMemory.h>
34#include <hidlmemory/mapping.h>
Michael Butler0897ab32017-10-04 02:38:42 -070035#include <iostream>
I-Jui (Ray) Sung2c4e1362017-09-06 02:15:54 -070036
37namespace android {
38namespace hardware {
39namespace neuralnetworks {
I-Jui (Ray) Sung2c4e1362017-09-06 02:15:54 -070040
41namespace generated_tests {
Xusong Wangb5cb8f72018-10-31 08:43:12 -070042using ::android::hardware::neuralnetworks::V1_2::implementation::ExecutionCallback;
43using ::android::hardware::neuralnetworks::V1_2::implementation::PreparedModelCallback;
Slava Shklyaev9e3fad12018-11-30 17:55:12 +000044using ::test_helper::bool8;
Michael K. Sanders941d61a2018-10-19 14:39:09 +010045using ::test_helper::compare;
46using ::test_helper::expectMultinomialDistributionWithinTolerance;
Mika Raentode166942018-04-17 16:49:50 +010047using ::test_helper::filter;
48using ::test_helper::for_all;
49using ::test_helper::for_each;
Michael K. Sanders941d61a2018-10-19 14:39:09 +010050using ::test_helper::MixedTyped;
51using ::test_helper::MixedTypedExample;
Michael K. Sanders941d61a2018-10-19 14:39:09 +010052using ::test_helper::resize_accordingly;
I-Jui (Ray) Sungf6b85502017-09-20 13:45:50 -070053
I-Jui (Ray) Sung5bf4edf2017-10-06 13:22:39 -070054template <typename T>
Xusong Wanga3165812018-11-19 18:26:08 -080055void copy_back_(std::map<int, std::vector<T>>* dst, const std::vector<RequestArgument>& ra,
56 char* src) {
57 for_each<T>(*dst, [&ra, src](int index, std::vector<T>& m) {
I-Jui (Ray) Sung5bf4edf2017-10-06 13:22:39 -070058 ASSERT_EQ(m.size(), ra[index].location.length / sizeof(T));
I-Jui (Ray) Sungf6b85502017-09-20 13:45:50 -070059 char* begin = src + ra[index].location.offset;
60 memcpy(m.data(), begin, ra[index].location.length);
61 });
62}
63
64void copy_back(MixedTyped* dst, const std::vector<RequestArgument>& ra, char* src) {
Xusong Wanga3165812018-11-19 18:26:08 -080065 copy_back_(&dst->float32Operands, ra, src);
66 copy_back_(&dst->int32Operands, ra, src);
Xusong Wangd49f6652019-01-16 18:32:24 -080067 copy_back_(&dst->quant8AsymmOperands, ra, src);
68 copy_back_(&dst->quant16SymmOperands, ra, src);
Xusong Wanga3165812018-11-19 18:26:08 -080069 copy_back_(&dst->float16Operands, ra, src);
70 copy_back_(&dst->bool8Operands, ra, src);
71 copy_back_(&dst->quant8ChannelOperands, ra, src);
Xusong Wangd49f6652019-01-16 18:32:24 -080072 copy_back_(&dst->quant16AsymmOperands, ra, src);
73 static_assert(8 == MixedTyped::kNumTypes,
Lev Proleev9b490f42018-11-02 12:44:11 +000074 "Number of types in MixedTyped changed, but copy_back function wasn't updated");
I-Jui (Ray) Sungf6b85502017-09-20 13:45:50 -070075}
76
I-Jui (Ray) Sung2c4e1362017-09-06 02:15:54 -070077// Top level driver for models and examples generated by test_generator.py
78// Test driver for those generated from ml/nn/runtime/test/spec
Xusong Wangb5cb8f72018-10-31 08:43:12 -070079static Return<ErrorStatus> ExecutePreparedModel(sp<V1_0::IPreparedModel>& preparedModel,
David Grosse3013492019-01-23 14:01:52 -080080 const Request& request, MeasureTiming,
Xusong Wangb5cb8f72018-10-31 08:43:12 -070081 sp<ExecutionCallback>& callback) {
82 return preparedModel->execute(request, callback);
83}
84static Return<ErrorStatus> ExecutePreparedModel(sp<V1_2::IPreparedModel>& preparedModel,
David Grosse3013492019-01-23 14:01:52 -080085 const Request& request, MeasureTiming measure,
Xusong Wangb5cb8f72018-10-31 08:43:12 -070086 sp<ExecutionCallback>& callback) {
David Grosse3013492019-01-23 14:01:52 -080087 return preparedModel->execute_1_2(request, measure, callback);
Xusong Wangb5cb8f72018-10-31 08:43:12 -070088}
Xusong Wang187c5972018-11-07 09:33:59 -080089static Return<ErrorStatus> ExecutePreparedModel(sp<V1_0::IPreparedModel>&, const Request&,
David Grosse3013492019-01-23 14:01:52 -080090 MeasureTiming, hidl_vec<OutputShape>*, Timing*) {
David Gross49e41672018-12-21 11:20:26 -080091 ADD_FAILURE() << "asking for synchronous execution at V1_0";
92 return ErrorStatus::GENERAL_FAILURE;
93}
94static Return<ErrorStatus> ExecutePreparedModel(sp<V1_2::IPreparedModel>& preparedModel,
David Grosse3013492019-01-23 14:01:52 -080095 const Request& request, MeasureTiming measure,
96 hidl_vec<OutputShape>* outputShapes,
97 Timing* timing) {
Xusong Wang187c5972018-11-07 09:33:59 -080098 ErrorStatus result;
99 Return<void> ret = preparedModel->executeSynchronously(
David Grosse3013492019-01-23 14:01:52 -0800100 request, measure,
101 [&result, outputShapes, timing](ErrorStatus error, const hidl_vec<OutputShape>& shapes,
102 const Timing& time) {
103 result = error;
104 *outputShapes = shapes;
105 *timing = time;
106 });
Xusong Wang187c5972018-11-07 09:33:59 -0800107 if (!ret.isOk()) {
108 return ErrorStatus::GENERAL_FAILURE;
109 }
110 return result;
David Gross49e41672018-12-21 11:20:26 -0800111}
112enum class Synchronously { NO, YES };
113const float kDefaultAtol = 1e-5f;
114const float kDefaultRtol = 1e-5f;
Xusong Wangb5cb8f72018-10-31 08:43:12 -0700115template <typename T_IPreparedModel>
116void EvaluatePreparedModel(sp<T_IPreparedModel>& preparedModel, std::function<bool(int)> is_ignored,
Michael K. Sandersefa4c812018-10-30 14:44:48 +0000117 const std::vector<MixedTypedExample>& examples,
David Grosse3013492019-01-23 14:01:52 -0800118 bool hasRelaxedFloat32Model, float fpAtol, float fpRtol,
119 Synchronously sync, MeasureTiming measure, bool testDynamicOutputShape) {
I-Jui (Ray) Sung2c4e1362017-09-06 02:15:54 -0700120 const uint32_t INPUT = 0;
121 const uint32_t OUTPUT = 1;
122
123 int example_no = 1;
124 for (auto& example : examples) {
125 SCOPED_TRACE(example_no++);
Michael K. Sanders941d61a2018-10-19 14:39:09 +0100126 const MixedTyped& inputs = example.operands.first;
127 const MixedTyped& golden = example.operands.second;
I-Jui (Ray) Sung2c4e1362017-09-06 02:15:54 -0700128
Xusong Wanga3165812018-11-19 18:26:08 -0800129 const bool hasFloat16Inputs = !inputs.float16Operands.empty();
Michael K. Sandersefa4c812018-10-30 14:44:48 +0000130 if (hasRelaxedFloat32Model || hasFloat16Inputs) {
131 // TODO: Adjust the error limit based on testing.
132 // If in relaxed mode, set the absolute tolerance to be 5ULP of FP16.
133 fpAtol = 5.0f * 0.0009765625f;
134 // Set the relative tolerance to be 5ULP of the corresponding FP precision.
135 fpRtol = 5.0f * 0.0009765625f;
136 }
137
I-Jui (Ray) Sung2c4e1362017-09-06 02:15:54 -0700138 std::vector<RequestArgument> inputs_info, outputs_info;
139 uint32_t inputSize = 0, outputSize = 0;
I-Jui (Ray) Sung2c4e1362017-09-06 02:15:54 -0700140 // This function only partially specifies the metadata (vector of RequestArguments).
141 // The contents are copied over below.
142 for_all(inputs, [&inputs_info, &inputSize](int index, auto, auto s) {
143 if (inputs_info.size() <= static_cast<size_t>(index)) inputs_info.resize(index + 1);
144 RequestArgument arg = {
145 .location = {.poolIndex = INPUT, .offset = 0, .length = static_cast<uint32_t>(s)},
146 .dimensions = {},
147 };
I-Jui (Ray) Sung959cd782017-10-04 20:49:57 -0700148 RequestArgument arg_empty = {
149 .hasNoValue = true,
150 };
151 inputs_info[index] = s ? arg : arg_empty;
I-Jui (Ray) Sung2c4e1362017-09-06 02:15:54 -0700152 inputSize += s;
153 });
154 // Compute offset for inputs 1 and so on
155 {
156 size_t offset = 0;
157 for (auto& i : inputs_info) {
I-Jui (Ray) Sung959cd782017-10-04 20:49:57 -0700158 if (!i.hasNoValue) i.location.offset = offset;
I-Jui (Ray) Sung2c4e1362017-09-06 02:15:54 -0700159 offset += i.location.length;
160 }
161 }
162
163 MixedTyped test; // holding test results
164
165 // Go through all outputs, initialize RequestArgument descriptors
I-Jui (Ray) Sungf6b85502017-09-20 13:45:50 -0700166 resize_accordingly(golden, test);
I-Jui (Ray) Sung2c4e1362017-09-06 02:15:54 -0700167 for_all(golden, [&outputs_info, &outputSize](int index, auto, auto s) {
168 if (outputs_info.size() <= static_cast<size_t>(index)) outputs_info.resize(index + 1);
169 RequestArgument arg = {
170 .location = {.poolIndex = OUTPUT, .offset = 0, .length = static_cast<uint32_t>(s)},
171 .dimensions = {},
172 };
173 outputs_info[index] = arg;
174 outputSize += s;
175 });
176 // Compute offset for outputs 1 and so on
177 {
178 size_t offset = 0;
179 for (auto& i : outputs_info) {
180 i.location.offset = offset;
181 offset += i.location.length;
182 }
183 }
Miao Wanga2d04c82018-02-05 17:26:54 -0800184 std::vector<hidl_memory> pools = {nn::allocateSharedMemory(inputSize),
185 nn::allocateSharedMemory(outputSize)};
I-Jui (Ray) Sung2c4e1362017-09-06 02:15:54 -0700186 ASSERT_NE(0ull, pools[INPUT].size());
187 ASSERT_NE(0ull, pools[OUTPUT].size());
188
189 // load data
190 sp<IMemory> inputMemory = mapMemory(pools[INPUT]);
191 sp<IMemory> outputMemory = mapMemory(pools[OUTPUT]);
192 ASSERT_NE(nullptr, inputMemory.get());
193 ASSERT_NE(nullptr, outputMemory.get());
194 char* inputPtr = reinterpret_cast<char*>(static_cast<void*>(inputMemory->getPointer()));
195 char* outputPtr = reinterpret_cast<char*>(static_cast<void*>(outputMemory->getPointer()));
196 ASSERT_NE(nullptr, inputPtr);
197 ASSERT_NE(nullptr, outputPtr);
198 inputMemory->update();
199 outputMemory->update();
200
201 // Go through all inputs, copy the values
202 for_all(inputs, [&inputs_info, inputPtr](int index, auto p, auto s) {
203 char* begin = (char*)p;
204 char* end = begin + s;
205 // TODO: handle more than one input
206 std::copy(begin, end, inputPtr + inputs_info[index].location.offset);
207 });
208
209 inputMemory->commit();
210 outputMemory->commit();
Michael Butlercf22a572017-09-22 13:26:12 -0700211
Xusong Wang187c5972018-11-07 09:33:59 -0800212 ErrorStatus executionStatus;
213 hidl_vec<OutputShape> outputShapes;
David Grosse3013492019-01-23 14:01:52 -0800214 Timing timing;
David Gross49e41672018-12-21 11:20:26 -0800215 if (sync == Synchronously::NO) {
216 SCOPED_TRACE("asynchronous");
Michael Butlercf22a572017-09-22 13:26:12 -0700217
David Gross49e41672018-12-21 11:20:26 -0800218 // launch execution
219 sp<ExecutionCallback> executionCallback = new ExecutionCallback();
220 ASSERT_NE(nullptr, executionCallback.get());
221 Return<ErrorStatus> executionLaunchStatus = ExecutePreparedModel(
David Grosse3013492019-01-23 14:01:52 -0800222 preparedModel, {.inputs = inputs_info, .outputs = outputs_info, .pools = pools},
223 measure, executionCallback);
David Gross49e41672018-12-21 11:20:26 -0800224 ASSERT_TRUE(executionLaunchStatus.isOk());
225 EXPECT_EQ(ErrorStatus::NONE, static_cast<ErrorStatus>(executionLaunchStatus));
226
227 // retrieve execution status
228 executionCallback->wait();
Xusong Wang187c5972018-11-07 09:33:59 -0800229 executionStatus = executionCallback->getStatus();
230 outputShapes = executionCallback->getOutputShapes();
David Grosse3013492019-01-23 14:01:52 -0800231 timing = executionCallback->getTiming();
David Gross49e41672018-12-21 11:20:26 -0800232 } else {
233 SCOPED_TRACE("synchronous");
234
235 // execute
Xusong Wang187c5972018-11-07 09:33:59 -0800236 Return<ErrorStatus> executionReturnStatus = ExecutePreparedModel(
David Grosse3013492019-01-23 14:01:52 -0800237 preparedModel, {.inputs = inputs_info, .outputs = outputs_info, .pools = pools},
238 measure, &outputShapes, &timing);
Xusong Wang187c5972018-11-07 09:33:59 -0800239 ASSERT_TRUE(executionReturnStatus.isOk());
240 executionStatus = static_cast<ErrorStatus>(executionReturnStatus);
David Gross49e41672018-12-21 11:20:26 -0800241 }
I-Jui (Ray) Sung2c4e1362017-09-06 02:15:54 -0700242
Xusong Wanga3165812018-11-19 18:26:08 -0800243 if (testDynamicOutputShape && executionStatus != ErrorStatus::NONE) {
244 LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
245 "execute model that it does not support.";
246 std::cout << "[ ] Early termination of test because vendor service cannot "
247 "execute model that it does not support."
248 << std::endl;
249 return;
250 }
Xusong Wang187c5972018-11-07 09:33:59 -0800251 ASSERT_EQ(ErrorStatus::NONE, executionStatus);
David Grosse3013492019-01-23 14:01:52 -0800252 if (measure == MeasureTiming::NO) {
253 EXPECT_EQ(UINT64_MAX, timing.timeOnDevice);
254 EXPECT_EQ(UINT64_MAX, timing.timeInDriver);
255 } else {
256 if (timing.timeOnDevice != UINT64_MAX && timing.timeInDriver != UINT64_MAX) {
257 EXPECT_LE(timing.timeOnDevice, timing.timeInDriver);
258 }
259 }
Xusong Wanga3165812018-11-19 18:26:08 -0800260
261 // Go through all outputs, overwrite output dimensions with returned output shapes
262 if (testDynamicOutputShape) {
263 ASSERT_NE(outputShapes.size(), 0);
264 for_each<uint32_t>(test.operandDimensions,
265 [&outputShapes](int idx, std::vector<uint32_t>& dim) {
266 dim = outputShapes[idx].dimensions;
267 });
268 }
Xusong Wang187c5972018-11-07 09:33:59 -0800269
I-Jui (Ray) Sung2c4e1362017-09-06 02:15:54 -0700270 // validate results
271 outputMemory->read();
I-Jui (Ray) Sungf6b85502017-09-20 13:45:50 -0700272 copy_back(&test, outputs_info, outputPtr);
I-Jui (Ray) Sung2c4e1362017-09-06 02:15:54 -0700273 outputMemory->commit();
I-Jui (Ray) Sung7d765bd2017-09-13 18:47:12 -0700274 // Filter out don't cares
I-Jui (Ray) Sung5bf4edf2017-10-06 13:22:39 -0700275 MixedTyped filtered_golden = filter(golden, is_ignored);
276 MixedTyped filtered_test = filter(test, is_ignored);
I-Jui (Ray) Sung7d765bd2017-09-13 18:47:12 -0700277
I-Jui (Ray) Sung2c4e1362017-09-06 02:15:54 -0700278 // We want "close-enough" results for float
Xusong Wang10d77e42018-08-28 16:50:01 -0700279 compare(filtered_golden, filtered_test, fpAtol, fpRtol);
Michael K. Sanders941d61a2018-10-19 14:39:09 +0100280
281 if (example.expectedMultinomialDistributionTolerance > 0) {
282 expectMultinomialDistributionWithinTolerance(test, example);
283 }
I-Jui (Ray) Sung2c4e1362017-09-06 02:15:54 -0700284 }
285}
David Gross49e41672018-12-21 11:20:26 -0800286template <typename T_IPreparedModel>
287void EvaluatePreparedModel(sp<T_IPreparedModel>& preparedModel, std::function<bool(int)> is_ignored,
288 const std::vector<MixedTypedExample>& examples,
David Grosse3013492019-01-23 14:01:52 -0800289 bool hasRelaxedFloat32Model, Synchronously sync, MeasureTiming measure,
Xusong Wanga3165812018-11-19 18:26:08 -0800290 bool testDynamicOutputShape) {
David Gross49e41672018-12-21 11:20:26 -0800291 EvaluatePreparedModel(preparedModel, is_ignored, examples, hasRelaxedFloat32Model, kDefaultAtol,
David Grosse3013492019-01-23 14:01:52 -0800292 kDefaultRtol, sync, measure, testDynamicOutputShape);
David Gross49e41672018-12-21 11:20:26 -0800293}
I-Jui (Ray) Sung2c4e1362017-09-06 02:15:54 -0700294
Xusong Wangb5cb8f72018-10-31 08:43:12 -0700295static void getPreparedModel(sp<PreparedModelCallback> callback,
296 sp<V1_0::IPreparedModel>* preparedModel) {
297 *preparedModel = callback->getPreparedModel();
298}
299static void getPreparedModel(sp<PreparedModelCallback> callback,
300 sp<V1_2::IPreparedModel>* preparedModel) {
301 sp<V1_0::IPreparedModel> preparedModelV1_0 = callback->getPreparedModel();
302 *preparedModel = V1_2::IPreparedModel::castFrom(preparedModelV1_0).withDefault(nullptr);
303}
304
Michael Butlerf76acd02018-03-22 16:37:57 -0700305void Execute(const sp<V1_0::IDevice>& device, std::function<V1_0::Model(void)> create_model,
Michael K. Sanders941d61a2018-10-19 14:39:09 +0100306 std::function<bool(int)> is_ignored, const std::vector<MixedTypedExample>& examples) {
Miao Wanga2d04c82018-02-05 17:26:54 -0800307 V1_0::Model model = create_model();
308
309 // see if service can handle model
310 bool fullySupportsModel = false;
Miao Wanga2d04c82018-02-05 17:26:54 -0800311 Return<void> supportedCall = device->getSupportedOperations(
Michael Butler4d5bb102018-02-26 15:24:46 -0800312 model, [&fullySupportsModel](ErrorStatus status, const hidl_vec<bool>& supported) {
313 ASSERT_EQ(ErrorStatus::NONE, status);
Miao Wanga2d04c82018-02-05 17:26:54 -0800314 ASSERT_NE(0ul, supported.size());
315 fullySupportsModel =
316 std::all_of(supported.begin(), supported.end(), [](bool valid) { return valid; });
317 });
318 ASSERT_TRUE(supportedCall.isOk());
Michael Butler4d5bb102018-02-26 15:24:46 -0800319
320 // launch prepare model
321 sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
322 ASSERT_NE(nullptr, preparedModelCallback.get());
Miao Wanga2d04c82018-02-05 17:26:54 -0800323 Return<ErrorStatus> prepareLaunchStatus = device->prepareModel(model, preparedModelCallback);
324 ASSERT_TRUE(prepareLaunchStatus.isOk());
Michael Butler4d5bb102018-02-26 15:24:46 -0800325 ASSERT_EQ(ErrorStatus::NONE, static_cast<ErrorStatus>(prepareLaunchStatus));
Miao Wanga2d04c82018-02-05 17:26:54 -0800326
327 // retrieve prepared model
328 preparedModelCallback->wait();
329 ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus();
Xusong Wangb5cb8f72018-10-31 08:43:12 -0700330 sp<V1_0::IPreparedModel> preparedModel;
331 getPreparedModel(preparedModelCallback, &preparedModel);
Miao Wanga2d04c82018-02-05 17:26:54 -0800332
333 // early termination if vendor service cannot fully prepare model
Michael Butler4d5bb102018-02-26 15:24:46 -0800334 if (!fullySupportsModel && prepareReturnStatus != ErrorStatus::NONE) {
Miao Wanga2d04c82018-02-05 17:26:54 -0800335 ASSERT_EQ(nullptr, preparedModel.get());
336 LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
337 "prepare model that it does not support.";
338 std::cout << "[ ] Early termination of test because vendor service cannot "
339 "prepare model that it does not support."
340 << std::endl;
341 return;
342 }
Michael Butler4d5bb102018-02-26 15:24:46 -0800343 EXPECT_EQ(ErrorStatus::NONE, prepareReturnStatus);
Miao Wanga2d04c82018-02-05 17:26:54 -0800344 ASSERT_NE(nullptr, preparedModel.get());
345
Xusong Wang10d77e42018-08-28 16:50:01 -0700346 float fpAtol = 1e-5f, fpRtol = 5.0f * 1.1920928955078125e-7f;
Michael K. Sandersefa4c812018-10-30 14:44:48 +0000347 EvaluatePreparedModel(preparedModel, is_ignored, examples,
Xusong Wanga3165812018-11-19 18:26:08 -0800348 /*hasRelaxedFloat32Model=*/false, fpAtol, fpRtol, Synchronously::NO,
David Grosse3013492019-01-23 14:01:52 -0800349 MeasureTiming::NO, /*testDynamicOutputShape=*/false);
Miao Wanga2d04c82018-02-05 17:26:54 -0800350}
351
Michael Butlerf76acd02018-03-22 16:37:57 -0700352void Execute(const sp<V1_1::IDevice>& device, std::function<V1_1::Model(void)> create_model,
Michael K. Sanders941d61a2018-10-19 14:39:09 +0100353 std::function<bool(int)> is_ignored, const std::vector<MixedTypedExample>& examples) {
Miao Wanga2d04c82018-02-05 17:26:54 -0800354 V1_1::Model model = create_model();
355
356 // see if service can handle model
357 bool fullySupportsModel = false;
Miao Wanga2d04c82018-02-05 17:26:54 -0800358 Return<void> supportedCall = device->getSupportedOperations_1_1(
Michael Butler4d5bb102018-02-26 15:24:46 -0800359 model, [&fullySupportsModel](ErrorStatus status, const hidl_vec<bool>& supported) {
360 ASSERT_EQ(ErrorStatus::NONE, status);
Miao Wanga2d04c82018-02-05 17:26:54 -0800361 ASSERT_NE(0ul, supported.size());
362 fullySupportsModel =
363 std::all_of(supported.begin(), supported.end(), [](bool valid) { return valid; });
364 });
365 ASSERT_TRUE(supportedCall.isOk());
Michael Butler4d5bb102018-02-26 15:24:46 -0800366
367 // launch prepare model
368 sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
369 ASSERT_NE(nullptr, preparedModelCallback.get());
Michael Butler2504c2f2018-04-11 16:30:09 -0700370 Return<ErrorStatus> prepareLaunchStatus = device->prepareModel_1_1(
371 model, ExecutionPreference::FAST_SINGLE_ANSWER, preparedModelCallback);
Miao Wanga2d04c82018-02-05 17:26:54 -0800372 ASSERT_TRUE(prepareLaunchStatus.isOk());
Michael Butler4d5bb102018-02-26 15:24:46 -0800373 ASSERT_EQ(ErrorStatus::NONE, static_cast<ErrorStatus>(prepareLaunchStatus));
Miao Wanga2d04c82018-02-05 17:26:54 -0800374
375 // retrieve prepared model
376 preparedModelCallback->wait();
377 ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus();
Xusong Wangb5cb8f72018-10-31 08:43:12 -0700378 sp<V1_0::IPreparedModel> preparedModel;
379 getPreparedModel(preparedModelCallback, &preparedModel);
Miao Wanga2d04c82018-02-05 17:26:54 -0800380
381 // early termination if vendor service cannot fully prepare model
Michael Butler4d5bb102018-02-26 15:24:46 -0800382 if (!fullySupportsModel && prepareReturnStatus != ErrorStatus::NONE) {
Miao Wanga2d04c82018-02-05 17:26:54 -0800383 ASSERT_EQ(nullptr, preparedModel.get());
384 LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
385 "prepare model that it does not support.";
386 std::cout << "[ ] Early termination of test because vendor service cannot "
387 "prepare model that it does not support."
388 << std::endl;
389 return;
390 }
Michael Butler4d5bb102018-02-26 15:24:46 -0800391 EXPECT_EQ(ErrorStatus::NONE, prepareReturnStatus);
Miao Wanga2d04c82018-02-05 17:26:54 -0800392 ASSERT_NE(nullptr, preparedModel.get());
393
Michael K. Sandersefa4c812018-10-30 14:44:48 +0000394 EvaluatePreparedModel(preparedModel, is_ignored, examples,
Xusong Wanga3165812018-11-19 18:26:08 -0800395 model.relaxComputationFloat32toFloat16, 1e-5f, 1e-5f, Synchronously::NO,
David Grosse3013492019-01-23 14:01:52 -0800396 MeasureTiming::NO, /*testDynamicOutputShape=*/false);
Miao Wanga2d04c82018-02-05 17:26:54 -0800397}
398
Slava Shklyaev871be942018-09-12 14:52:02 +0100399// TODO: Reduce code duplication.
400void Execute(const sp<V1_2::IDevice>& device, std::function<V1_2::Model(void)> create_model,
Xusong Wanga3165812018-11-19 18:26:08 -0800401 std::function<bool(int)> is_ignored, const std::vector<MixedTypedExample>& examples,
402 bool testDynamicOutputShape) {
Slava Shklyaev871be942018-09-12 14:52:02 +0100403 V1_2::Model model = create_model();
404
405 // see if service can handle model
406 bool fullySupportsModel = false;
407 Return<void> supportedCall = device->getSupportedOperations_1_2(
408 model, [&fullySupportsModel](ErrorStatus status, const hidl_vec<bool>& supported) {
409 ASSERT_EQ(ErrorStatus::NONE, status);
410 ASSERT_NE(0ul, supported.size());
411 fullySupportsModel =
412 std::all_of(supported.begin(), supported.end(), [](bool valid) { return valid; });
413 });
414 ASSERT_TRUE(supportedCall.isOk());
415
416 // launch prepare model
417 sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
418 ASSERT_NE(nullptr, preparedModelCallback.get());
419 Return<ErrorStatus> prepareLaunchStatus = device->prepareModel_1_2(
420 model, ExecutionPreference::FAST_SINGLE_ANSWER, preparedModelCallback);
421 ASSERT_TRUE(prepareLaunchStatus.isOk());
422 ASSERT_EQ(ErrorStatus::NONE, static_cast<ErrorStatus>(prepareLaunchStatus));
423
424 // retrieve prepared model
425 preparedModelCallback->wait();
426 ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus();
Xusong Wangb5cb8f72018-10-31 08:43:12 -0700427 sp<V1_2::IPreparedModel> preparedModel;
428 getPreparedModel(preparedModelCallback, &preparedModel);
Slava Shklyaev871be942018-09-12 14:52:02 +0100429
430 // early termination if vendor service cannot fully prepare model
431 if (!fullySupportsModel && prepareReturnStatus != ErrorStatus::NONE) {
432 ASSERT_EQ(nullptr, preparedModel.get());
433 LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
434 "prepare model that it does not support.";
435 std::cout << "[ ] Early termination of test because vendor service cannot "
436 "prepare model that it does not support."
437 << std::endl;
438 return;
439 }
440 EXPECT_EQ(ErrorStatus::NONE, prepareReturnStatus);
441 ASSERT_NE(nullptr, preparedModel.get());
442
Michael K. Sandersefa4c812018-10-30 14:44:48 +0000443 EvaluatePreparedModel(preparedModel, is_ignored, examples,
Xusong Wanga3165812018-11-19 18:26:08 -0800444 model.relaxComputationFloat32toFloat16, Synchronously::NO,
David Grosse3013492019-01-23 14:01:52 -0800445 MeasureTiming::NO, testDynamicOutputShape);
David Gross49e41672018-12-21 11:20:26 -0800446 EvaluatePreparedModel(preparedModel, is_ignored, examples,
Xusong Wanga3165812018-11-19 18:26:08 -0800447 model.relaxComputationFloat32toFloat16, Synchronously::YES,
David Grosse3013492019-01-23 14:01:52 -0800448 MeasureTiming::NO, testDynamicOutputShape);
449 EvaluatePreparedModel(preparedModel, is_ignored, examples,
450 model.relaxComputationFloat32toFloat16, Synchronously::NO,
451 MeasureTiming::YES, testDynamicOutputShape);
452 EvaluatePreparedModel(preparedModel, is_ignored, examples,
453 model.relaxComputationFloat32toFloat16, Synchronously::YES,
454 MeasureTiming::YES, testDynamicOutputShape);
Slava Shklyaev871be942018-09-12 14:52:02 +0100455}
456
I-Jui (Ray) Sung2c4e1362017-09-06 02:15:54 -0700457} // namespace generated_tests
458
I-Jui (Ray) Sung2c4e1362017-09-06 02:15:54 -0700459} // namespace neuralnetworks
460} // namespace hardware
461} // namespace android