blob: c02d0206e20fbfaf00a425e631ab2b0248bdbc6f [file] [log] [blame]
Michael Butler20f28a22019-04-26 17:46:08 -07001/*
Michael Butler0a1ad962019-04-30 13:51:24 -07002 * Copyright (C) 2019 The Android Open Source Project
Michael Butler20f28a22019-04-26 17:46:08 -07003 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "neuralnetworks_hidl_hal_test"
18
19#include "VtsHalNeuralnetworks.h"
20
Slava Shklyaev73ee79d2019-05-14 14:15:14 +010021#include "1.2/Callbacks.h"
Michael Butler20f28a22019-04-26 17:46:08 -070022#include "ExecutionBurstController.h"
23#include "ExecutionBurstServer.h"
Xusong Wangbcaa7822019-08-23 16:10:54 -070024#include "GeneratedTestHarness.h"
Michael Butler20f28a22019-04-26 17:46:08 -070025#include "TestHarness.h"
26#include "Utils.h"
27
28#include <android-base/logging.h>
Michael Butlerddb770f2019-05-02 18:16:13 -070029#include <cstring>
Michael Butler20f28a22019-04-26 17:46:08 -070030
Michael Butler62749b92019-08-26 23:55:47 -070031namespace android::hardware::neuralnetworks::V1_2::vts::functional {
Michael Butler20f28a22019-04-26 17:46:08 -070032
Michael Butler62749b92019-08-26 23:55:47 -070033using nn::ExecutionBurstController;
34using nn::RequestChannelSender;
35using nn::ResultChannelReceiver;
36using V1_0::ErrorStatus;
37using V1_0::Request;
38using ExecutionBurstCallback = ExecutionBurstController::ExecutionBurstCallback;
Michael Butler20f28a22019-04-26 17:46:08 -070039
Michael Butler0a1ad962019-04-30 13:51:24 -070040// This constant value represents the length of an FMQ that is large enough to
41// return a result from a burst execution for all of the generated test cases.
Michael Butler20f28a22019-04-26 17:46:08 -070042constexpr size_t kExecutionBurstChannelLength = 1024;
Michael Butler0a1ad962019-04-30 13:51:24 -070043
44// This constant value represents a length of an FMQ that is not large enough
45// to return a result from a burst execution for some of the generated test
46// cases.
Michael Butler20f28a22019-04-26 17:46:08 -070047constexpr size_t kExecutionBurstChannelSmallLength = 8;
48
49///////////////////////// UTILITY FUNCTIONS /////////////////////////
50
51static bool badTiming(Timing timing) {
52 return timing.timeOnDevice == UINT64_MAX && timing.timeInDriver == UINT64_MAX;
53}
54
55static void createBurst(const sp<IPreparedModel>& preparedModel, const sp<IBurstCallback>& callback,
56 std::unique_ptr<RequestChannelSender>* sender,
57 std::unique_ptr<ResultChannelReceiver>* receiver,
Michael Butler0a1ad962019-04-30 13:51:24 -070058 sp<IBurstContext>* context,
59 size_t resultChannelLength = kExecutionBurstChannelLength) {
Michael Butler20f28a22019-04-26 17:46:08 -070060 ASSERT_NE(nullptr, preparedModel.get());
61 ASSERT_NE(nullptr, sender);
62 ASSERT_NE(nullptr, receiver);
63 ASSERT_NE(nullptr, context);
64
65 // create FMQ objects
66 auto [fmqRequestChannel, fmqRequestDescriptor] =
67 RequestChannelSender::create(kExecutionBurstChannelLength, /*blocking=*/true);
68 auto [fmqResultChannel, fmqResultDescriptor] =
Michael Butler0a1ad962019-04-30 13:51:24 -070069 ResultChannelReceiver::create(resultChannelLength, /*blocking=*/true);
Michael Butler20f28a22019-04-26 17:46:08 -070070 ASSERT_NE(nullptr, fmqRequestChannel.get());
71 ASSERT_NE(nullptr, fmqResultChannel.get());
72 ASSERT_NE(nullptr, fmqRequestDescriptor);
73 ASSERT_NE(nullptr, fmqResultDescriptor);
74
75 // configure burst
76 ErrorStatus errorStatus;
77 sp<IBurstContext> burstContext;
78 const Return<void> ret = preparedModel->configureExecutionBurst(
79 callback, *fmqRequestDescriptor, *fmqResultDescriptor,
80 [&errorStatus, &burstContext](ErrorStatus status, const sp<IBurstContext>& context) {
81 errorStatus = status;
82 burstContext = context;
83 });
84 ASSERT_TRUE(ret.isOk());
85 ASSERT_EQ(ErrorStatus::NONE, errorStatus);
86 ASSERT_NE(nullptr, burstContext.get());
87
88 // return values
89 *sender = std::move(fmqRequestChannel);
90 *receiver = std::move(fmqResultChannel);
91 *context = burstContext;
92}
93
94static void createBurstWithResultChannelLength(
Michael Butler0a1ad962019-04-30 13:51:24 -070095 const sp<IPreparedModel>& preparedModel, size_t resultChannelLength,
96 std::shared_ptr<ExecutionBurstController>* controller) {
Michael Butler20f28a22019-04-26 17:46:08 -070097 ASSERT_NE(nullptr, preparedModel.get());
98 ASSERT_NE(nullptr, controller);
99
100 // create FMQ objects
Michael Butler0a1ad962019-04-30 13:51:24 -0700101 std::unique_ptr<RequestChannelSender> sender;
102 std::unique_ptr<ResultChannelReceiver> receiver;
Michael Butler20f28a22019-04-26 17:46:08 -0700103 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
Michael Butler0a1ad962019-04-30 13:51:24 -0700104 sp<IBurstContext> context;
105 ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context,
106 resultChannelLength));
107 ASSERT_NE(nullptr, sender.get());
108 ASSERT_NE(nullptr, receiver.get());
109 ASSERT_NE(nullptr, context.get());
Michael Butler20f28a22019-04-26 17:46:08 -0700110
111 // return values
Michael Butler0a1ad962019-04-30 13:51:24 -0700112 *controller = std::make_shared<ExecutionBurstController>(std::move(sender), std::move(receiver),
113 context, callback);
Michael Butler20f28a22019-04-26 17:46:08 -0700114}
115
116// Primary validation function. This function will take a valid serialized
117// request, apply a mutation to it to invalidate the serialized request, then
118// pass it to interface calls that use the serialized request. Note that the
119// serialized request here is passed by value, and any mutation to the
120// serialized request does not leave this function.
121static void validate(RequestChannelSender* sender, ResultChannelReceiver* receiver,
122 const std::string& message, std::vector<FmqRequestDatum> serialized,
123 const std::function<void(std::vector<FmqRequestDatum>*)>& mutation) {
124 mutation(&serialized);
125
126 // skip if packet is too large to send
127 if (serialized.size() > kExecutionBurstChannelLength) {
128 return;
129 }
130
131 SCOPED_TRACE(message);
132
133 // send invalid packet
Michael Butler0a1ad962019-04-30 13:51:24 -0700134 ASSERT_TRUE(sender->sendPacket(serialized));
Michael Butler20f28a22019-04-26 17:46:08 -0700135
136 // receive error
137 auto results = receiver->getBlocking();
138 ASSERT_TRUE(results.has_value());
139 const auto [status, outputShapes, timing] = std::move(*results);
140 EXPECT_NE(ErrorStatus::NONE, status);
141 EXPECT_EQ(0u, outputShapes.size());
142 EXPECT_TRUE(badTiming(timing));
143}
144
Michael Butler0a1ad962019-04-30 13:51:24 -0700145// For validation, valid packet entries are mutated to invalid packet entries,
146// or invalid packet entries are inserted into valid packets. This function
147// creates pre-set invalid packet entries for convenience.
148static std::vector<FmqRequestDatum> createBadRequestPacketEntries() {
Michael Butler20f28a22019-04-26 17:46:08 -0700149 const FmqRequestDatum::PacketInformation packetInformation = {
150 /*.packetSize=*/10, /*.numberOfInputOperands=*/10, /*.numberOfOutputOperands=*/10,
151 /*.numberOfPools=*/10};
152 const FmqRequestDatum::OperandInformation operandInformation = {
153 /*.hasNoValue=*/false, /*.location=*/{}, /*.numberOfDimensions=*/10};
154 const int32_t invalidPoolIdentifier = std::numeric_limits<int32_t>::max();
Michael Butler0a1ad962019-04-30 13:51:24 -0700155 std::vector<FmqRequestDatum> bad(7);
156 bad[0].packetInformation(packetInformation);
157 bad[1].inputOperandInformation(operandInformation);
158 bad[2].inputOperandDimensionValue(0);
159 bad[3].outputOperandInformation(operandInformation);
160 bad[4].outputOperandDimensionValue(0);
161 bad[5].poolIdentifier(invalidPoolIdentifier);
162 bad[6].measureTiming(MeasureTiming::YES);
163 return bad;
Michael Butler20f28a22019-04-26 17:46:08 -0700164}
165
Michael Butler0a1ad962019-04-30 13:51:24 -0700166// For validation, valid packet entries are mutated to invalid packet entries,
167// or invalid packet entries are inserted into valid packets. This function
168// retrieves pre-set invalid packet entries for convenience. This function
169// caches these data so they can be reused on subsequent validation checks.
170static const std::vector<FmqRequestDatum>& getBadRequestPacketEntries() {
171 static const std::vector<FmqRequestDatum> bad = createBadRequestPacketEntries();
172 return bad;
Michael Butler20f28a22019-04-26 17:46:08 -0700173}
174
175///////////////////////// REMOVE DATUM ////////////////////////////////////
176
177static void removeDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
178 const std::vector<FmqRequestDatum>& serialized) {
179 for (size_t index = 0; index < serialized.size(); ++index) {
180 const std::string message = "removeDatum: removed datum at index " + std::to_string(index);
181 validate(sender, receiver, message, serialized,
182 [index](std::vector<FmqRequestDatum>* serialized) {
183 serialized->erase(serialized->begin() + index);
184 });
185 }
186}
187
188///////////////////////// ADD DATUM ////////////////////////////////////
189
190static void addDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
191 const std::vector<FmqRequestDatum>& serialized) {
Michael Butler0a1ad962019-04-30 13:51:24 -0700192 const std::vector<FmqRequestDatum>& extra = getBadRequestPacketEntries();
Michael Butler20f28a22019-04-26 17:46:08 -0700193 for (size_t index = 0; index <= serialized.size(); ++index) {
194 for (size_t type = 0; type < extra.size(); ++type) {
195 const std::string message = "addDatum: added datum type " + std::to_string(type) +
196 " at index " + std::to_string(index);
197 validate(sender, receiver, message, serialized,
198 [index, type, &extra](std::vector<FmqRequestDatum>* serialized) {
199 serialized->insert(serialized->begin() + index, extra[type]);
200 });
201 }
202 }
203}
204
205///////////////////////// MUTATE DATUM ////////////////////////////////////
206
207static bool interestingCase(const FmqRequestDatum& lhs, const FmqRequestDatum& rhs) {
208 using Discriminator = FmqRequestDatum::hidl_discriminator;
209
210 const bool differentValues = (lhs != rhs);
Michael Butler0a1ad962019-04-30 13:51:24 -0700211 const bool sameDiscriminator = (lhs.getDiscriminator() == rhs.getDiscriminator());
Michael Butler20f28a22019-04-26 17:46:08 -0700212 const auto discriminator = rhs.getDiscriminator();
213 const bool isDimensionValue = (discriminator == Discriminator::inputOperandDimensionValue ||
214 discriminator == Discriminator::outputOperandDimensionValue);
215
Michael Butler0a1ad962019-04-30 13:51:24 -0700216 return differentValues && !(sameDiscriminator && isDimensionValue);
Michael Butler20f28a22019-04-26 17:46:08 -0700217}
218
219static void mutateDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
220 const std::vector<FmqRequestDatum>& serialized) {
Michael Butler0a1ad962019-04-30 13:51:24 -0700221 const std::vector<FmqRequestDatum>& change = getBadRequestPacketEntries();
Michael Butler20f28a22019-04-26 17:46:08 -0700222 for (size_t index = 0; index < serialized.size(); ++index) {
223 for (size_t type = 0; type < change.size(); ++type) {
224 if (interestingCase(serialized[index], change[type])) {
225 const std::string message = "mutateDatum: changed datum at index " +
226 std::to_string(index) + " to datum type " +
227 std::to_string(type);
228 validate(sender, receiver, message, serialized,
229 [index, type, &change](std::vector<FmqRequestDatum>* serialized) {
230 (*serialized)[index] = change[type];
231 });
232 }
233 }
234 }
235}
236
237///////////////////////// BURST VALIATION TESTS ////////////////////////////////////
238
239static void validateBurstSerialization(const sp<IPreparedModel>& preparedModel,
Xusong Wang6d0270b2019-08-09 16:45:24 -0700240 const Request& request) {
Michael Butler20f28a22019-04-26 17:46:08 -0700241 // create burst
242 std::unique_ptr<RequestChannelSender> sender;
243 std::unique_ptr<ResultChannelReceiver> receiver;
244 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
245 sp<IBurstContext> context;
246 ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context));
247 ASSERT_NE(nullptr, sender.get());
248 ASSERT_NE(nullptr, receiver.get());
249 ASSERT_NE(nullptr, context.get());
250
Xusong Wang6d0270b2019-08-09 16:45:24 -0700251 // load memory into callback slots
252 std::vector<intptr_t> keys;
253 keys.reserve(request.pools.size());
254 std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(keys),
255 [](const auto& pool) { return reinterpret_cast<intptr_t>(&pool); });
256 const std::vector<int32_t> slots = callback->getSlots(request.pools, keys);
Michael Butler20f28a22019-04-26 17:46:08 -0700257
Xusong Wang6d0270b2019-08-09 16:45:24 -0700258 // ensure slot std::numeric_limits<int32_t>::max() doesn't exist (for
259 // subsequent slot validation testing)
260 ASSERT_TRUE(std::all_of(slots.begin(), slots.end(), [](int32_t slot) {
261 return slot != std::numeric_limits<int32_t>::max();
262 }));
Michael Butler20f28a22019-04-26 17:46:08 -0700263
Xusong Wang6d0270b2019-08-09 16:45:24 -0700264 // serialize the request
265 const auto serialized = ::android::nn::serialize(request, MeasureTiming::YES, slots);
Michael Butler20f28a22019-04-26 17:46:08 -0700266
Xusong Wang6d0270b2019-08-09 16:45:24 -0700267 // validations
268 removeDatumTest(sender.get(), receiver.get(), serialized);
269 addDatumTest(sender.get(), receiver.get(), serialized);
270 mutateDatumTest(sender.get(), receiver.get(), serialized);
Michael Butler20f28a22019-04-26 17:46:08 -0700271}
272
Michael Butler0a1ad962019-04-30 13:51:24 -0700273// This test validates that when the Result message size exceeds length of the
274// result FMQ, the service instance gracefully fails and returns an error.
Michael Butler20f28a22019-04-26 17:46:08 -0700275static void validateBurstFmqLength(const sp<IPreparedModel>& preparedModel,
Xusong Wang6d0270b2019-08-09 16:45:24 -0700276 const Request& request) {
Michael Butler20f28a22019-04-26 17:46:08 -0700277 // create regular burst
278 std::shared_ptr<ExecutionBurstController> controllerRegular;
Michael Butler0a1ad962019-04-30 13:51:24 -0700279 ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength(
280 preparedModel, kExecutionBurstChannelLength, &controllerRegular));
Michael Butler20f28a22019-04-26 17:46:08 -0700281 ASSERT_NE(nullptr, controllerRegular.get());
282
283 // create burst with small output channel
284 std::shared_ptr<ExecutionBurstController> controllerSmall;
Michael Butler0a1ad962019-04-30 13:51:24 -0700285 ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength(
286 preparedModel, kExecutionBurstChannelSmallLength, &controllerSmall));
Michael Butler20f28a22019-04-26 17:46:08 -0700287 ASSERT_NE(nullptr, controllerSmall.get());
288
Xusong Wang6d0270b2019-08-09 16:45:24 -0700289 // load memory into callback slots
290 std::vector<intptr_t> keys(request.pools.size());
291 for (size_t i = 0; i < keys.size(); ++i) {
292 keys[i] = reinterpret_cast<intptr_t>(&request.pools[i]);
Michael Butler20f28a22019-04-26 17:46:08 -0700293 }
Xusong Wang6d0270b2019-08-09 16:45:24 -0700294
295 // collect serialized result by running regular burst
296 const auto [statusRegular, outputShapesRegular, timingRegular] =
297 controllerRegular->compute(request, MeasureTiming::NO, keys);
298
299 // skip test if regular burst output isn't useful for testing a failure
300 // caused by having too small of a length for the result FMQ
301 const std::vector<FmqResultDatum> serialized =
302 ::android::nn::serialize(statusRegular, outputShapesRegular, timingRegular);
303 if (statusRegular != ErrorStatus::NONE ||
304 serialized.size() <= kExecutionBurstChannelSmallLength) {
305 return;
306 }
307
308 // by this point, execution should fail because the result channel isn't
309 // large enough to return the serialized result
310 const auto [statusSmall, outputShapesSmall, timingSmall] =
311 controllerSmall->compute(request, MeasureTiming::NO, keys);
312 EXPECT_NE(ErrorStatus::NONE, statusSmall);
313 EXPECT_EQ(0u, outputShapesSmall.size());
314 EXPECT_TRUE(badTiming(timingSmall));
Michael Butler20f28a22019-04-26 17:46:08 -0700315}
316
Michael Butlerddb770f2019-05-02 18:16:13 -0700317static bool isSanitized(const FmqResultDatum& datum) {
318 using Discriminator = FmqResultDatum::hidl_discriminator;
319
320 // check to ensure the padding values in the returned
321 // FmqResultDatum::OperandInformation are initialized to 0
322 if (datum.getDiscriminator() == Discriminator::operandInformation) {
323 static_assert(
324 offsetof(FmqResultDatum::OperandInformation, isSufficient) == 0,
325 "unexpected value for offset of FmqResultDatum::OperandInformation::isSufficient");
326 static_assert(
327 sizeof(FmqResultDatum::OperandInformation::isSufficient) == 1,
328 "unexpected value for size of FmqResultDatum::OperandInformation::isSufficient");
329 static_assert(offsetof(FmqResultDatum::OperandInformation, numberOfDimensions) == 4,
330 "unexpected value for offset of "
331 "FmqResultDatum::OperandInformation::numberOfDimensions");
332 static_assert(sizeof(FmqResultDatum::OperandInformation::numberOfDimensions) == 4,
333 "unexpected value for size of "
334 "FmqResultDatum::OperandInformation::numberOfDimensions");
335 static_assert(sizeof(FmqResultDatum::OperandInformation) == 8,
336 "unexpected value for size of "
337 "FmqResultDatum::OperandInformation");
338
339 constexpr size_t paddingOffset =
340 offsetof(FmqResultDatum::OperandInformation, isSufficient) +
341 sizeof(FmqResultDatum::OperandInformation::isSufficient);
342 constexpr size_t paddingSize =
343 offsetof(FmqResultDatum::OperandInformation, numberOfDimensions) - paddingOffset;
344
345 FmqResultDatum::OperandInformation initialized{};
346 std::memset(&initialized, 0, sizeof(initialized));
347
348 const char* initializedPaddingStart =
349 reinterpret_cast<const char*>(&initialized) + paddingOffset;
350 const char* datumPaddingStart =
351 reinterpret_cast<const char*>(&datum.operandInformation()) + paddingOffset;
352
353 return std::memcmp(datumPaddingStart, initializedPaddingStart, paddingSize) == 0;
354 }
355
356 // there are no other padding initialization checks required, so return true
357 // for any sum-type that isn't FmqResultDatum::OperandInformation
358 return true;
359}
360
361static void validateBurstSanitized(const sp<IPreparedModel>& preparedModel,
Xusong Wang323ba2e2019-08-19 10:37:18 -0700362 const Request& request) {
Michael Butlerddb770f2019-05-02 18:16:13 -0700363 // create burst
364 std::unique_ptr<RequestChannelSender> sender;
365 std::unique_ptr<ResultChannelReceiver> receiver;
366 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
367 sp<IBurstContext> context;
368 ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context));
369 ASSERT_NE(nullptr, sender.get());
370 ASSERT_NE(nullptr, receiver.get());
371 ASSERT_NE(nullptr, context.get());
372
Xusong Wang323ba2e2019-08-19 10:37:18 -0700373 // load memory into callback slots
374 std::vector<intptr_t> keys;
375 keys.reserve(request.pools.size());
376 std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(keys),
377 [](const auto& pool) { return reinterpret_cast<intptr_t>(&pool); });
378 const std::vector<int32_t> slots = callback->getSlots(request.pools, keys);
Michael Butlerddb770f2019-05-02 18:16:13 -0700379
Xusong Wang323ba2e2019-08-19 10:37:18 -0700380 // send valid request
381 ASSERT_TRUE(sender->send(request, MeasureTiming::YES, slots));
Michael Butlerddb770f2019-05-02 18:16:13 -0700382
Xusong Wang323ba2e2019-08-19 10:37:18 -0700383 // receive valid result
384 auto serialized = receiver->getPacketBlocking();
385 ASSERT_TRUE(serialized.has_value());
Michael Butlerddb770f2019-05-02 18:16:13 -0700386
Xusong Wang323ba2e2019-08-19 10:37:18 -0700387 // sanitize result
388 ASSERT_TRUE(std::all_of(serialized->begin(), serialized->end(), isSanitized))
389 << "The result serialized data is not properly sanitized";
Michael Butlerddb770f2019-05-02 18:16:13 -0700390}
391
Michael Butler20f28a22019-04-26 17:46:08 -0700392///////////////////////////// ENTRY POINT //////////////////////////////////
393
Michael Butler13b05162019-08-29 22:17:24 -0700394void validateBurst(const sp<IPreparedModel>& preparedModel, const Request& request) {
Xusong Wang6d0270b2019-08-09 16:45:24 -0700395 ASSERT_NO_FATAL_FAILURE(validateBurstSerialization(preparedModel, request));
396 ASSERT_NO_FATAL_FAILURE(validateBurstFmqLength(preparedModel, request));
Xusong Wang323ba2e2019-08-19 10:37:18 -0700397 ASSERT_NO_FATAL_FAILURE(validateBurstSanitized(preparedModel, request));
Michael Butler20f28a22019-04-26 17:46:08 -0700398}
399
Michael Butler62749b92019-08-26 23:55:47 -0700400} // namespace android::hardware::neuralnetworks::V1_2::vts::functional