blob: aecb7b79c47f1ebc98daadd85a1b4bd008fcb864 [file] [log] [blame]
Lev Proleev13fdfcd2019-08-30 11:35:34 +01001/*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "neuralnetworks_hidl_hal_test"
18
19#include "VtsHalNeuralnetworks.h"
20
21#include "1.2/Callbacks.h"
22#include "ExecutionBurstController.h"
23#include "ExecutionBurstServer.h"
24#include "GeneratedTestHarness.h"
25#include "TestHarness.h"
Lev Proleev13fdfcd2019-08-30 11:35:34 +010026
27#include <android-base/logging.h>
Michael Butler648ada52019-07-25 17:22:11 -070028#include <chrono>
Lev Proleev13fdfcd2019-08-30 11:35:34 +010029#include <cstring>
30
Lev Proleev26d1bc82019-08-30 11:57:18 +010031namespace android::hardware::neuralnetworks::V1_3::vts::functional {
Lev Proleev13fdfcd2019-08-30 11:35:34 +010032
33using nn::ExecutionBurstController;
34using nn::RequestChannelSender;
35using nn::ResultChannelReceiver;
Lev Proleev13fdfcd2019-08-30 11:35:34 +010036using V1_0::Request;
Lev Proleev26d1bc82019-08-30 11:57:18 +010037using V1_2::FmqRequestDatum;
38using V1_2::FmqResultDatum;
39using V1_2::IBurstCallback;
40using V1_2::IBurstContext;
Lev Proleev26d1bc82019-08-30 11:57:18 +010041using V1_2::MeasureTiming;
42using V1_2::Timing;
Lev Proleev13fdfcd2019-08-30 11:35:34 +010043using ExecutionBurstCallback = ExecutionBurstController::ExecutionBurstCallback;
44
45// This constant value represents the length of an FMQ that is large enough to
46// return a result from a burst execution for all of the generated test cases.
47constexpr size_t kExecutionBurstChannelLength = 1024;
48
49// This constant value represents a length of an FMQ that is not large enough
50// to return a result from a burst execution for some of the generated test
51// cases.
52constexpr size_t kExecutionBurstChannelSmallLength = 8;
53
54///////////////////////// UTILITY FUNCTIONS /////////////////////////
55
56static bool badTiming(Timing timing) {
57 return timing.timeOnDevice == UINT64_MAX && timing.timeInDriver == UINT64_MAX;
58}
59
60static void createBurst(const sp<IPreparedModel>& preparedModel, const sp<IBurstCallback>& callback,
61 std::unique_ptr<RequestChannelSender>* sender,
62 std::unique_ptr<ResultChannelReceiver>* receiver,
63 sp<IBurstContext>* context,
64 size_t resultChannelLength = kExecutionBurstChannelLength) {
65 ASSERT_NE(nullptr, preparedModel.get());
66 ASSERT_NE(nullptr, sender);
67 ASSERT_NE(nullptr, receiver);
68 ASSERT_NE(nullptr, context);
69
70 // create FMQ objects
71 auto [fmqRequestChannel, fmqRequestDescriptor] =
Michael Butler648ada52019-07-25 17:22:11 -070072 RequestChannelSender::create(kExecutionBurstChannelLength);
Lev Proleev13fdfcd2019-08-30 11:35:34 +010073 auto [fmqResultChannel, fmqResultDescriptor] =
Michael Butler648ada52019-07-25 17:22:11 -070074 ResultChannelReceiver::create(resultChannelLength, std::chrono::microseconds{0});
Lev Proleev13fdfcd2019-08-30 11:35:34 +010075 ASSERT_NE(nullptr, fmqRequestChannel.get());
76 ASSERT_NE(nullptr, fmqResultChannel.get());
77 ASSERT_NE(nullptr, fmqRequestDescriptor);
78 ASSERT_NE(nullptr, fmqResultDescriptor);
79
80 // configure burst
Michael Butler79a41d72019-12-11 19:08:08 -080081 V1_0::ErrorStatus errorStatus;
Lev Proleev13fdfcd2019-08-30 11:35:34 +010082 sp<IBurstContext> burstContext;
83 const Return<void> ret = preparedModel->configureExecutionBurst(
84 callback, *fmqRequestDescriptor, *fmqResultDescriptor,
Michael Butler79a41d72019-12-11 19:08:08 -080085 [&errorStatus, &burstContext](V1_0::ErrorStatus status,
86 const sp<IBurstContext>& context) {
Lev Proleev13fdfcd2019-08-30 11:35:34 +010087 errorStatus = status;
88 burstContext = context;
89 });
90 ASSERT_TRUE(ret.isOk());
Michael Butler79a41d72019-12-11 19:08:08 -080091 ASSERT_EQ(V1_0::ErrorStatus::NONE, errorStatus);
Lev Proleev13fdfcd2019-08-30 11:35:34 +010092 ASSERT_NE(nullptr, burstContext.get());
93
94 // return values
95 *sender = std::move(fmqRequestChannel);
96 *receiver = std::move(fmqResultChannel);
97 *context = burstContext;
98}
99
100static void createBurstWithResultChannelLength(
101 const sp<IPreparedModel>& preparedModel, size_t resultChannelLength,
102 std::shared_ptr<ExecutionBurstController>* controller) {
103 ASSERT_NE(nullptr, preparedModel.get());
104 ASSERT_NE(nullptr, controller);
105
106 // create FMQ objects
107 std::unique_ptr<RequestChannelSender> sender;
108 std::unique_ptr<ResultChannelReceiver> receiver;
109 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
110 sp<IBurstContext> context;
111 ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context,
112 resultChannelLength));
113 ASSERT_NE(nullptr, sender.get());
114 ASSERT_NE(nullptr, receiver.get());
115 ASSERT_NE(nullptr, context.get());
116
117 // return values
118 *controller = std::make_shared<ExecutionBurstController>(std::move(sender), std::move(receiver),
119 context, callback);
120}
121
122// Primary validation function. This function will take a valid serialized
123// request, apply a mutation to it to invalidate the serialized request, then
124// pass it to interface calls that use the serialized request. Note that the
125// serialized request here is passed by value, and any mutation to the
126// serialized request does not leave this function.
127static void validate(RequestChannelSender* sender, ResultChannelReceiver* receiver,
128 const std::string& message, std::vector<FmqRequestDatum> serialized,
129 const std::function<void(std::vector<FmqRequestDatum>*)>& mutation) {
130 mutation(&serialized);
131
132 // skip if packet is too large to send
133 if (serialized.size() > kExecutionBurstChannelLength) {
134 return;
135 }
136
137 SCOPED_TRACE(message);
138
139 // send invalid packet
140 ASSERT_TRUE(sender->sendPacket(serialized));
141
142 // receive error
143 auto results = receiver->getBlocking();
144 ASSERT_TRUE(results.has_value());
145 const auto [status, outputShapes, timing] = std::move(*results);
Michael Butler79a41d72019-12-11 19:08:08 -0800146 EXPECT_NE(V1_0::ErrorStatus::NONE, status);
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100147 EXPECT_EQ(0u, outputShapes.size());
148 EXPECT_TRUE(badTiming(timing));
149}
150
151// For validation, valid packet entries are mutated to invalid packet entries,
152// or invalid packet entries are inserted into valid packets. This function
153// creates pre-set invalid packet entries for convenience.
154static std::vector<FmqRequestDatum> createBadRequestPacketEntries() {
155 const FmqRequestDatum::PacketInformation packetInformation = {
156 /*.packetSize=*/10, /*.numberOfInputOperands=*/10, /*.numberOfOutputOperands=*/10,
157 /*.numberOfPools=*/10};
158 const FmqRequestDatum::OperandInformation operandInformation = {
159 /*.hasNoValue=*/false, /*.location=*/{}, /*.numberOfDimensions=*/10};
160 const int32_t invalidPoolIdentifier = std::numeric_limits<int32_t>::max();
161 std::vector<FmqRequestDatum> bad(7);
162 bad[0].packetInformation(packetInformation);
163 bad[1].inputOperandInformation(operandInformation);
164 bad[2].inputOperandDimensionValue(0);
165 bad[3].outputOperandInformation(operandInformation);
166 bad[4].outputOperandDimensionValue(0);
167 bad[5].poolIdentifier(invalidPoolIdentifier);
168 bad[6].measureTiming(MeasureTiming::YES);
169 return bad;
170}
171
172// For validation, valid packet entries are mutated to invalid packet entries,
173// or invalid packet entries are inserted into valid packets. This function
174// retrieves pre-set invalid packet entries for convenience. This function
175// caches these data so they can be reused on subsequent validation checks.
176static const std::vector<FmqRequestDatum>& getBadRequestPacketEntries() {
177 static const std::vector<FmqRequestDatum> bad = createBadRequestPacketEntries();
178 return bad;
179}
180
181///////////////////////// REMOVE DATUM ////////////////////////////////////
182
183static void removeDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
184 const std::vector<FmqRequestDatum>& serialized) {
185 for (size_t index = 0; index < serialized.size(); ++index) {
186 const std::string message = "removeDatum: removed datum at index " + std::to_string(index);
187 validate(sender, receiver, message, serialized,
188 [index](std::vector<FmqRequestDatum>* serialized) {
189 serialized->erase(serialized->begin() + index);
190 });
191 }
192}
193
194///////////////////////// ADD DATUM ////////////////////////////////////
195
196static void addDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
197 const std::vector<FmqRequestDatum>& serialized) {
198 const std::vector<FmqRequestDatum>& extra = getBadRequestPacketEntries();
199 for (size_t index = 0; index <= serialized.size(); ++index) {
200 for (size_t type = 0; type < extra.size(); ++type) {
201 const std::string message = "addDatum: added datum type " + std::to_string(type) +
202 " at index " + std::to_string(index);
203 validate(sender, receiver, message, serialized,
204 [index, type, &extra](std::vector<FmqRequestDatum>* serialized) {
205 serialized->insert(serialized->begin() + index, extra[type]);
206 });
207 }
208 }
209}
210
211///////////////////////// MUTATE DATUM ////////////////////////////////////
212
213static bool interestingCase(const FmqRequestDatum& lhs, const FmqRequestDatum& rhs) {
214 using Discriminator = FmqRequestDatum::hidl_discriminator;
215
216 const bool differentValues = (lhs != rhs);
217 const bool sameDiscriminator = (lhs.getDiscriminator() == rhs.getDiscriminator());
218 const auto discriminator = rhs.getDiscriminator();
219 const bool isDimensionValue = (discriminator == Discriminator::inputOperandDimensionValue ||
220 discriminator == Discriminator::outputOperandDimensionValue);
221
222 return differentValues && !(sameDiscriminator && isDimensionValue);
223}
224
225static void mutateDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
226 const std::vector<FmqRequestDatum>& serialized) {
227 const std::vector<FmqRequestDatum>& change = getBadRequestPacketEntries();
228 for (size_t index = 0; index < serialized.size(); ++index) {
229 for (size_t type = 0; type < change.size(); ++type) {
230 if (interestingCase(serialized[index], change[type])) {
231 const std::string message = "mutateDatum: changed datum at index " +
232 std::to_string(index) + " to datum type " +
233 std::to_string(type);
234 validate(sender, receiver, message, serialized,
235 [index, type, &change](std::vector<FmqRequestDatum>* serialized) {
236 (*serialized)[index] = change[type];
237 });
238 }
239 }
240 }
241}
242
243///////////////////////// BURST VALIATION TESTS ////////////////////////////////////
244
245static void validateBurstSerialization(const sp<IPreparedModel>& preparedModel,
246 const Request& request) {
247 // create burst
248 std::unique_ptr<RequestChannelSender> sender;
249 std::unique_ptr<ResultChannelReceiver> receiver;
250 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
251 sp<IBurstContext> context;
252 ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context));
253 ASSERT_NE(nullptr, sender.get());
254 ASSERT_NE(nullptr, receiver.get());
255 ASSERT_NE(nullptr, context.get());
256
257 // load memory into callback slots
258 std::vector<intptr_t> keys;
259 keys.reserve(request.pools.size());
260 std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(keys),
261 [](const auto& pool) { return reinterpret_cast<intptr_t>(&pool); });
262 const std::vector<int32_t> slots = callback->getSlots(request.pools, keys);
263
264 // ensure slot std::numeric_limits<int32_t>::max() doesn't exist (for
265 // subsequent slot validation testing)
266 ASSERT_TRUE(std::all_of(slots.begin(), slots.end(), [](int32_t slot) {
267 return slot != std::numeric_limits<int32_t>::max();
268 }));
269
270 // serialize the request
271 const auto serialized = android::nn::serialize(request, MeasureTiming::YES, slots);
272
273 // validations
274 removeDatumTest(sender.get(), receiver.get(), serialized);
275 addDatumTest(sender.get(), receiver.get(), serialized);
276 mutateDatumTest(sender.get(), receiver.get(), serialized);
277}
278
279// This test validates that when the Result message size exceeds length of the
280// result FMQ, the service instance gracefully fails and returns an error.
281static void validateBurstFmqLength(const sp<IPreparedModel>& preparedModel,
282 const Request& request) {
283 // create regular burst
284 std::shared_ptr<ExecutionBurstController> controllerRegular;
285 ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength(
286 preparedModel, kExecutionBurstChannelLength, &controllerRegular));
287 ASSERT_NE(nullptr, controllerRegular.get());
288
289 // create burst with small output channel
290 std::shared_ptr<ExecutionBurstController> controllerSmall;
291 ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength(
292 preparedModel, kExecutionBurstChannelSmallLength, &controllerSmall));
293 ASSERT_NE(nullptr, controllerSmall.get());
294
295 // load memory into callback slots
296 std::vector<intptr_t> keys(request.pools.size());
297 for (size_t i = 0; i < keys.size(); ++i) {
298 keys[i] = reinterpret_cast<intptr_t>(&request.pools[i]);
299 }
300
301 // collect serialized result by running regular burst
Michael Butler648ada52019-07-25 17:22:11 -0700302 const auto [nRegular, outputShapesRegular, timingRegular, fallbackRegular] =
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100303 controllerRegular->compute(request, MeasureTiming::NO, keys);
Michael Butlerd09c0ee2020-03-12 15:12:23 -0700304 const V1_0::ErrorStatus statusRegular = nn::legacyConvertResultCodeToErrorStatus(nRegular);
Michael Butler648ada52019-07-25 17:22:11 -0700305 EXPECT_FALSE(fallbackRegular);
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100306
307 // skip test if regular burst output isn't useful for testing a failure
308 // caused by having too small of a length for the result FMQ
309 const std::vector<FmqResultDatum> serialized =
310 android::nn::serialize(statusRegular, outputShapesRegular, timingRegular);
Michael Butler79a41d72019-12-11 19:08:08 -0800311 if (statusRegular != V1_0::ErrorStatus::NONE ||
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100312 serialized.size() <= kExecutionBurstChannelSmallLength) {
313 return;
314 }
315
316 // by this point, execution should fail because the result channel isn't
317 // large enough to return the serialized result
Michael Butler648ada52019-07-25 17:22:11 -0700318 const auto [nSmall, outputShapesSmall, timingSmall, fallbackSmall] =
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100319 controllerSmall->compute(request, MeasureTiming::NO, keys);
Michael Butlerd09c0ee2020-03-12 15:12:23 -0700320 const V1_0::ErrorStatus statusSmall = nn::legacyConvertResultCodeToErrorStatus(nSmall);
Michael Butler79a41d72019-12-11 19:08:08 -0800321 EXPECT_NE(V1_0::ErrorStatus::NONE, statusSmall);
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100322 EXPECT_EQ(0u, outputShapesSmall.size());
323 EXPECT_TRUE(badTiming(timingSmall));
Michael Butler648ada52019-07-25 17:22:11 -0700324 EXPECT_FALSE(fallbackSmall);
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100325}
326
327static bool isSanitized(const FmqResultDatum& datum) {
328 using Discriminator = FmqResultDatum::hidl_discriminator;
329
330 // check to ensure the padding values in the returned
331 // FmqResultDatum::OperandInformation are initialized to 0
332 if (datum.getDiscriminator() == Discriminator::operandInformation) {
333 static_assert(
334 offsetof(FmqResultDatum::OperandInformation, isSufficient) == 0,
335 "unexpected value for offset of FmqResultDatum::OperandInformation::isSufficient");
336 static_assert(
337 sizeof(FmqResultDatum::OperandInformation::isSufficient) == 1,
338 "unexpected value for size of FmqResultDatum::OperandInformation::isSufficient");
339 static_assert(offsetof(FmqResultDatum::OperandInformation, numberOfDimensions) == 4,
340 "unexpected value for offset of "
341 "FmqResultDatum::OperandInformation::numberOfDimensions");
342 static_assert(sizeof(FmqResultDatum::OperandInformation::numberOfDimensions) == 4,
343 "unexpected value for size of "
344 "FmqResultDatum::OperandInformation::numberOfDimensions");
345 static_assert(sizeof(FmqResultDatum::OperandInformation) == 8,
346 "unexpected value for size of "
347 "FmqResultDatum::OperandInformation");
348
349 constexpr size_t paddingOffset =
350 offsetof(FmqResultDatum::OperandInformation, isSufficient) +
351 sizeof(FmqResultDatum::OperandInformation::isSufficient);
352 constexpr size_t paddingSize =
353 offsetof(FmqResultDatum::OperandInformation, numberOfDimensions) - paddingOffset;
354
355 FmqResultDatum::OperandInformation initialized{};
356 std::memset(&initialized, 0, sizeof(initialized));
357
358 const char* initializedPaddingStart =
359 reinterpret_cast<const char*>(&initialized) + paddingOffset;
360 const char* datumPaddingStart =
361 reinterpret_cast<const char*>(&datum.operandInformation()) + paddingOffset;
362
363 return std::memcmp(datumPaddingStart, initializedPaddingStart, paddingSize) == 0;
364 }
365
366 // there are no other padding initialization checks required, so return true
367 // for any sum-type that isn't FmqResultDatum::OperandInformation
368 return true;
369}
370
371static void validateBurstSanitized(const sp<IPreparedModel>& preparedModel,
372 const Request& request) {
373 // create burst
374 std::unique_ptr<RequestChannelSender> sender;
375 std::unique_ptr<ResultChannelReceiver> receiver;
376 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
377 sp<IBurstContext> context;
378 ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context));
379 ASSERT_NE(nullptr, sender.get());
380 ASSERT_NE(nullptr, receiver.get());
381 ASSERT_NE(nullptr, context.get());
382
383 // load memory into callback slots
384 std::vector<intptr_t> keys;
385 keys.reserve(request.pools.size());
386 std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(keys),
387 [](const auto& pool) { return reinterpret_cast<intptr_t>(&pool); });
388 const std::vector<int32_t> slots = callback->getSlots(request.pools, keys);
389
390 // send valid request
391 ASSERT_TRUE(sender->send(request, MeasureTiming::YES, slots));
392
393 // receive valid result
394 auto serialized = receiver->getPacketBlocking();
395 ASSERT_TRUE(serialized.has_value());
396
397 // sanitize result
398 ASSERT_TRUE(std::all_of(serialized->begin(), serialized->end(), isSanitized))
399 << "The result serialized data is not properly sanitized";
400}
401
402///////////////////////////// ENTRY POINT //////////////////////////////////
403
404void validateBurst(const sp<IPreparedModel>& preparedModel, const Request& request) {
405 ASSERT_NO_FATAL_FAILURE(validateBurstSerialization(preparedModel, request));
406 ASSERT_NO_FATAL_FAILURE(validateBurstFmqLength(preparedModel, request));
407 ASSERT_NO_FATAL_FAILURE(validateBurstSanitized(preparedModel, request));
408}
409
Lev Proleev26d1bc82019-08-30 11:57:18 +0100410} // namespace android::hardware::neuralnetworks::V1_3::vts::functional