blob: 1d4493d2085f962c1ad31372a97b10797c9f0b09 [file] [log] [blame]
Lev Proleev13fdfcd2019-08-30 11:35:34 +01001/*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "neuralnetworks_hidl_hal_test"
18
19#include "VtsHalNeuralnetworks.h"
20
21#include "1.2/Callbacks.h"
22#include "ExecutionBurstController.h"
23#include "ExecutionBurstServer.h"
24#include "GeneratedTestHarness.h"
25#include "TestHarness.h"
26#include "Utils.h"
27
28#include <android-base/logging.h>
29#include <cstring>
30
31namespace android::hardware::neuralnetworks::V1_2::vts::functional {
32
33using nn::ExecutionBurstController;
34using nn::RequestChannelSender;
35using nn::ResultChannelReceiver;
36using V1_0::ErrorStatus;
37using V1_0::Request;
38using ExecutionBurstCallback = ExecutionBurstController::ExecutionBurstCallback;
39
40// This constant value represents the length of an FMQ that is large enough to
41// return a result from a burst execution for all of the generated test cases.
42constexpr size_t kExecutionBurstChannelLength = 1024;
43
44// This constant value represents a length of an FMQ that is not large enough
45// to return a result from a burst execution for some of the generated test
46// cases.
47constexpr size_t kExecutionBurstChannelSmallLength = 8;
48
49///////////////////////// UTILITY FUNCTIONS /////////////////////////
50
51static bool badTiming(Timing timing) {
52 return timing.timeOnDevice == UINT64_MAX && timing.timeInDriver == UINT64_MAX;
53}
54
55static void createBurst(const sp<IPreparedModel>& preparedModel, const sp<IBurstCallback>& callback,
56 std::unique_ptr<RequestChannelSender>* sender,
57 std::unique_ptr<ResultChannelReceiver>* receiver,
58 sp<IBurstContext>* context,
59 size_t resultChannelLength = kExecutionBurstChannelLength) {
60 ASSERT_NE(nullptr, preparedModel.get());
61 ASSERT_NE(nullptr, sender);
62 ASSERT_NE(nullptr, receiver);
63 ASSERT_NE(nullptr, context);
64
65 // create FMQ objects
66 auto [fmqRequestChannel, fmqRequestDescriptor] =
67 RequestChannelSender::create(kExecutionBurstChannelLength, /*blocking=*/true);
68 auto [fmqResultChannel, fmqResultDescriptor] =
69 ResultChannelReceiver::create(resultChannelLength, /*blocking=*/true);
70 ASSERT_NE(nullptr, fmqRequestChannel.get());
71 ASSERT_NE(nullptr, fmqResultChannel.get());
72 ASSERT_NE(nullptr, fmqRequestDescriptor);
73 ASSERT_NE(nullptr, fmqResultDescriptor);
74
75 // configure burst
76 ErrorStatus errorStatus;
77 sp<IBurstContext> burstContext;
78 const Return<void> ret = preparedModel->configureExecutionBurst(
79 callback, *fmqRequestDescriptor, *fmqResultDescriptor,
80 [&errorStatus, &burstContext](ErrorStatus status, const sp<IBurstContext>& context) {
81 errorStatus = status;
82 burstContext = context;
83 });
84 ASSERT_TRUE(ret.isOk());
85 ASSERT_EQ(ErrorStatus::NONE, errorStatus);
86 ASSERT_NE(nullptr, burstContext.get());
87
88 // return values
89 *sender = std::move(fmqRequestChannel);
90 *receiver = std::move(fmqResultChannel);
91 *context = burstContext;
92}
93
94static void createBurstWithResultChannelLength(
95 const sp<IPreparedModel>& preparedModel, size_t resultChannelLength,
96 std::shared_ptr<ExecutionBurstController>* controller) {
97 ASSERT_NE(nullptr, preparedModel.get());
98 ASSERT_NE(nullptr, controller);
99
100 // create FMQ objects
101 std::unique_ptr<RequestChannelSender> sender;
102 std::unique_ptr<ResultChannelReceiver> receiver;
103 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
104 sp<IBurstContext> context;
105 ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context,
106 resultChannelLength));
107 ASSERT_NE(nullptr, sender.get());
108 ASSERT_NE(nullptr, receiver.get());
109 ASSERT_NE(nullptr, context.get());
110
111 // return values
112 *controller = std::make_shared<ExecutionBurstController>(std::move(sender), std::move(receiver),
113 context, callback);
114}
115
116// Primary validation function. This function will take a valid serialized
117// request, apply a mutation to it to invalidate the serialized request, then
118// pass it to interface calls that use the serialized request. Note that the
119// serialized request here is passed by value, and any mutation to the
120// serialized request does not leave this function.
121static void validate(RequestChannelSender* sender, ResultChannelReceiver* receiver,
122 const std::string& message, std::vector<FmqRequestDatum> serialized,
123 const std::function<void(std::vector<FmqRequestDatum>*)>& mutation) {
124 mutation(&serialized);
125
126 // skip if packet is too large to send
127 if (serialized.size() > kExecutionBurstChannelLength) {
128 return;
129 }
130
131 SCOPED_TRACE(message);
132
133 // send invalid packet
134 ASSERT_TRUE(sender->sendPacket(serialized));
135
136 // receive error
137 auto results = receiver->getBlocking();
138 ASSERT_TRUE(results.has_value());
139 const auto [status, outputShapes, timing] = std::move(*results);
140 EXPECT_NE(ErrorStatus::NONE, status);
141 EXPECT_EQ(0u, outputShapes.size());
142 EXPECT_TRUE(badTiming(timing));
143}
144
145// For validation, valid packet entries are mutated to invalid packet entries,
146// or invalid packet entries are inserted into valid packets. This function
147// creates pre-set invalid packet entries for convenience.
148static std::vector<FmqRequestDatum> createBadRequestPacketEntries() {
149 const FmqRequestDatum::PacketInformation packetInformation = {
150 /*.packetSize=*/10, /*.numberOfInputOperands=*/10, /*.numberOfOutputOperands=*/10,
151 /*.numberOfPools=*/10};
152 const FmqRequestDatum::OperandInformation operandInformation = {
153 /*.hasNoValue=*/false, /*.location=*/{}, /*.numberOfDimensions=*/10};
154 const int32_t invalidPoolIdentifier = std::numeric_limits<int32_t>::max();
155 std::vector<FmqRequestDatum> bad(7);
156 bad[0].packetInformation(packetInformation);
157 bad[1].inputOperandInformation(operandInformation);
158 bad[2].inputOperandDimensionValue(0);
159 bad[3].outputOperandInformation(operandInformation);
160 bad[4].outputOperandDimensionValue(0);
161 bad[5].poolIdentifier(invalidPoolIdentifier);
162 bad[6].measureTiming(MeasureTiming::YES);
163 return bad;
164}
165
166// For validation, valid packet entries are mutated to invalid packet entries,
167// or invalid packet entries are inserted into valid packets. This function
168// retrieves pre-set invalid packet entries for convenience. This function
169// caches these data so they can be reused on subsequent validation checks.
170static const std::vector<FmqRequestDatum>& getBadRequestPacketEntries() {
171 static const std::vector<FmqRequestDatum> bad = createBadRequestPacketEntries();
172 return bad;
173}
174
175///////////////////////// REMOVE DATUM ////////////////////////////////////
176
177static void removeDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
178 const std::vector<FmqRequestDatum>& serialized) {
179 for (size_t index = 0; index < serialized.size(); ++index) {
180 const std::string message = "removeDatum: removed datum at index " + std::to_string(index);
181 validate(sender, receiver, message, serialized,
182 [index](std::vector<FmqRequestDatum>* serialized) {
183 serialized->erase(serialized->begin() + index);
184 });
185 }
186}
187
188///////////////////////// ADD DATUM ////////////////////////////////////
189
190static void addDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
191 const std::vector<FmqRequestDatum>& serialized) {
192 const std::vector<FmqRequestDatum>& extra = getBadRequestPacketEntries();
193 for (size_t index = 0; index <= serialized.size(); ++index) {
194 for (size_t type = 0; type < extra.size(); ++type) {
195 const std::string message = "addDatum: added datum type " + std::to_string(type) +
196 " at index " + std::to_string(index);
197 validate(sender, receiver, message, serialized,
198 [index, type, &extra](std::vector<FmqRequestDatum>* serialized) {
199 serialized->insert(serialized->begin() + index, extra[type]);
200 });
201 }
202 }
203}
204
205///////////////////////// MUTATE DATUM ////////////////////////////////////
206
207static bool interestingCase(const FmqRequestDatum& lhs, const FmqRequestDatum& rhs) {
208 using Discriminator = FmqRequestDatum::hidl_discriminator;
209
210 const bool differentValues = (lhs != rhs);
211 const bool sameDiscriminator = (lhs.getDiscriminator() == rhs.getDiscriminator());
212 const auto discriminator = rhs.getDiscriminator();
213 const bool isDimensionValue = (discriminator == Discriminator::inputOperandDimensionValue ||
214 discriminator == Discriminator::outputOperandDimensionValue);
215
216 return differentValues && !(sameDiscriminator && isDimensionValue);
217}
218
219static void mutateDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
220 const std::vector<FmqRequestDatum>& serialized) {
221 const std::vector<FmqRequestDatum>& change = getBadRequestPacketEntries();
222 for (size_t index = 0; index < serialized.size(); ++index) {
223 for (size_t type = 0; type < change.size(); ++type) {
224 if (interestingCase(serialized[index], change[type])) {
225 const std::string message = "mutateDatum: changed datum at index " +
226 std::to_string(index) + " to datum type " +
227 std::to_string(type);
228 validate(sender, receiver, message, serialized,
229 [index, type, &change](std::vector<FmqRequestDatum>* serialized) {
230 (*serialized)[index] = change[type];
231 });
232 }
233 }
234 }
235}
236
237///////////////////////// BURST VALIATION TESTS ////////////////////////////////////
238
239static void validateBurstSerialization(const sp<IPreparedModel>& preparedModel,
240 const Request& request) {
241 // create burst
242 std::unique_ptr<RequestChannelSender> sender;
243 std::unique_ptr<ResultChannelReceiver> receiver;
244 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
245 sp<IBurstContext> context;
246 ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context));
247 ASSERT_NE(nullptr, sender.get());
248 ASSERT_NE(nullptr, receiver.get());
249 ASSERT_NE(nullptr, context.get());
250
251 // load memory into callback slots
252 std::vector<intptr_t> keys;
253 keys.reserve(request.pools.size());
254 std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(keys),
255 [](const auto& pool) { return reinterpret_cast<intptr_t>(&pool); });
256 const std::vector<int32_t> slots = callback->getSlots(request.pools, keys);
257
258 // ensure slot std::numeric_limits<int32_t>::max() doesn't exist (for
259 // subsequent slot validation testing)
260 ASSERT_TRUE(std::all_of(slots.begin(), slots.end(), [](int32_t slot) {
261 return slot != std::numeric_limits<int32_t>::max();
262 }));
263
264 // serialize the request
265 const auto serialized = android::nn::serialize(request, MeasureTiming::YES, slots);
266
267 // validations
268 removeDatumTest(sender.get(), receiver.get(), serialized);
269 addDatumTest(sender.get(), receiver.get(), serialized);
270 mutateDatumTest(sender.get(), receiver.get(), serialized);
271}
272
273// This test validates that when the Result message size exceeds length of the
274// result FMQ, the service instance gracefully fails and returns an error.
275static void validateBurstFmqLength(const sp<IPreparedModel>& preparedModel,
276 const Request& request) {
277 // create regular burst
278 std::shared_ptr<ExecutionBurstController> controllerRegular;
279 ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength(
280 preparedModel, kExecutionBurstChannelLength, &controllerRegular));
281 ASSERT_NE(nullptr, controllerRegular.get());
282
283 // create burst with small output channel
284 std::shared_ptr<ExecutionBurstController> controllerSmall;
285 ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength(
286 preparedModel, kExecutionBurstChannelSmallLength, &controllerSmall));
287 ASSERT_NE(nullptr, controllerSmall.get());
288
289 // load memory into callback slots
290 std::vector<intptr_t> keys(request.pools.size());
291 for (size_t i = 0; i < keys.size(); ++i) {
292 keys[i] = reinterpret_cast<intptr_t>(&request.pools[i]);
293 }
294
295 // collect serialized result by running regular burst
296 const auto [statusRegular, outputShapesRegular, timingRegular] =
297 controllerRegular->compute(request, MeasureTiming::NO, keys);
298
299 // skip test if regular burst output isn't useful for testing a failure
300 // caused by having too small of a length for the result FMQ
301 const std::vector<FmqResultDatum> serialized =
302 android::nn::serialize(statusRegular, outputShapesRegular, timingRegular);
303 if (statusRegular != ErrorStatus::NONE ||
304 serialized.size() <= kExecutionBurstChannelSmallLength) {
305 return;
306 }
307
308 // by this point, execution should fail because the result channel isn't
309 // large enough to return the serialized result
310 const auto [statusSmall, outputShapesSmall, timingSmall] =
311 controllerSmall->compute(request, MeasureTiming::NO, keys);
312 EXPECT_NE(ErrorStatus::NONE, statusSmall);
313 EXPECT_EQ(0u, outputShapesSmall.size());
314 EXPECT_TRUE(badTiming(timingSmall));
315}
316
317static bool isSanitized(const FmqResultDatum& datum) {
318 using Discriminator = FmqResultDatum::hidl_discriminator;
319
320 // check to ensure the padding values in the returned
321 // FmqResultDatum::OperandInformation are initialized to 0
322 if (datum.getDiscriminator() == Discriminator::operandInformation) {
323 static_assert(
324 offsetof(FmqResultDatum::OperandInformation, isSufficient) == 0,
325 "unexpected value for offset of FmqResultDatum::OperandInformation::isSufficient");
326 static_assert(
327 sizeof(FmqResultDatum::OperandInformation::isSufficient) == 1,
328 "unexpected value for size of FmqResultDatum::OperandInformation::isSufficient");
329 static_assert(offsetof(FmqResultDatum::OperandInformation, numberOfDimensions) == 4,
330 "unexpected value for offset of "
331 "FmqResultDatum::OperandInformation::numberOfDimensions");
332 static_assert(sizeof(FmqResultDatum::OperandInformation::numberOfDimensions) == 4,
333 "unexpected value for size of "
334 "FmqResultDatum::OperandInformation::numberOfDimensions");
335 static_assert(sizeof(FmqResultDatum::OperandInformation) == 8,
336 "unexpected value for size of "
337 "FmqResultDatum::OperandInformation");
338
339 constexpr size_t paddingOffset =
340 offsetof(FmqResultDatum::OperandInformation, isSufficient) +
341 sizeof(FmqResultDatum::OperandInformation::isSufficient);
342 constexpr size_t paddingSize =
343 offsetof(FmqResultDatum::OperandInformation, numberOfDimensions) - paddingOffset;
344
345 FmqResultDatum::OperandInformation initialized{};
346 std::memset(&initialized, 0, sizeof(initialized));
347
348 const char* initializedPaddingStart =
349 reinterpret_cast<const char*>(&initialized) + paddingOffset;
350 const char* datumPaddingStart =
351 reinterpret_cast<const char*>(&datum.operandInformation()) + paddingOffset;
352
353 return std::memcmp(datumPaddingStart, initializedPaddingStart, paddingSize) == 0;
354 }
355
356 // there are no other padding initialization checks required, so return true
357 // for any sum-type that isn't FmqResultDatum::OperandInformation
358 return true;
359}
360
361static void validateBurstSanitized(const sp<IPreparedModel>& preparedModel,
362 const Request& request) {
363 // create burst
364 std::unique_ptr<RequestChannelSender> sender;
365 std::unique_ptr<ResultChannelReceiver> receiver;
366 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
367 sp<IBurstContext> context;
368 ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context));
369 ASSERT_NE(nullptr, sender.get());
370 ASSERT_NE(nullptr, receiver.get());
371 ASSERT_NE(nullptr, context.get());
372
373 // load memory into callback slots
374 std::vector<intptr_t> keys;
375 keys.reserve(request.pools.size());
376 std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(keys),
377 [](const auto& pool) { return reinterpret_cast<intptr_t>(&pool); });
378 const std::vector<int32_t> slots = callback->getSlots(request.pools, keys);
379
380 // send valid request
381 ASSERT_TRUE(sender->send(request, MeasureTiming::YES, slots));
382
383 // receive valid result
384 auto serialized = receiver->getPacketBlocking();
385 ASSERT_TRUE(serialized.has_value());
386
387 // sanitize result
388 ASSERT_TRUE(std::all_of(serialized->begin(), serialized->end(), isSanitized))
389 << "The result serialized data is not properly sanitized";
390}
391
392///////////////////////////// ENTRY POINT //////////////////////////////////
393
394void validateBurst(const sp<IPreparedModel>& preparedModel, const Request& request) {
395 ASSERT_NO_FATAL_FAILURE(validateBurstSerialization(preparedModel, request));
396 ASSERT_NO_FATAL_FAILURE(validateBurstFmqLength(preparedModel, request));
397 ASSERT_NO_FATAL_FAILURE(validateBurstSanitized(preparedModel, request));
398}
399
400} // namespace android::hardware::neuralnetworks::V1_2::vts::functional