blob: 5eb6c99b4bbe5424ee59caf9e827dfc7decc0eaf [file] [log] [blame]
Michael Butlerd6e38fd2019-04-26 17:46:08 -07001/*
Michael Butler353a6242019-04-30 13:51:24 -07002 * Copyright (C) 2019 The Android Open Source Project
Michael Butlerd6e38fd2019-04-26 17:46:08 -07003 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "neuralnetworks_hidl_hal_test"
18
19#include "VtsHalNeuralnetworks.h"
20
Slava Shklyaev73ee79d2019-05-14 14:15:14 +010021#include "1.2/Callbacks.h"
Michael Butlerd6e38fd2019-04-26 17:46:08 -070022#include "ExecutionBurstController.h"
23#include "ExecutionBurstServer.h"
Xusong Wangbcaa7822019-08-23 16:10:54 -070024#include "GeneratedTestHarness.h"
Michael Butlerd6e38fd2019-04-26 17:46:08 -070025#include "TestHarness.h"
26#include "Utils.h"
27
28#include <android-base/logging.h>
29
Michael Butler62749b92019-08-26 23:55:47 -070030namespace android::hardware::neuralnetworks::V1_2::vts::functional {
Michael Butlerd6e38fd2019-04-26 17:46:08 -070031
Michael Butler62749b92019-08-26 23:55:47 -070032using nn::ExecutionBurstController;
33using nn::RequestChannelSender;
34using nn::ResultChannelReceiver;
35using V1_0::ErrorStatus;
36using V1_0::Request;
37using ExecutionBurstCallback = ExecutionBurstController::ExecutionBurstCallback;
Michael Butlerd6e38fd2019-04-26 17:46:08 -070038
Michael Butler353a6242019-04-30 13:51:24 -070039// This constant value represents the length of an FMQ that is large enough to
40// return a result from a burst execution for all of the generated test cases.
Michael Butlerd6e38fd2019-04-26 17:46:08 -070041constexpr size_t kExecutionBurstChannelLength = 1024;
Michael Butler353a6242019-04-30 13:51:24 -070042
43// This constant value represents a length of an FMQ that is not large enough
44// to return a result from a burst execution for some of the generated test
45// cases.
Michael Butlerd6e38fd2019-04-26 17:46:08 -070046constexpr size_t kExecutionBurstChannelSmallLength = 8;
47
48///////////////////////// UTILITY FUNCTIONS /////////////////////////
49
50static bool badTiming(Timing timing) {
51 return timing.timeOnDevice == UINT64_MAX && timing.timeInDriver == UINT64_MAX;
52}
53
54static void createBurst(const sp<IPreparedModel>& preparedModel, const sp<IBurstCallback>& callback,
55 std::unique_ptr<RequestChannelSender>* sender,
56 std::unique_ptr<ResultChannelReceiver>* receiver,
Michael Butler353a6242019-04-30 13:51:24 -070057 sp<IBurstContext>* context,
58 size_t resultChannelLength = kExecutionBurstChannelLength) {
Michael Butlerd6e38fd2019-04-26 17:46:08 -070059 ASSERT_NE(nullptr, preparedModel.get());
60 ASSERT_NE(nullptr, sender);
61 ASSERT_NE(nullptr, receiver);
62 ASSERT_NE(nullptr, context);
63
64 // create FMQ objects
65 auto [fmqRequestChannel, fmqRequestDescriptor] =
66 RequestChannelSender::create(kExecutionBurstChannelLength, /*blocking=*/true);
67 auto [fmqResultChannel, fmqResultDescriptor] =
Michael Butler353a6242019-04-30 13:51:24 -070068 ResultChannelReceiver::create(resultChannelLength, /*blocking=*/true);
Michael Butlerd6e38fd2019-04-26 17:46:08 -070069 ASSERT_NE(nullptr, fmqRequestChannel.get());
70 ASSERT_NE(nullptr, fmqResultChannel.get());
71 ASSERT_NE(nullptr, fmqRequestDescriptor);
72 ASSERT_NE(nullptr, fmqResultDescriptor);
73
74 // configure burst
75 ErrorStatus errorStatus;
76 sp<IBurstContext> burstContext;
77 const Return<void> ret = preparedModel->configureExecutionBurst(
78 callback, *fmqRequestDescriptor, *fmqResultDescriptor,
79 [&errorStatus, &burstContext](ErrorStatus status, const sp<IBurstContext>& context) {
80 errorStatus = status;
81 burstContext = context;
82 });
83 ASSERT_TRUE(ret.isOk());
84 ASSERT_EQ(ErrorStatus::NONE, errorStatus);
85 ASSERT_NE(nullptr, burstContext.get());
86
87 // return values
88 *sender = std::move(fmqRequestChannel);
89 *receiver = std::move(fmqResultChannel);
90 *context = burstContext;
91}
92
93static void createBurstWithResultChannelLength(
Michael Butler353a6242019-04-30 13:51:24 -070094 const sp<IPreparedModel>& preparedModel, size_t resultChannelLength,
95 std::shared_ptr<ExecutionBurstController>* controller) {
Michael Butlerd6e38fd2019-04-26 17:46:08 -070096 ASSERT_NE(nullptr, preparedModel.get());
97 ASSERT_NE(nullptr, controller);
98
99 // create FMQ objects
Michael Butler353a6242019-04-30 13:51:24 -0700100 std::unique_ptr<RequestChannelSender> sender;
101 std::unique_ptr<ResultChannelReceiver> receiver;
Michael Butlerd6e38fd2019-04-26 17:46:08 -0700102 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
Michael Butler353a6242019-04-30 13:51:24 -0700103 sp<IBurstContext> context;
104 ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context,
105 resultChannelLength));
106 ASSERT_NE(nullptr, sender.get());
107 ASSERT_NE(nullptr, receiver.get());
108 ASSERT_NE(nullptr, context.get());
Michael Butlerd6e38fd2019-04-26 17:46:08 -0700109
110 // return values
Michael Butler353a6242019-04-30 13:51:24 -0700111 *controller = std::make_shared<ExecutionBurstController>(std::move(sender), std::move(receiver),
112 context, callback);
Michael Butlerd6e38fd2019-04-26 17:46:08 -0700113}
114
115// Primary validation function. This function will take a valid serialized
116// request, apply a mutation to it to invalidate the serialized request, then
117// pass it to interface calls that use the serialized request. Note that the
118// serialized request here is passed by value, and any mutation to the
119// serialized request does not leave this function.
120static void validate(RequestChannelSender* sender, ResultChannelReceiver* receiver,
121 const std::string& message, std::vector<FmqRequestDatum> serialized,
122 const std::function<void(std::vector<FmqRequestDatum>*)>& mutation) {
123 mutation(&serialized);
124
125 // skip if packet is too large to send
126 if (serialized.size() > kExecutionBurstChannelLength) {
127 return;
128 }
129
130 SCOPED_TRACE(message);
131
132 // send invalid packet
Michael Butler353a6242019-04-30 13:51:24 -0700133 ASSERT_TRUE(sender->sendPacket(serialized));
Michael Butlerd6e38fd2019-04-26 17:46:08 -0700134
135 // receive error
136 auto results = receiver->getBlocking();
137 ASSERT_TRUE(results.has_value());
138 const auto [status, outputShapes, timing] = std::move(*results);
139 EXPECT_NE(ErrorStatus::NONE, status);
140 EXPECT_EQ(0u, outputShapes.size());
141 EXPECT_TRUE(badTiming(timing));
142}
143
Michael Butler353a6242019-04-30 13:51:24 -0700144// For validation, valid packet entries are mutated to invalid packet entries,
145// or invalid packet entries are inserted into valid packets. This function
146// creates pre-set invalid packet entries for convenience.
147static std::vector<FmqRequestDatum> createBadRequestPacketEntries() {
Michael Butlerd6e38fd2019-04-26 17:46:08 -0700148 const FmqRequestDatum::PacketInformation packetInformation = {
149 /*.packetSize=*/10, /*.numberOfInputOperands=*/10, /*.numberOfOutputOperands=*/10,
150 /*.numberOfPools=*/10};
151 const FmqRequestDatum::OperandInformation operandInformation = {
152 /*.hasNoValue=*/false, /*.location=*/{}, /*.numberOfDimensions=*/10};
153 const int32_t invalidPoolIdentifier = std::numeric_limits<int32_t>::max();
Michael Butler353a6242019-04-30 13:51:24 -0700154 std::vector<FmqRequestDatum> bad(7);
155 bad[0].packetInformation(packetInformation);
156 bad[1].inputOperandInformation(operandInformation);
157 bad[2].inputOperandDimensionValue(0);
158 bad[3].outputOperandInformation(operandInformation);
159 bad[4].outputOperandDimensionValue(0);
160 bad[5].poolIdentifier(invalidPoolIdentifier);
161 bad[6].measureTiming(MeasureTiming::YES);
162 return bad;
Michael Butlerd6e38fd2019-04-26 17:46:08 -0700163}
164
Michael Butler353a6242019-04-30 13:51:24 -0700165// For validation, valid packet entries are mutated to invalid packet entries,
166// or invalid packet entries are inserted into valid packets. This function
167// retrieves pre-set invalid packet entries for convenience. This function
168// caches these data so they can be reused on subsequent validation checks.
169static const std::vector<FmqRequestDatum>& getBadRequestPacketEntries() {
170 static const std::vector<FmqRequestDatum> bad = createBadRequestPacketEntries();
171 return bad;
Michael Butlerd6e38fd2019-04-26 17:46:08 -0700172}
173
174///////////////////////// REMOVE DATUM ////////////////////////////////////
175
176static void removeDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
177 const std::vector<FmqRequestDatum>& serialized) {
178 for (size_t index = 0; index < serialized.size(); ++index) {
179 const std::string message = "removeDatum: removed datum at index " + std::to_string(index);
180 validate(sender, receiver, message, serialized,
181 [index](std::vector<FmqRequestDatum>* serialized) {
182 serialized->erase(serialized->begin() + index);
183 });
184 }
185}
186
187///////////////////////// ADD DATUM ////////////////////////////////////
188
189static void addDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
190 const std::vector<FmqRequestDatum>& serialized) {
Michael Butler353a6242019-04-30 13:51:24 -0700191 const std::vector<FmqRequestDatum>& extra = getBadRequestPacketEntries();
Michael Butlerd6e38fd2019-04-26 17:46:08 -0700192 for (size_t index = 0; index <= serialized.size(); ++index) {
193 for (size_t type = 0; type < extra.size(); ++type) {
194 const std::string message = "addDatum: added datum type " + std::to_string(type) +
195 " at index " + std::to_string(index);
196 validate(sender, receiver, message, serialized,
197 [index, type, &extra](std::vector<FmqRequestDatum>* serialized) {
198 serialized->insert(serialized->begin() + index, extra[type]);
199 });
200 }
201 }
202}
203
204///////////////////////// MUTATE DATUM ////////////////////////////////////
205
206static bool interestingCase(const FmqRequestDatum& lhs, const FmqRequestDatum& rhs) {
207 using Discriminator = FmqRequestDatum::hidl_discriminator;
208
209 const bool differentValues = (lhs != rhs);
Michael Butler353a6242019-04-30 13:51:24 -0700210 const bool sameDiscriminator = (lhs.getDiscriminator() == rhs.getDiscriminator());
Michael Butlerd6e38fd2019-04-26 17:46:08 -0700211 const auto discriminator = rhs.getDiscriminator();
212 const bool isDimensionValue = (discriminator == Discriminator::inputOperandDimensionValue ||
213 discriminator == Discriminator::outputOperandDimensionValue);
214
Michael Butler353a6242019-04-30 13:51:24 -0700215 return differentValues && !(sameDiscriminator && isDimensionValue);
Michael Butlerd6e38fd2019-04-26 17:46:08 -0700216}
217
218static void mutateDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
219 const std::vector<FmqRequestDatum>& serialized) {
Michael Butler353a6242019-04-30 13:51:24 -0700220 const std::vector<FmqRequestDatum>& change = getBadRequestPacketEntries();
Michael Butlerd6e38fd2019-04-26 17:46:08 -0700221 for (size_t index = 0; index < serialized.size(); ++index) {
222 for (size_t type = 0; type < change.size(); ++type) {
223 if (interestingCase(serialized[index], change[type])) {
224 const std::string message = "mutateDatum: changed datum at index " +
225 std::to_string(index) + " to datum type " +
226 std::to_string(type);
227 validate(sender, receiver, message, serialized,
228 [index, type, &change](std::vector<FmqRequestDatum>* serialized) {
229 (*serialized)[index] = change[type];
230 });
231 }
232 }
233 }
234}
235
236///////////////////////// BURST VALIATION TESTS ////////////////////////////////////
237
238static void validateBurstSerialization(const sp<IPreparedModel>& preparedModel,
Xusong Wangead950d2019-08-09 16:45:24 -0700239 const Request& request) {
Michael Butlerd6e38fd2019-04-26 17:46:08 -0700240 // create burst
241 std::unique_ptr<RequestChannelSender> sender;
242 std::unique_ptr<ResultChannelReceiver> receiver;
243 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
244 sp<IBurstContext> context;
245 ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context));
246 ASSERT_NE(nullptr, sender.get());
247 ASSERT_NE(nullptr, receiver.get());
248 ASSERT_NE(nullptr, context.get());
249
Xusong Wangead950d2019-08-09 16:45:24 -0700250 // load memory into callback slots
251 std::vector<intptr_t> keys;
252 keys.reserve(request.pools.size());
253 std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(keys),
254 [](const auto& pool) { return reinterpret_cast<intptr_t>(&pool); });
255 const std::vector<int32_t> slots = callback->getSlots(request.pools, keys);
Michael Butlerd6e38fd2019-04-26 17:46:08 -0700256
Xusong Wangead950d2019-08-09 16:45:24 -0700257 // ensure slot std::numeric_limits<int32_t>::max() doesn't exist (for
258 // subsequent slot validation testing)
259 ASSERT_TRUE(std::all_of(slots.begin(), slots.end(), [](int32_t slot) {
260 return slot != std::numeric_limits<int32_t>::max();
261 }));
Michael Butlerd6e38fd2019-04-26 17:46:08 -0700262
Xusong Wangead950d2019-08-09 16:45:24 -0700263 // serialize the request
264 const auto serialized = ::android::nn::serialize(request, MeasureTiming::YES, slots);
Michael Butlerd6e38fd2019-04-26 17:46:08 -0700265
Xusong Wangead950d2019-08-09 16:45:24 -0700266 // validations
267 removeDatumTest(sender.get(), receiver.get(), serialized);
268 addDatumTest(sender.get(), receiver.get(), serialized);
269 mutateDatumTest(sender.get(), receiver.get(), serialized);
Michael Butlerd6e38fd2019-04-26 17:46:08 -0700270}
271
Michael Butler353a6242019-04-30 13:51:24 -0700272// This test validates that when the Result message size exceeds length of the
273// result FMQ, the service instance gracefully fails and returns an error.
Michael Butlerd6e38fd2019-04-26 17:46:08 -0700274static void validateBurstFmqLength(const sp<IPreparedModel>& preparedModel,
Xusong Wangead950d2019-08-09 16:45:24 -0700275 const Request& request) {
Michael Butlerd6e38fd2019-04-26 17:46:08 -0700276 // create regular burst
277 std::shared_ptr<ExecutionBurstController> controllerRegular;
Michael Butler353a6242019-04-30 13:51:24 -0700278 ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength(
279 preparedModel, kExecutionBurstChannelLength, &controllerRegular));
Michael Butlerd6e38fd2019-04-26 17:46:08 -0700280 ASSERT_NE(nullptr, controllerRegular.get());
281
282 // create burst with small output channel
283 std::shared_ptr<ExecutionBurstController> controllerSmall;
Michael Butler353a6242019-04-30 13:51:24 -0700284 ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength(
285 preparedModel, kExecutionBurstChannelSmallLength, &controllerSmall));
Michael Butlerd6e38fd2019-04-26 17:46:08 -0700286 ASSERT_NE(nullptr, controllerSmall.get());
287
Xusong Wangead950d2019-08-09 16:45:24 -0700288 // load memory into callback slots
289 std::vector<intptr_t> keys(request.pools.size());
290 for (size_t i = 0; i < keys.size(); ++i) {
291 keys[i] = reinterpret_cast<intptr_t>(&request.pools[i]);
Michael Butlerd6e38fd2019-04-26 17:46:08 -0700292 }
Xusong Wangead950d2019-08-09 16:45:24 -0700293
294 // collect serialized result by running regular burst
295 const auto [statusRegular, outputShapesRegular, timingRegular] =
296 controllerRegular->compute(request, MeasureTiming::NO, keys);
297
298 // skip test if regular burst output isn't useful for testing a failure
299 // caused by having too small of a length for the result FMQ
300 const std::vector<FmqResultDatum> serialized =
301 ::android::nn::serialize(statusRegular, outputShapesRegular, timingRegular);
302 if (statusRegular != ErrorStatus::NONE ||
303 serialized.size() <= kExecutionBurstChannelSmallLength) {
304 return;
305 }
306
307 // by this point, execution should fail because the result channel isn't
308 // large enough to return the serialized result
309 const auto [statusSmall, outputShapesSmall, timingSmall] =
310 controllerSmall->compute(request, MeasureTiming::NO, keys);
311 EXPECT_NE(ErrorStatus::NONE, statusSmall);
312 EXPECT_EQ(0u, outputShapesSmall.size());
313 EXPECT_TRUE(badTiming(timingSmall));
Michael Butlerd6e38fd2019-04-26 17:46:08 -0700314}
315
316///////////////////////////// ENTRY POINT //////////////////////////////////
317
318void ValidationTest::validateBurst(const sp<IPreparedModel>& preparedModel,
Xusong Wangead950d2019-08-09 16:45:24 -0700319 const Request& request) {
320 ASSERT_NO_FATAL_FAILURE(validateBurstSerialization(preparedModel, request));
321 ASSERT_NO_FATAL_FAILURE(validateBurstFmqLength(preparedModel, request));
Michael Butlerd6e38fd2019-04-26 17:46:08 -0700322}
323
Michael Butler62749b92019-08-26 23:55:47 -0700324} // namespace android::hardware::neuralnetworks::V1_2::vts::functional