blob: cb801eb4ad9577c0eb55b084afae6e0f62476d2b [file] [log] [blame]
Michael Butler20f28a22019-04-26 17:46:08 -07001/*
Michael Butler0a1ad962019-04-30 13:51:24 -07002 * Copyright (C) 2019 The Android Open Source Project
Michael Butler20f28a22019-04-26 17:46:08 -07003 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "neuralnetworks_hidl_hal_test"
18
19#include "VtsHalNeuralnetworks.h"
20
Slava Shklyaev1d6b4652019-05-14 14:15:14 +010021#include "1.2/Callbacks.h"
Michael Butler20f28a22019-04-26 17:46:08 -070022#include "ExecutionBurstController.h"
23#include "ExecutionBurstServer.h"
Xusong Wang9e2b97b2019-08-23 16:10:54 -070024#include "GeneratedTestHarness.h"
Michael Butler20f28a22019-04-26 17:46:08 -070025#include "TestHarness.h"
26#include "Utils.h"
27
28#include <android-base/logging.h>
Michael Butlerddb770f2019-05-02 18:16:13 -070029#include <cstring>
Michael Butler20f28a22019-04-26 17:46:08 -070030
31namespace android {
32namespace hardware {
33namespace neuralnetworks {
34namespace V1_2 {
35namespace vts {
36namespace functional {
37
38using ::android::nn::ExecutionBurstController;
39using ::android::nn::RequestChannelSender;
40using ::android::nn::ResultChannelReceiver;
41using ExecutionBurstCallback = ::android::nn::ExecutionBurstController::ExecutionBurstCallback;
42
Michael Butler0a1ad962019-04-30 13:51:24 -070043// This constant value represents the length of an FMQ that is large enough to
44// return a result from a burst execution for all of the generated test cases.
Michael Butler20f28a22019-04-26 17:46:08 -070045constexpr size_t kExecutionBurstChannelLength = 1024;
Michael Butler0a1ad962019-04-30 13:51:24 -070046
47// This constant value represents a length of an FMQ that is not large enough
48// to return a result from a burst execution for some of the generated test
49// cases.
Michael Butler20f28a22019-04-26 17:46:08 -070050constexpr size_t kExecutionBurstChannelSmallLength = 8;
51
52///////////////////////// UTILITY FUNCTIONS /////////////////////////
53
54static bool badTiming(Timing timing) {
55 return timing.timeOnDevice == UINT64_MAX && timing.timeInDriver == UINT64_MAX;
56}
57
58static void createBurst(const sp<IPreparedModel>& preparedModel, const sp<IBurstCallback>& callback,
59 std::unique_ptr<RequestChannelSender>* sender,
60 std::unique_ptr<ResultChannelReceiver>* receiver,
Michael Butler0a1ad962019-04-30 13:51:24 -070061 sp<IBurstContext>* context,
62 size_t resultChannelLength = kExecutionBurstChannelLength) {
Michael Butler20f28a22019-04-26 17:46:08 -070063 ASSERT_NE(nullptr, preparedModel.get());
64 ASSERT_NE(nullptr, sender);
65 ASSERT_NE(nullptr, receiver);
66 ASSERT_NE(nullptr, context);
67
68 // create FMQ objects
69 auto [fmqRequestChannel, fmqRequestDescriptor] =
70 RequestChannelSender::create(kExecutionBurstChannelLength, /*blocking=*/true);
71 auto [fmqResultChannel, fmqResultDescriptor] =
Michael Butler0a1ad962019-04-30 13:51:24 -070072 ResultChannelReceiver::create(resultChannelLength, /*blocking=*/true);
Michael Butler20f28a22019-04-26 17:46:08 -070073 ASSERT_NE(nullptr, fmqRequestChannel.get());
74 ASSERT_NE(nullptr, fmqResultChannel.get());
75 ASSERT_NE(nullptr, fmqRequestDescriptor);
76 ASSERT_NE(nullptr, fmqResultDescriptor);
77
78 // configure burst
79 ErrorStatus errorStatus;
80 sp<IBurstContext> burstContext;
81 const Return<void> ret = preparedModel->configureExecutionBurst(
82 callback, *fmqRequestDescriptor, *fmqResultDescriptor,
83 [&errorStatus, &burstContext](ErrorStatus status, const sp<IBurstContext>& context) {
84 errorStatus = status;
85 burstContext = context;
86 });
87 ASSERT_TRUE(ret.isOk());
88 ASSERT_EQ(ErrorStatus::NONE, errorStatus);
89 ASSERT_NE(nullptr, burstContext.get());
90
91 // return values
92 *sender = std::move(fmqRequestChannel);
93 *receiver = std::move(fmqResultChannel);
94 *context = burstContext;
95}
96
97static void createBurstWithResultChannelLength(
Michael Butler0a1ad962019-04-30 13:51:24 -070098 const sp<IPreparedModel>& preparedModel, size_t resultChannelLength,
99 std::shared_ptr<ExecutionBurstController>* controller) {
Michael Butler20f28a22019-04-26 17:46:08 -0700100 ASSERT_NE(nullptr, preparedModel.get());
101 ASSERT_NE(nullptr, controller);
102
103 // create FMQ objects
Michael Butler0a1ad962019-04-30 13:51:24 -0700104 std::unique_ptr<RequestChannelSender> sender;
105 std::unique_ptr<ResultChannelReceiver> receiver;
Michael Butler20f28a22019-04-26 17:46:08 -0700106 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
Michael Butler0a1ad962019-04-30 13:51:24 -0700107 sp<IBurstContext> context;
108 ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context,
109 resultChannelLength));
110 ASSERT_NE(nullptr, sender.get());
111 ASSERT_NE(nullptr, receiver.get());
112 ASSERT_NE(nullptr, context.get());
Michael Butler20f28a22019-04-26 17:46:08 -0700113
114 // return values
Michael Butler0a1ad962019-04-30 13:51:24 -0700115 *controller = std::make_shared<ExecutionBurstController>(std::move(sender), std::move(receiver),
116 context, callback);
Michael Butler20f28a22019-04-26 17:46:08 -0700117}
118
119// Primary validation function. This function will take a valid serialized
120// request, apply a mutation to it to invalidate the serialized request, then
121// pass it to interface calls that use the serialized request. Note that the
122// serialized request here is passed by value, and any mutation to the
123// serialized request does not leave this function.
124static void validate(RequestChannelSender* sender, ResultChannelReceiver* receiver,
125 const std::string& message, std::vector<FmqRequestDatum> serialized,
126 const std::function<void(std::vector<FmqRequestDatum>*)>& mutation) {
127 mutation(&serialized);
128
129 // skip if packet is too large to send
130 if (serialized.size() > kExecutionBurstChannelLength) {
131 return;
132 }
133
134 SCOPED_TRACE(message);
135
136 // send invalid packet
Michael Butler0a1ad962019-04-30 13:51:24 -0700137 ASSERT_TRUE(sender->sendPacket(serialized));
Michael Butler20f28a22019-04-26 17:46:08 -0700138
139 // receive error
140 auto results = receiver->getBlocking();
141 ASSERT_TRUE(results.has_value());
142 const auto [status, outputShapes, timing] = std::move(*results);
143 EXPECT_NE(ErrorStatus::NONE, status);
144 EXPECT_EQ(0u, outputShapes.size());
145 EXPECT_TRUE(badTiming(timing));
146}
147
Michael Butler0a1ad962019-04-30 13:51:24 -0700148// For validation, valid packet entries are mutated to invalid packet entries,
149// or invalid packet entries are inserted into valid packets. This function
150// creates pre-set invalid packet entries for convenience.
151static std::vector<FmqRequestDatum> createBadRequestPacketEntries() {
Michael Butler20f28a22019-04-26 17:46:08 -0700152 const FmqRequestDatum::PacketInformation packetInformation = {
153 /*.packetSize=*/10, /*.numberOfInputOperands=*/10, /*.numberOfOutputOperands=*/10,
154 /*.numberOfPools=*/10};
155 const FmqRequestDatum::OperandInformation operandInformation = {
156 /*.hasNoValue=*/false, /*.location=*/{}, /*.numberOfDimensions=*/10};
157 const int32_t invalidPoolIdentifier = std::numeric_limits<int32_t>::max();
Michael Butler0a1ad962019-04-30 13:51:24 -0700158 std::vector<FmqRequestDatum> bad(7);
159 bad[0].packetInformation(packetInformation);
160 bad[1].inputOperandInformation(operandInformation);
161 bad[2].inputOperandDimensionValue(0);
162 bad[3].outputOperandInformation(operandInformation);
163 bad[4].outputOperandDimensionValue(0);
164 bad[5].poolIdentifier(invalidPoolIdentifier);
165 bad[6].measureTiming(MeasureTiming::YES);
166 return bad;
Michael Butler20f28a22019-04-26 17:46:08 -0700167}
168
Michael Butler0a1ad962019-04-30 13:51:24 -0700169// For validation, valid packet entries are mutated to invalid packet entries,
170// or invalid packet entries are inserted into valid packets. This function
171// retrieves pre-set invalid packet entries for convenience. This function
172// caches these data so they can be reused on subsequent validation checks.
173static const std::vector<FmqRequestDatum>& getBadRequestPacketEntries() {
174 static const std::vector<FmqRequestDatum> bad = createBadRequestPacketEntries();
175 return bad;
Michael Butler20f28a22019-04-26 17:46:08 -0700176}
177
178///////////////////////// REMOVE DATUM ////////////////////////////////////
179
180static void removeDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
181 const std::vector<FmqRequestDatum>& serialized) {
182 for (size_t index = 0; index < serialized.size(); ++index) {
183 const std::string message = "removeDatum: removed datum at index " + std::to_string(index);
184 validate(sender, receiver, message, serialized,
185 [index](std::vector<FmqRequestDatum>* serialized) {
186 serialized->erase(serialized->begin() + index);
187 });
188 }
189}
190
191///////////////////////// ADD DATUM ////////////////////////////////////
192
193static void addDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
194 const std::vector<FmqRequestDatum>& serialized) {
Michael Butler0a1ad962019-04-30 13:51:24 -0700195 const std::vector<FmqRequestDatum>& extra = getBadRequestPacketEntries();
Michael Butler20f28a22019-04-26 17:46:08 -0700196 for (size_t index = 0; index <= serialized.size(); ++index) {
197 for (size_t type = 0; type < extra.size(); ++type) {
198 const std::string message = "addDatum: added datum type " + std::to_string(type) +
199 " at index " + std::to_string(index);
200 validate(sender, receiver, message, serialized,
201 [index, type, &extra](std::vector<FmqRequestDatum>* serialized) {
202 serialized->insert(serialized->begin() + index, extra[type]);
203 });
204 }
205 }
206}
207
208///////////////////////// MUTATE DATUM ////////////////////////////////////
209
210static bool interestingCase(const FmqRequestDatum& lhs, const FmqRequestDatum& rhs) {
211 using Discriminator = FmqRequestDatum::hidl_discriminator;
212
213 const bool differentValues = (lhs != rhs);
Michael Butler0a1ad962019-04-30 13:51:24 -0700214 const bool sameDiscriminator = (lhs.getDiscriminator() == rhs.getDiscriminator());
Michael Butler20f28a22019-04-26 17:46:08 -0700215 const auto discriminator = rhs.getDiscriminator();
216 const bool isDimensionValue = (discriminator == Discriminator::inputOperandDimensionValue ||
217 discriminator == Discriminator::outputOperandDimensionValue);
218
Michael Butler0a1ad962019-04-30 13:51:24 -0700219 return differentValues && !(sameDiscriminator && isDimensionValue);
Michael Butler20f28a22019-04-26 17:46:08 -0700220}
221
222static void mutateDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
223 const std::vector<FmqRequestDatum>& serialized) {
Michael Butler0a1ad962019-04-30 13:51:24 -0700224 const std::vector<FmqRequestDatum>& change = getBadRequestPacketEntries();
Michael Butler20f28a22019-04-26 17:46:08 -0700225 for (size_t index = 0; index < serialized.size(); ++index) {
226 for (size_t type = 0; type < change.size(); ++type) {
227 if (interestingCase(serialized[index], change[type])) {
228 const std::string message = "mutateDatum: changed datum at index " +
229 std::to_string(index) + " to datum type " +
230 std::to_string(type);
231 validate(sender, receiver, message, serialized,
232 [index, type, &change](std::vector<FmqRequestDatum>* serialized) {
233 (*serialized)[index] = change[type];
234 });
235 }
236 }
237 }
238}
239
240///////////////////////// BURST VALIATION TESTS ////////////////////////////////////
241
242static void validateBurstSerialization(const sp<IPreparedModel>& preparedModel,
Xusong Wang491b0a82019-08-09 16:45:24 -0700243 const Request& request) {
Michael Butler20f28a22019-04-26 17:46:08 -0700244 // create burst
245 std::unique_ptr<RequestChannelSender> sender;
246 std::unique_ptr<ResultChannelReceiver> receiver;
247 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
248 sp<IBurstContext> context;
249 ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context));
250 ASSERT_NE(nullptr, sender.get());
251 ASSERT_NE(nullptr, receiver.get());
252 ASSERT_NE(nullptr, context.get());
253
Xusong Wang491b0a82019-08-09 16:45:24 -0700254 // load memory into callback slots
255 std::vector<intptr_t> keys;
256 keys.reserve(request.pools.size());
257 std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(keys),
258 [](const auto& pool) { return reinterpret_cast<intptr_t>(&pool); });
259 const std::vector<int32_t> slots = callback->getSlots(request.pools, keys);
Michael Butler20f28a22019-04-26 17:46:08 -0700260
Xusong Wang491b0a82019-08-09 16:45:24 -0700261 // ensure slot std::numeric_limits<int32_t>::max() doesn't exist (for
262 // subsequent slot validation testing)
263 ASSERT_TRUE(std::all_of(slots.begin(), slots.end(), [](int32_t slot) {
264 return slot != std::numeric_limits<int32_t>::max();
265 }));
Michael Butler20f28a22019-04-26 17:46:08 -0700266
Xusong Wang491b0a82019-08-09 16:45:24 -0700267 // serialize the request
268 const auto serialized = ::android::nn::serialize(request, MeasureTiming::YES, slots);
Michael Butler20f28a22019-04-26 17:46:08 -0700269
Xusong Wang491b0a82019-08-09 16:45:24 -0700270 // validations
271 removeDatumTest(sender.get(), receiver.get(), serialized);
272 addDatumTest(sender.get(), receiver.get(), serialized);
273 mutateDatumTest(sender.get(), receiver.get(), serialized);
Michael Butler20f28a22019-04-26 17:46:08 -0700274}
275
Michael Butler0a1ad962019-04-30 13:51:24 -0700276// This test validates that when the Result message size exceeds length of the
277// result FMQ, the service instance gracefully fails and returns an error.
Michael Butler20f28a22019-04-26 17:46:08 -0700278static void validateBurstFmqLength(const sp<IPreparedModel>& preparedModel,
Xusong Wang491b0a82019-08-09 16:45:24 -0700279 const Request& request) {
Michael Butler20f28a22019-04-26 17:46:08 -0700280 // create regular burst
281 std::shared_ptr<ExecutionBurstController> controllerRegular;
Michael Butler0a1ad962019-04-30 13:51:24 -0700282 ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength(
283 preparedModel, kExecutionBurstChannelLength, &controllerRegular));
Michael Butler20f28a22019-04-26 17:46:08 -0700284 ASSERT_NE(nullptr, controllerRegular.get());
285
286 // create burst with small output channel
287 std::shared_ptr<ExecutionBurstController> controllerSmall;
Michael Butler0a1ad962019-04-30 13:51:24 -0700288 ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength(
289 preparedModel, kExecutionBurstChannelSmallLength, &controllerSmall));
Michael Butler20f28a22019-04-26 17:46:08 -0700290 ASSERT_NE(nullptr, controllerSmall.get());
291
Xusong Wang491b0a82019-08-09 16:45:24 -0700292 // load memory into callback slots
293 std::vector<intptr_t> keys(request.pools.size());
294 for (size_t i = 0; i < keys.size(); ++i) {
295 keys[i] = reinterpret_cast<intptr_t>(&request.pools[i]);
Michael Butler20f28a22019-04-26 17:46:08 -0700296 }
Xusong Wang491b0a82019-08-09 16:45:24 -0700297
298 // collect serialized result by running regular burst
299 const auto [statusRegular, outputShapesRegular, timingRegular] =
300 controllerRegular->compute(request, MeasureTiming::NO, keys);
301
302 // skip test if regular burst output isn't useful for testing a failure
303 // caused by having too small of a length for the result FMQ
304 const std::vector<FmqResultDatum> serialized =
305 ::android::nn::serialize(statusRegular, outputShapesRegular, timingRegular);
306 if (statusRegular != ErrorStatus::NONE ||
307 serialized.size() <= kExecutionBurstChannelSmallLength) {
308 return;
309 }
310
311 // by this point, execution should fail because the result channel isn't
312 // large enough to return the serialized result
313 const auto [statusSmall, outputShapesSmall, timingSmall] =
314 controllerSmall->compute(request, MeasureTiming::NO, keys);
315 EXPECT_NE(ErrorStatus::NONE, statusSmall);
316 EXPECT_EQ(0u, outputShapesSmall.size());
317 EXPECT_TRUE(badTiming(timingSmall));
Michael Butler20f28a22019-04-26 17:46:08 -0700318}
319
Michael Butlerddb770f2019-05-02 18:16:13 -0700320static bool isSanitized(const FmqResultDatum& datum) {
321 using Discriminator = FmqResultDatum::hidl_discriminator;
322
323 // check to ensure the padding values in the returned
324 // FmqResultDatum::OperandInformation are initialized to 0
325 if (datum.getDiscriminator() == Discriminator::operandInformation) {
326 static_assert(
327 offsetof(FmqResultDatum::OperandInformation, isSufficient) == 0,
328 "unexpected value for offset of FmqResultDatum::OperandInformation::isSufficient");
329 static_assert(
330 sizeof(FmqResultDatum::OperandInformation::isSufficient) == 1,
331 "unexpected value for size of FmqResultDatum::OperandInformation::isSufficient");
332 static_assert(offsetof(FmqResultDatum::OperandInformation, numberOfDimensions) == 4,
333 "unexpected value for offset of "
334 "FmqResultDatum::OperandInformation::numberOfDimensions");
335 static_assert(sizeof(FmqResultDatum::OperandInformation::numberOfDimensions) == 4,
336 "unexpected value for size of "
337 "FmqResultDatum::OperandInformation::numberOfDimensions");
338 static_assert(sizeof(FmqResultDatum::OperandInformation) == 8,
339 "unexpected value for size of "
340 "FmqResultDatum::OperandInformation");
341
342 constexpr size_t paddingOffset =
343 offsetof(FmqResultDatum::OperandInformation, isSufficient) +
344 sizeof(FmqResultDatum::OperandInformation::isSufficient);
345 constexpr size_t paddingSize =
346 offsetof(FmqResultDatum::OperandInformation, numberOfDimensions) - paddingOffset;
347
348 FmqResultDatum::OperandInformation initialized{};
349 std::memset(&initialized, 0, sizeof(initialized));
350
351 const char* initializedPaddingStart =
352 reinterpret_cast<const char*>(&initialized) + paddingOffset;
353 const char* datumPaddingStart =
354 reinterpret_cast<const char*>(&datum.operandInformation()) + paddingOffset;
355
356 return std::memcmp(datumPaddingStart, initializedPaddingStart, paddingSize) == 0;
357 }
358
359 // there are no other padding initialization checks required, so return true
360 // for any sum-type that isn't FmqResultDatum::OperandInformation
361 return true;
362}
363
364static void validateBurstSanitized(const sp<IPreparedModel>& preparedModel,
Xusong Wang8fc46222019-08-19 10:37:18 -0700365 const Request& request) {
Michael Butlerddb770f2019-05-02 18:16:13 -0700366 // create burst
367 std::unique_ptr<RequestChannelSender> sender;
368 std::unique_ptr<ResultChannelReceiver> receiver;
369 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
370 sp<IBurstContext> context;
371 ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context));
372 ASSERT_NE(nullptr, sender.get());
373 ASSERT_NE(nullptr, receiver.get());
374 ASSERT_NE(nullptr, context.get());
375
Xusong Wang8fc46222019-08-19 10:37:18 -0700376 // load memory into callback slots
377 std::vector<intptr_t> keys;
378 keys.reserve(request.pools.size());
379 std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(keys),
380 [](const auto& pool) { return reinterpret_cast<intptr_t>(&pool); });
381 const std::vector<int32_t> slots = callback->getSlots(request.pools, keys);
Michael Butlerddb770f2019-05-02 18:16:13 -0700382
Xusong Wang8fc46222019-08-19 10:37:18 -0700383 // send valid request
384 ASSERT_TRUE(sender->send(request, MeasureTiming::YES, slots));
Michael Butlerddb770f2019-05-02 18:16:13 -0700385
Xusong Wang8fc46222019-08-19 10:37:18 -0700386 // receive valid result
387 auto serialized = receiver->getPacketBlocking();
388 ASSERT_TRUE(serialized.has_value());
Michael Butlerddb770f2019-05-02 18:16:13 -0700389
Xusong Wang8fc46222019-08-19 10:37:18 -0700390 // sanitize result
391 ASSERT_TRUE(std::all_of(serialized->begin(), serialized->end(), isSanitized))
392 << "The result serialized data is not properly sanitized";
Michael Butlerddb770f2019-05-02 18:16:13 -0700393}
394
Michael Butler20f28a22019-04-26 17:46:08 -0700395///////////////////////////// ENTRY POINT //////////////////////////////////
396
397void ValidationTest::validateBurst(const sp<IPreparedModel>& preparedModel,
Xusong Wang491b0a82019-08-09 16:45:24 -0700398 const Request& request) {
399 ASSERT_NO_FATAL_FAILURE(validateBurstSerialization(preparedModel, request));
400 ASSERT_NO_FATAL_FAILURE(validateBurstFmqLength(preparedModel, request));
Xusong Wang8fc46222019-08-19 10:37:18 -0700401 ASSERT_NO_FATAL_FAILURE(validateBurstSanitized(preparedModel, request));
Michael Butler20f28a22019-04-26 17:46:08 -0700402}
403
404} // namespace functional
405} // namespace vts
406} // namespace V1_2
407} // namespace neuralnetworks
408} // namespace hardware
409} // namespace android