blob: 6ee7b1c5344a09e9eb5d237dec8059cd4baa0795 [file] [log] [blame]
Lev Proleev3b13b552019-08-30 11:35:34 +01001/*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "neuralnetworks_hidl_hal_test"
18
19#include "VtsHalNeuralnetworks.h"
20
21#include "1.2/Callbacks.h"
22#include "ExecutionBurstController.h"
23#include "ExecutionBurstServer.h"
24#include "GeneratedTestHarness.h"
25#include "TestHarness.h"
26#include "Utils.h"
27
28#include <android-base/logging.h>
Michael Butler57568872019-07-25 17:22:11 -070029#include <chrono>
Lev Proleev3b13b552019-08-30 11:35:34 +010030#include <cstring>
31
Lev Proleevb49dadf2019-08-30 11:57:18 +010032namespace android::hardware::neuralnetworks::V1_3::vts::functional {
Lev Proleev3b13b552019-08-30 11:35:34 +010033
34using nn::ExecutionBurstController;
35using nn::RequestChannelSender;
36using nn::ResultChannelReceiver;
Lev Proleev3b13b552019-08-30 11:35:34 +010037using V1_0::Request;
Lev Proleevb49dadf2019-08-30 11:57:18 +010038using V1_2::FmqRequestDatum;
39using V1_2::FmqResultDatum;
40using V1_2::IBurstCallback;
41using V1_2::IBurstContext;
Lev Proleevb49dadf2019-08-30 11:57:18 +010042using V1_2::MeasureTiming;
43using V1_2::Timing;
Lev Proleev3b13b552019-08-30 11:35:34 +010044using ExecutionBurstCallback = ExecutionBurstController::ExecutionBurstCallback;
45
Michael Butlerda1a6922020-03-11 18:45:45 -070046using BurstExecutionMutation = std::function<void(std::vector<FmqRequestDatum>*)>;
47
Lev Proleev3b13b552019-08-30 11:35:34 +010048// This constant value represents the length of an FMQ that is large enough to
49// return a result from a burst execution for all of the generated test cases.
50constexpr size_t kExecutionBurstChannelLength = 1024;
51
52// This constant value represents a length of an FMQ that is not large enough
53// to return a result from a burst execution for some of the generated test
54// cases.
55constexpr size_t kExecutionBurstChannelSmallLength = 8;
56
57///////////////////////// UTILITY FUNCTIONS /////////////////////////
58
59static bool badTiming(Timing timing) {
60 return timing.timeOnDevice == UINT64_MAX && timing.timeInDriver == UINT64_MAX;
61}
62
63static void createBurst(const sp<IPreparedModel>& preparedModel, const sp<IBurstCallback>& callback,
64 std::unique_ptr<RequestChannelSender>* sender,
65 std::unique_ptr<ResultChannelReceiver>* receiver,
66 sp<IBurstContext>* context,
67 size_t resultChannelLength = kExecutionBurstChannelLength) {
68 ASSERT_NE(nullptr, preparedModel.get());
69 ASSERT_NE(nullptr, sender);
70 ASSERT_NE(nullptr, receiver);
71 ASSERT_NE(nullptr, context);
72
73 // create FMQ objects
74 auto [fmqRequestChannel, fmqRequestDescriptor] =
Michael Butler57568872019-07-25 17:22:11 -070075 RequestChannelSender::create(kExecutionBurstChannelLength);
Lev Proleev3b13b552019-08-30 11:35:34 +010076 auto [fmqResultChannel, fmqResultDescriptor] =
Michael Butler57568872019-07-25 17:22:11 -070077 ResultChannelReceiver::create(resultChannelLength, std::chrono::microseconds{0});
Lev Proleev3b13b552019-08-30 11:35:34 +010078 ASSERT_NE(nullptr, fmqRequestChannel.get());
79 ASSERT_NE(nullptr, fmqResultChannel.get());
80 ASSERT_NE(nullptr, fmqRequestDescriptor);
81 ASSERT_NE(nullptr, fmqResultDescriptor);
82
83 // configure burst
Michael Butler9449a282019-12-11 19:08:08 -080084 V1_0::ErrorStatus errorStatus;
Lev Proleev3b13b552019-08-30 11:35:34 +010085 sp<IBurstContext> burstContext;
86 const Return<void> ret = preparedModel->configureExecutionBurst(
87 callback, *fmqRequestDescriptor, *fmqResultDescriptor,
Michael Butler9449a282019-12-11 19:08:08 -080088 [&errorStatus, &burstContext](V1_0::ErrorStatus status,
89 const sp<IBurstContext>& context) {
Lev Proleev3b13b552019-08-30 11:35:34 +010090 errorStatus = status;
91 burstContext = context;
92 });
93 ASSERT_TRUE(ret.isOk());
Michael Butler9449a282019-12-11 19:08:08 -080094 ASSERT_EQ(V1_0::ErrorStatus::NONE, errorStatus);
Lev Proleev3b13b552019-08-30 11:35:34 +010095 ASSERT_NE(nullptr, burstContext.get());
96
97 // return values
98 *sender = std::move(fmqRequestChannel);
99 *receiver = std::move(fmqResultChannel);
100 *context = burstContext;
101}
102
103static void createBurstWithResultChannelLength(
104 const sp<IPreparedModel>& preparedModel, size_t resultChannelLength,
105 std::shared_ptr<ExecutionBurstController>* controller) {
106 ASSERT_NE(nullptr, preparedModel.get());
107 ASSERT_NE(nullptr, controller);
108
109 // create FMQ objects
110 std::unique_ptr<RequestChannelSender> sender;
111 std::unique_ptr<ResultChannelReceiver> receiver;
112 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
113 sp<IBurstContext> context;
114 ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context,
115 resultChannelLength));
116 ASSERT_NE(nullptr, sender.get());
117 ASSERT_NE(nullptr, receiver.get());
118 ASSERT_NE(nullptr, context.get());
119
120 // return values
121 *controller = std::make_shared<ExecutionBurstController>(std::move(sender), std::move(receiver),
122 context, callback);
123}
124
125// Primary validation function. This function will take a valid serialized
126// request, apply a mutation to it to invalidate the serialized request, then
Michael Butlerda1a6922020-03-11 18:45:45 -0700127// pass it to interface calls that use the serialized request.
Lev Proleev3b13b552019-08-30 11:35:34 +0100128static void validate(RequestChannelSender* sender, ResultChannelReceiver* receiver,
Michael Butlerda1a6922020-03-11 18:45:45 -0700129 const std::string& message,
130 const std::vector<FmqRequestDatum>& originalSerialized,
131 const BurstExecutionMutation& mutate) {
132 std::vector<FmqRequestDatum> serialized = originalSerialized;
133 mutate(&serialized);
Lev Proleev3b13b552019-08-30 11:35:34 +0100134
135 // skip if packet is too large to send
136 if (serialized.size() > kExecutionBurstChannelLength) {
137 return;
138 }
139
140 SCOPED_TRACE(message);
141
142 // send invalid packet
143 ASSERT_TRUE(sender->sendPacket(serialized));
144
145 // receive error
146 auto results = receiver->getBlocking();
147 ASSERT_TRUE(results.has_value());
148 const auto [status, outputShapes, timing] = std::move(*results);
Michael Butler9449a282019-12-11 19:08:08 -0800149 EXPECT_NE(V1_0::ErrorStatus::NONE, status);
Lev Proleev3b13b552019-08-30 11:35:34 +0100150 EXPECT_EQ(0u, outputShapes.size());
151 EXPECT_TRUE(badTiming(timing));
152}
153
154// For validation, valid packet entries are mutated to invalid packet entries,
155// or invalid packet entries are inserted into valid packets. This function
156// creates pre-set invalid packet entries for convenience.
157static std::vector<FmqRequestDatum> createBadRequestPacketEntries() {
158 const FmqRequestDatum::PacketInformation packetInformation = {
159 /*.packetSize=*/10, /*.numberOfInputOperands=*/10, /*.numberOfOutputOperands=*/10,
160 /*.numberOfPools=*/10};
161 const FmqRequestDatum::OperandInformation operandInformation = {
162 /*.hasNoValue=*/false, /*.location=*/{}, /*.numberOfDimensions=*/10};
163 const int32_t invalidPoolIdentifier = std::numeric_limits<int32_t>::max();
164 std::vector<FmqRequestDatum> bad(7);
165 bad[0].packetInformation(packetInformation);
166 bad[1].inputOperandInformation(operandInformation);
167 bad[2].inputOperandDimensionValue(0);
168 bad[3].outputOperandInformation(operandInformation);
169 bad[4].outputOperandDimensionValue(0);
170 bad[5].poolIdentifier(invalidPoolIdentifier);
171 bad[6].measureTiming(MeasureTiming::YES);
172 return bad;
173}
174
175// For validation, valid packet entries are mutated to invalid packet entries,
176// or invalid packet entries are inserted into valid packets. This function
177// retrieves pre-set invalid packet entries for convenience. This function
178// caches these data so they can be reused on subsequent validation checks.
179static const std::vector<FmqRequestDatum>& getBadRequestPacketEntries() {
180 static const std::vector<FmqRequestDatum> bad = createBadRequestPacketEntries();
181 return bad;
182}
183
184///////////////////////// REMOVE DATUM ////////////////////////////////////
185
186static void removeDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
187 const std::vector<FmqRequestDatum>& serialized) {
188 for (size_t index = 0; index < serialized.size(); ++index) {
189 const std::string message = "removeDatum: removed datum at index " + std::to_string(index);
190 validate(sender, receiver, message, serialized,
191 [index](std::vector<FmqRequestDatum>* serialized) {
192 serialized->erase(serialized->begin() + index);
193 });
194 }
195}
196
197///////////////////////// ADD DATUM ////////////////////////////////////
198
199static void addDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
200 const std::vector<FmqRequestDatum>& serialized) {
201 const std::vector<FmqRequestDatum>& extra = getBadRequestPacketEntries();
202 for (size_t index = 0; index <= serialized.size(); ++index) {
203 for (size_t type = 0; type < extra.size(); ++type) {
204 const std::string message = "addDatum: added datum type " + std::to_string(type) +
205 " at index " + std::to_string(index);
206 validate(sender, receiver, message, serialized,
207 [index, type, &extra](std::vector<FmqRequestDatum>* serialized) {
208 serialized->insert(serialized->begin() + index, extra[type]);
209 });
210 }
211 }
212}
213
214///////////////////////// MUTATE DATUM ////////////////////////////////////
215
216static bool interestingCase(const FmqRequestDatum& lhs, const FmqRequestDatum& rhs) {
217 using Discriminator = FmqRequestDatum::hidl_discriminator;
218
219 const bool differentValues = (lhs != rhs);
220 const bool sameDiscriminator = (lhs.getDiscriminator() == rhs.getDiscriminator());
221 const auto discriminator = rhs.getDiscriminator();
222 const bool isDimensionValue = (discriminator == Discriminator::inputOperandDimensionValue ||
223 discriminator == Discriminator::outputOperandDimensionValue);
224
225 return differentValues && !(sameDiscriminator && isDimensionValue);
226}
227
228static void mutateDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
229 const std::vector<FmqRequestDatum>& serialized) {
230 const std::vector<FmqRequestDatum>& change = getBadRequestPacketEntries();
231 for (size_t index = 0; index < serialized.size(); ++index) {
232 for (size_t type = 0; type < change.size(); ++type) {
233 if (interestingCase(serialized[index], change[type])) {
234 const std::string message = "mutateDatum: changed datum at index " +
235 std::to_string(index) + " to datum type " +
236 std::to_string(type);
237 validate(sender, receiver, message, serialized,
238 [index, type, &change](std::vector<FmqRequestDatum>* serialized) {
239 (*serialized)[index] = change[type];
240 });
241 }
242 }
243 }
244}
245
246///////////////////////// BURST VALIATION TESTS ////////////////////////////////////
247
248static void validateBurstSerialization(const sp<IPreparedModel>& preparedModel,
249 const Request& request) {
250 // create burst
251 std::unique_ptr<RequestChannelSender> sender;
252 std::unique_ptr<ResultChannelReceiver> receiver;
253 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
254 sp<IBurstContext> context;
255 ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context));
256 ASSERT_NE(nullptr, sender.get());
257 ASSERT_NE(nullptr, receiver.get());
258 ASSERT_NE(nullptr, context.get());
259
260 // load memory into callback slots
261 std::vector<intptr_t> keys;
262 keys.reserve(request.pools.size());
263 std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(keys),
264 [](const auto& pool) { return reinterpret_cast<intptr_t>(&pool); });
265 const std::vector<int32_t> slots = callback->getSlots(request.pools, keys);
266
267 // ensure slot std::numeric_limits<int32_t>::max() doesn't exist (for
268 // subsequent slot validation testing)
269 ASSERT_TRUE(std::all_of(slots.begin(), slots.end(), [](int32_t slot) {
270 return slot != std::numeric_limits<int32_t>::max();
271 }));
272
273 // serialize the request
274 const auto serialized = android::nn::serialize(request, MeasureTiming::YES, slots);
275
276 // validations
277 removeDatumTest(sender.get(), receiver.get(), serialized);
278 addDatumTest(sender.get(), receiver.get(), serialized);
279 mutateDatumTest(sender.get(), receiver.get(), serialized);
280}
281
282// This test validates that when the Result message size exceeds length of the
283// result FMQ, the service instance gracefully fails and returns an error.
284static void validateBurstFmqLength(const sp<IPreparedModel>& preparedModel,
285 const Request& request) {
286 // create regular burst
287 std::shared_ptr<ExecutionBurstController> controllerRegular;
288 ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength(
289 preparedModel, kExecutionBurstChannelLength, &controllerRegular));
290 ASSERT_NE(nullptr, controllerRegular.get());
291
292 // create burst with small output channel
293 std::shared_ptr<ExecutionBurstController> controllerSmall;
294 ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength(
295 preparedModel, kExecutionBurstChannelSmallLength, &controllerSmall));
296 ASSERT_NE(nullptr, controllerSmall.get());
297
298 // load memory into callback slots
299 std::vector<intptr_t> keys(request.pools.size());
300 for (size_t i = 0; i < keys.size(); ++i) {
301 keys[i] = reinterpret_cast<intptr_t>(&request.pools[i]);
302 }
303
304 // collect serialized result by running regular burst
Michael Butler57568872019-07-25 17:22:11 -0700305 const auto [nRegular, outputShapesRegular, timingRegular, fallbackRegular] =
Lev Proleev3b13b552019-08-30 11:35:34 +0100306 controllerRegular->compute(request, MeasureTiming::NO, keys);
Michael Butler9449a282019-12-11 19:08:08 -0800307 const V1_0::ErrorStatus statusRegular =
308 nn::convertToV1_0(nn::convertResultCodeToErrorStatus(nRegular));
Michael Butler57568872019-07-25 17:22:11 -0700309 EXPECT_FALSE(fallbackRegular);
Lev Proleev3b13b552019-08-30 11:35:34 +0100310
311 // skip test if regular burst output isn't useful for testing a failure
312 // caused by having too small of a length for the result FMQ
313 const std::vector<FmqResultDatum> serialized =
314 android::nn::serialize(statusRegular, outputShapesRegular, timingRegular);
Michael Butler9449a282019-12-11 19:08:08 -0800315 if (statusRegular != V1_0::ErrorStatus::NONE ||
Lev Proleev3b13b552019-08-30 11:35:34 +0100316 serialized.size() <= kExecutionBurstChannelSmallLength) {
317 return;
318 }
319
320 // by this point, execution should fail because the result channel isn't
321 // large enough to return the serialized result
Michael Butler57568872019-07-25 17:22:11 -0700322 const auto [nSmall, outputShapesSmall, timingSmall, fallbackSmall] =
Lev Proleev3b13b552019-08-30 11:35:34 +0100323 controllerSmall->compute(request, MeasureTiming::NO, keys);
Michael Butler9449a282019-12-11 19:08:08 -0800324 const V1_0::ErrorStatus statusSmall =
325 nn::convertToV1_0(nn::convertResultCodeToErrorStatus(nSmall));
326 EXPECT_NE(V1_0::ErrorStatus::NONE, statusSmall);
Lev Proleev3b13b552019-08-30 11:35:34 +0100327 EXPECT_EQ(0u, outputShapesSmall.size());
328 EXPECT_TRUE(badTiming(timingSmall));
Michael Butler57568872019-07-25 17:22:11 -0700329 EXPECT_FALSE(fallbackSmall);
Lev Proleev3b13b552019-08-30 11:35:34 +0100330}
331
332static bool isSanitized(const FmqResultDatum& datum) {
333 using Discriminator = FmqResultDatum::hidl_discriminator;
334
335 // check to ensure the padding values in the returned
336 // FmqResultDatum::OperandInformation are initialized to 0
337 if (datum.getDiscriminator() == Discriminator::operandInformation) {
338 static_assert(
339 offsetof(FmqResultDatum::OperandInformation, isSufficient) == 0,
340 "unexpected value for offset of FmqResultDatum::OperandInformation::isSufficient");
341 static_assert(
342 sizeof(FmqResultDatum::OperandInformation::isSufficient) == 1,
343 "unexpected value for size of FmqResultDatum::OperandInformation::isSufficient");
344 static_assert(offsetof(FmqResultDatum::OperandInformation, numberOfDimensions) == 4,
345 "unexpected value for offset of "
346 "FmqResultDatum::OperandInformation::numberOfDimensions");
347 static_assert(sizeof(FmqResultDatum::OperandInformation::numberOfDimensions) == 4,
348 "unexpected value for size of "
349 "FmqResultDatum::OperandInformation::numberOfDimensions");
350 static_assert(sizeof(FmqResultDatum::OperandInformation) == 8,
351 "unexpected value for size of "
352 "FmqResultDatum::OperandInformation");
353
354 constexpr size_t paddingOffset =
355 offsetof(FmqResultDatum::OperandInformation, isSufficient) +
356 sizeof(FmqResultDatum::OperandInformation::isSufficient);
357 constexpr size_t paddingSize =
358 offsetof(FmqResultDatum::OperandInformation, numberOfDimensions) - paddingOffset;
359
360 FmqResultDatum::OperandInformation initialized{};
361 std::memset(&initialized, 0, sizeof(initialized));
362
363 const char* initializedPaddingStart =
364 reinterpret_cast<const char*>(&initialized) + paddingOffset;
365 const char* datumPaddingStart =
366 reinterpret_cast<const char*>(&datum.operandInformation()) + paddingOffset;
367
368 return std::memcmp(datumPaddingStart, initializedPaddingStart, paddingSize) == 0;
369 }
370
371 // there are no other padding initialization checks required, so return true
372 // for any sum-type that isn't FmqResultDatum::OperandInformation
373 return true;
374}
375
376static void validateBurstSanitized(const sp<IPreparedModel>& preparedModel,
377 const Request& request) {
378 // create burst
379 std::unique_ptr<RequestChannelSender> sender;
380 std::unique_ptr<ResultChannelReceiver> receiver;
381 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
382 sp<IBurstContext> context;
383 ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context));
384 ASSERT_NE(nullptr, sender.get());
385 ASSERT_NE(nullptr, receiver.get());
386 ASSERT_NE(nullptr, context.get());
387
388 // load memory into callback slots
389 std::vector<intptr_t> keys;
390 keys.reserve(request.pools.size());
391 std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(keys),
392 [](const auto& pool) { return reinterpret_cast<intptr_t>(&pool); });
393 const std::vector<int32_t> slots = callback->getSlots(request.pools, keys);
394
395 // send valid request
396 ASSERT_TRUE(sender->send(request, MeasureTiming::YES, slots));
397
398 // receive valid result
399 auto serialized = receiver->getPacketBlocking();
400 ASSERT_TRUE(serialized.has_value());
401
402 // sanitize result
403 ASSERT_TRUE(std::all_of(serialized->begin(), serialized->end(), isSanitized))
404 << "The result serialized data is not properly sanitized";
405}
406
407///////////////////////////// ENTRY POINT //////////////////////////////////
408
409void validateBurst(const sp<IPreparedModel>& preparedModel, const Request& request) {
410 ASSERT_NO_FATAL_FAILURE(validateBurstSerialization(preparedModel, request));
411 ASSERT_NO_FATAL_FAILURE(validateBurstFmqLength(preparedModel, request));
412 ASSERT_NO_FATAL_FAILURE(validateBurstSanitized(preparedModel, request));
413}
414
Lev Proleevb49dadf2019-08-30 11:57:18 +0100415} // namespace android::hardware::neuralnetworks::V1_3::vts::functional