blob: 6ff9dfd3a8e86bb5538c7a1fd960f9752f4e841d [file] [log] [blame]
Lev Proleev13fdfcd2019-08-30 11:35:34 +01001/*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "neuralnetworks_hidl_hal_test"
18
19#include "VtsHalNeuralnetworks.h"
20
21#include "1.2/Callbacks.h"
22#include "ExecutionBurstController.h"
23#include "ExecutionBurstServer.h"
24#include "GeneratedTestHarness.h"
25#include "TestHarness.h"
26#include "Utils.h"
27
28#include <android-base/logging.h>
Michael Butler648ada52019-07-25 17:22:11 -070029#include <chrono>
Lev Proleev13fdfcd2019-08-30 11:35:34 +010030#include <cstring>
31
Lev Proleev26d1bc82019-08-30 11:57:18 +010032namespace android::hardware::neuralnetworks::V1_3::vts::functional {
Lev Proleev13fdfcd2019-08-30 11:35:34 +010033
34using nn::ExecutionBurstController;
35using nn::RequestChannelSender;
36using nn::ResultChannelReceiver;
Lev Proleev13fdfcd2019-08-30 11:35:34 +010037using V1_0::Request;
Lev Proleev26d1bc82019-08-30 11:57:18 +010038using V1_2::FmqRequestDatum;
39using V1_2::FmqResultDatum;
40using V1_2::IBurstCallback;
41using V1_2::IBurstContext;
Lev Proleev26d1bc82019-08-30 11:57:18 +010042using V1_2::MeasureTiming;
43using V1_2::Timing;
Lev Proleev13fdfcd2019-08-30 11:35:34 +010044using ExecutionBurstCallback = ExecutionBurstController::ExecutionBurstCallback;
45
46// This constant value represents the length of an FMQ that is large enough to
47// return a result from a burst execution for all of the generated test cases.
48constexpr size_t kExecutionBurstChannelLength = 1024;
49
50// This constant value represents a length of an FMQ that is not large enough
51// to return a result from a burst execution for some of the generated test
52// cases.
53constexpr size_t kExecutionBurstChannelSmallLength = 8;
54
55///////////////////////// UTILITY FUNCTIONS /////////////////////////
56
57static bool badTiming(Timing timing) {
58 return timing.timeOnDevice == UINT64_MAX && timing.timeInDriver == UINT64_MAX;
59}
60
61static void createBurst(const sp<IPreparedModel>& preparedModel, const sp<IBurstCallback>& callback,
62 std::unique_ptr<RequestChannelSender>* sender,
63 std::unique_ptr<ResultChannelReceiver>* receiver,
64 sp<IBurstContext>* context,
65 size_t resultChannelLength = kExecutionBurstChannelLength) {
66 ASSERT_NE(nullptr, preparedModel.get());
67 ASSERT_NE(nullptr, sender);
68 ASSERT_NE(nullptr, receiver);
69 ASSERT_NE(nullptr, context);
70
71 // create FMQ objects
72 auto [fmqRequestChannel, fmqRequestDescriptor] =
Michael Butler648ada52019-07-25 17:22:11 -070073 RequestChannelSender::create(kExecutionBurstChannelLength);
Lev Proleev13fdfcd2019-08-30 11:35:34 +010074 auto [fmqResultChannel, fmqResultDescriptor] =
Michael Butler648ada52019-07-25 17:22:11 -070075 ResultChannelReceiver::create(resultChannelLength, std::chrono::microseconds{0});
Lev Proleev13fdfcd2019-08-30 11:35:34 +010076 ASSERT_NE(nullptr, fmqRequestChannel.get());
77 ASSERT_NE(nullptr, fmqResultChannel.get());
78 ASSERT_NE(nullptr, fmqRequestDescriptor);
79 ASSERT_NE(nullptr, fmqResultDescriptor);
80
81 // configure burst
Michael Butler79a41d72019-12-11 19:08:08 -080082 V1_0::ErrorStatus errorStatus;
Lev Proleev13fdfcd2019-08-30 11:35:34 +010083 sp<IBurstContext> burstContext;
84 const Return<void> ret = preparedModel->configureExecutionBurst(
85 callback, *fmqRequestDescriptor, *fmqResultDescriptor,
Michael Butler79a41d72019-12-11 19:08:08 -080086 [&errorStatus, &burstContext](V1_0::ErrorStatus status,
87 const sp<IBurstContext>& context) {
Lev Proleev13fdfcd2019-08-30 11:35:34 +010088 errorStatus = status;
89 burstContext = context;
90 });
91 ASSERT_TRUE(ret.isOk());
Michael Butler79a41d72019-12-11 19:08:08 -080092 ASSERT_EQ(V1_0::ErrorStatus::NONE, errorStatus);
Lev Proleev13fdfcd2019-08-30 11:35:34 +010093 ASSERT_NE(nullptr, burstContext.get());
94
95 // return values
96 *sender = std::move(fmqRequestChannel);
97 *receiver = std::move(fmqResultChannel);
98 *context = burstContext;
99}
100
101static void createBurstWithResultChannelLength(
102 const sp<IPreparedModel>& preparedModel, size_t resultChannelLength,
103 std::shared_ptr<ExecutionBurstController>* controller) {
104 ASSERT_NE(nullptr, preparedModel.get());
105 ASSERT_NE(nullptr, controller);
106
107 // create FMQ objects
108 std::unique_ptr<RequestChannelSender> sender;
109 std::unique_ptr<ResultChannelReceiver> receiver;
110 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
111 sp<IBurstContext> context;
112 ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context,
113 resultChannelLength));
114 ASSERT_NE(nullptr, sender.get());
115 ASSERT_NE(nullptr, receiver.get());
116 ASSERT_NE(nullptr, context.get());
117
118 // return values
119 *controller = std::make_shared<ExecutionBurstController>(std::move(sender), std::move(receiver),
120 context, callback);
121}
122
123// Primary validation function. This function will take a valid serialized
124// request, apply a mutation to it to invalidate the serialized request, then
125// pass it to interface calls that use the serialized request. Note that the
126// serialized request here is passed by value, and any mutation to the
127// serialized request does not leave this function.
128static void validate(RequestChannelSender* sender, ResultChannelReceiver* receiver,
129 const std::string& message, std::vector<FmqRequestDatum> serialized,
130 const std::function<void(std::vector<FmqRequestDatum>*)>& mutation) {
131 mutation(&serialized);
132
133 // skip if packet is too large to send
134 if (serialized.size() > kExecutionBurstChannelLength) {
135 return;
136 }
137
138 SCOPED_TRACE(message);
139
140 // send invalid packet
141 ASSERT_TRUE(sender->sendPacket(serialized));
142
143 // receive error
144 auto results = receiver->getBlocking();
145 ASSERT_TRUE(results.has_value());
146 const auto [status, outputShapes, timing] = std::move(*results);
Michael Butler79a41d72019-12-11 19:08:08 -0800147 EXPECT_NE(V1_0::ErrorStatus::NONE, status);
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100148 EXPECT_EQ(0u, outputShapes.size());
149 EXPECT_TRUE(badTiming(timing));
150}
151
152// For validation, valid packet entries are mutated to invalid packet entries,
153// or invalid packet entries are inserted into valid packets. This function
154// creates pre-set invalid packet entries for convenience.
155static std::vector<FmqRequestDatum> createBadRequestPacketEntries() {
156 const FmqRequestDatum::PacketInformation packetInformation = {
157 /*.packetSize=*/10, /*.numberOfInputOperands=*/10, /*.numberOfOutputOperands=*/10,
158 /*.numberOfPools=*/10};
159 const FmqRequestDatum::OperandInformation operandInformation = {
160 /*.hasNoValue=*/false, /*.location=*/{}, /*.numberOfDimensions=*/10};
161 const int32_t invalidPoolIdentifier = std::numeric_limits<int32_t>::max();
162 std::vector<FmqRequestDatum> bad(7);
163 bad[0].packetInformation(packetInformation);
164 bad[1].inputOperandInformation(operandInformation);
165 bad[2].inputOperandDimensionValue(0);
166 bad[3].outputOperandInformation(operandInformation);
167 bad[4].outputOperandDimensionValue(0);
168 bad[5].poolIdentifier(invalidPoolIdentifier);
169 bad[6].measureTiming(MeasureTiming::YES);
170 return bad;
171}
172
173// For validation, valid packet entries are mutated to invalid packet entries,
174// or invalid packet entries are inserted into valid packets. This function
175// retrieves pre-set invalid packet entries for convenience. This function
176// caches these data so they can be reused on subsequent validation checks.
177static const std::vector<FmqRequestDatum>& getBadRequestPacketEntries() {
178 static const std::vector<FmqRequestDatum> bad = createBadRequestPacketEntries();
179 return bad;
180}
181
182///////////////////////// REMOVE DATUM ////////////////////////////////////
183
184static void removeDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
185 const std::vector<FmqRequestDatum>& serialized) {
186 for (size_t index = 0; index < serialized.size(); ++index) {
187 const std::string message = "removeDatum: removed datum at index " + std::to_string(index);
188 validate(sender, receiver, message, serialized,
189 [index](std::vector<FmqRequestDatum>* serialized) {
190 serialized->erase(serialized->begin() + index);
191 });
192 }
193}
194
195///////////////////////// ADD DATUM ////////////////////////////////////
196
197static void addDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
198 const std::vector<FmqRequestDatum>& serialized) {
199 const std::vector<FmqRequestDatum>& extra = getBadRequestPacketEntries();
200 for (size_t index = 0; index <= serialized.size(); ++index) {
201 for (size_t type = 0; type < extra.size(); ++type) {
202 const std::string message = "addDatum: added datum type " + std::to_string(type) +
203 " at index " + std::to_string(index);
204 validate(sender, receiver, message, serialized,
205 [index, type, &extra](std::vector<FmqRequestDatum>* serialized) {
206 serialized->insert(serialized->begin() + index, extra[type]);
207 });
208 }
209 }
210}
211
212///////////////////////// MUTATE DATUM ////////////////////////////////////
213
214static bool interestingCase(const FmqRequestDatum& lhs, const FmqRequestDatum& rhs) {
215 using Discriminator = FmqRequestDatum::hidl_discriminator;
216
217 const bool differentValues = (lhs != rhs);
218 const bool sameDiscriminator = (lhs.getDiscriminator() == rhs.getDiscriminator());
219 const auto discriminator = rhs.getDiscriminator();
220 const bool isDimensionValue = (discriminator == Discriminator::inputOperandDimensionValue ||
221 discriminator == Discriminator::outputOperandDimensionValue);
222
223 return differentValues && !(sameDiscriminator && isDimensionValue);
224}
225
226static void mutateDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
227 const std::vector<FmqRequestDatum>& serialized) {
228 const std::vector<FmqRequestDatum>& change = getBadRequestPacketEntries();
229 for (size_t index = 0; index < serialized.size(); ++index) {
230 for (size_t type = 0; type < change.size(); ++type) {
231 if (interestingCase(serialized[index], change[type])) {
232 const std::string message = "mutateDatum: changed datum at index " +
233 std::to_string(index) + " to datum type " +
234 std::to_string(type);
235 validate(sender, receiver, message, serialized,
236 [index, type, &change](std::vector<FmqRequestDatum>* serialized) {
237 (*serialized)[index] = change[type];
238 });
239 }
240 }
241 }
242}
243
244///////////////////////// BURST VALIATION TESTS ////////////////////////////////////
245
246static void validateBurstSerialization(const sp<IPreparedModel>& preparedModel,
247 const Request& request) {
248 // create burst
249 std::unique_ptr<RequestChannelSender> sender;
250 std::unique_ptr<ResultChannelReceiver> receiver;
251 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
252 sp<IBurstContext> context;
253 ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context));
254 ASSERT_NE(nullptr, sender.get());
255 ASSERT_NE(nullptr, receiver.get());
256 ASSERT_NE(nullptr, context.get());
257
258 // load memory into callback slots
259 std::vector<intptr_t> keys;
260 keys.reserve(request.pools.size());
261 std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(keys),
262 [](const auto& pool) { return reinterpret_cast<intptr_t>(&pool); });
263 const std::vector<int32_t> slots = callback->getSlots(request.pools, keys);
264
265 // ensure slot std::numeric_limits<int32_t>::max() doesn't exist (for
266 // subsequent slot validation testing)
267 ASSERT_TRUE(std::all_of(slots.begin(), slots.end(), [](int32_t slot) {
268 return slot != std::numeric_limits<int32_t>::max();
269 }));
270
271 // serialize the request
272 const auto serialized = android::nn::serialize(request, MeasureTiming::YES, slots);
273
274 // validations
275 removeDatumTest(sender.get(), receiver.get(), serialized);
276 addDatumTest(sender.get(), receiver.get(), serialized);
277 mutateDatumTest(sender.get(), receiver.get(), serialized);
278}
279
280// This test validates that when the Result message size exceeds length of the
281// result FMQ, the service instance gracefully fails and returns an error.
282static void validateBurstFmqLength(const sp<IPreparedModel>& preparedModel,
283 const Request& request) {
284 // create regular burst
285 std::shared_ptr<ExecutionBurstController> controllerRegular;
286 ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength(
287 preparedModel, kExecutionBurstChannelLength, &controllerRegular));
288 ASSERT_NE(nullptr, controllerRegular.get());
289
290 // create burst with small output channel
291 std::shared_ptr<ExecutionBurstController> controllerSmall;
292 ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength(
293 preparedModel, kExecutionBurstChannelSmallLength, &controllerSmall));
294 ASSERT_NE(nullptr, controllerSmall.get());
295
296 // load memory into callback slots
297 std::vector<intptr_t> keys(request.pools.size());
298 for (size_t i = 0; i < keys.size(); ++i) {
299 keys[i] = reinterpret_cast<intptr_t>(&request.pools[i]);
300 }
301
302 // collect serialized result by running regular burst
Michael Butler648ada52019-07-25 17:22:11 -0700303 const auto [nRegular, outputShapesRegular, timingRegular, fallbackRegular] =
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100304 controllerRegular->compute(request, MeasureTiming::NO, keys);
Michael Butler79a41d72019-12-11 19:08:08 -0800305 const V1_0::ErrorStatus statusRegular =
306 nn::convertToV1_0(nn::convertResultCodeToErrorStatus(nRegular));
Michael Butler648ada52019-07-25 17:22:11 -0700307 EXPECT_FALSE(fallbackRegular);
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100308
309 // skip test if regular burst output isn't useful for testing a failure
310 // caused by having too small of a length for the result FMQ
311 const std::vector<FmqResultDatum> serialized =
312 android::nn::serialize(statusRegular, outputShapesRegular, timingRegular);
Michael Butler79a41d72019-12-11 19:08:08 -0800313 if (statusRegular != V1_0::ErrorStatus::NONE ||
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100314 serialized.size() <= kExecutionBurstChannelSmallLength) {
315 return;
316 }
317
318 // by this point, execution should fail because the result channel isn't
319 // large enough to return the serialized result
Michael Butler648ada52019-07-25 17:22:11 -0700320 const auto [nSmall, outputShapesSmall, timingSmall, fallbackSmall] =
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100321 controllerSmall->compute(request, MeasureTiming::NO, keys);
Michael Butler79a41d72019-12-11 19:08:08 -0800322 const V1_0::ErrorStatus statusSmall =
323 nn::convertToV1_0(nn::convertResultCodeToErrorStatus(nSmall));
324 EXPECT_NE(V1_0::ErrorStatus::NONE, statusSmall);
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100325 EXPECT_EQ(0u, outputShapesSmall.size());
326 EXPECT_TRUE(badTiming(timingSmall));
Michael Butler648ada52019-07-25 17:22:11 -0700327 EXPECT_FALSE(fallbackSmall);
Lev Proleev13fdfcd2019-08-30 11:35:34 +0100328}
329
330static bool isSanitized(const FmqResultDatum& datum) {
331 using Discriminator = FmqResultDatum::hidl_discriminator;
332
333 // check to ensure the padding values in the returned
334 // FmqResultDatum::OperandInformation are initialized to 0
335 if (datum.getDiscriminator() == Discriminator::operandInformation) {
336 static_assert(
337 offsetof(FmqResultDatum::OperandInformation, isSufficient) == 0,
338 "unexpected value for offset of FmqResultDatum::OperandInformation::isSufficient");
339 static_assert(
340 sizeof(FmqResultDatum::OperandInformation::isSufficient) == 1,
341 "unexpected value for size of FmqResultDatum::OperandInformation::isSufficient");
342 static_assert(offsetof(FmqResultDatum::OperandInformation, numberOfDimensions) == 4,
343 "unexpected value for offset of "
344 "FmqResultDatum::OperandInformation::numberOfDimensions");
345 static_assert(sizeof(FmqResultDatum::OperandInformation::numberOfDimensions) == 4,
346 "unexpected value for size of "
347 "FmqResultDatum::OperandInformation::numberOfDimensions");
348 static_assert(sizeof(FmqResultDatum::OperandInformation) == 8,
349 "unexpected value for size of "
350 "FmqResultDatum::OperandInformation");
351
352 constexpr size_t paddingOffset =
353 offsetof(FmqResultDatum::OperandInformation, isSufficient) +
354 sizeof(FmqResultDatum::OperandInformation::isSufficient);
355 constexpr size_t paddingSize =
356 offsetof(FmqResultDatum::OperandInformation, numberOfDimensions) - paddingOffset;
357
358 FmqResultDatum::OperandInformation initialized{};
359 std::memset(&initialized, 0, sizeof(initialized));
360
361 const char* initializedPaddingStart =
362 reinterpret_cast<const char*>(&initialized) + paddingOffset;
363 const char* datumPaddingStart =
364 reinterpret_cast<const char*>(&datum.operandInformation()) + paddingOffset;
365
366 return std::memcmp(datumPaddingStart, initializedPaddingStart, paddingSize) == 0;
367 }
368
369 // there are no other padding initialization checks required, so return true
370 // for any sum-type that isn't FmqResultDatum::OperandInformation
371 return true;
372}
373
374static void validateBurstSanitized(const sp<IPreparedModel>& preparedModel,
375 const Request& request) {
376 // create burst
377 std::unique_ptr<RequestChannelSender> sender;
378 std::unique_ptr<ResultChannelReceiver> receiver;
379 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
380 sp<IBurstContext> context;
381 ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context));
382 ASSERT_NE(nullptr, sender.get());
383 ASSERT_NE(nullptr, receiver.get());
384 ASSERT_NE(nullptr, context.get());
385
386 // load memory into callback slots
387 std::vector<intptr_t> keys;
388 keys.reserve(request.pools.size());
389 std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(keys),
390 [](const auto& pool) { return reinterpret_cast<intptr_t>(&pool); });
391 const std::vector<int32_t> slots = callback->getSlots(request.pools, keys);
392
393 // send valid request
394 ASSERT_TRUE(sender->send(request, MeasureTiming::YES, slots));
395
396 // receive valid result
397 auto serialized = receiver->getPacketBlocking();
398 ASSERT_TRUE(serialized.has_value());
399
400 // sanitize result
401 ASSERT_TRUE(std::all_of(serialized->begin(), serialized->end(), isSanitized))
402 << "The result serialized data is not properly sanitized";
403}
404
405///////////////////////////// ENTRY POINT //////////////////////////////////
406
407void validateBurst(const sp<IPreparedModel>& preparedModel, const Request& request) {
408 ASSERT_NO_FATAL_FAILURE(validateBurstSerialization(preparedModel, request));
409 ASSERT_NO_FATAL_FAILURE(validateBurstFmqLength(preparedModel, request));
410 ASSERT_NO_FATAL_FAILURE(validateBurstSanitized(preparedModel, request));
411}
412
Lev Proleev26d1bc82019-08-30 11:57:18 +0100413} // namespace android::hardware::neuralnetworks::V1_3::vts::functional