blob: 667189b2a041b553b9589d628f514d18c209984e [file] [log] [blame]
Michael Butlera685c3d2020-02-22 22:37:59 -08001/*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "CommonUtils.h"
18
19#include <android-base/logging.h>
20#include <nnapi/Result.h>
21#include <nnapi/SharedMemory.h>
22#include <nnapi/TypeUtils.h>
23#include <nnapi/Types.h>
24#include <nnapi/Validation.h>
25
26#include <algorithm>
27#include <any>
28#include <optional>
29#include <variant>
30#include <vector>
31
32namespace android::hardware::neuralnetworks::utils {
33namespace {
34
35bool hasNoPointerData(const nn::Operand& operand);
36bool hasNoPointerData(const nn::Model::Subgraph& subgraph);
37bool hasNoPointerData(const nn::Request::Argument& argument);
38
39template <typename Type>
40bool hasNoPointerData(const std::vector<Type>& objects) {
41 return std::all_of(objects.begin(), objects.end(),
42 [](const auto& object) { return hasNoPointerData(object); });
43}
44
45bool hasNoPointerData(const nn::DataLocation& location) {
46 return std::visit([](auto ptr) { return ptr == nullptr; }, location.pointer);
47}
48
49bool hasNoPointerData(const nn::Operand& operand) {
50 return hasNoPointerData(operand.location);
51}
52
53bool hasNoPointerData(const nn::Model::Subgraph& subgraph) {
54 return hasNoPointerData(subgraph.operands);
55}
56
57bool hasNoPointerData(const nn::Request::Argument& argument) {
58 return hasNoPointerData(argument.location);
59}
60
61void copyPointersToSharedMemory(nn::Operand* operand, nn::ConstantMemoryBuilder* memoryBuilder) {
62 CHECK(operand != nullptr);
63 CHECK(memoryBuilder != nullptr);
64
65 if (operand->lifetime != nn::Operand::LifeTime::POINTER) {
66 return;
67 }
68
69 const void* data = std::visit([](auto ptr) { return static_cast<const void*>(ptr); },
70 operand->location.pointer);
71 CHECK(data != nullptr);
72 operand->lifetime = nn::Operand::LifeTime::CONSTANT_REFERENCE;
73 operand->location = memoryBuilder->append(data, operand->location.length);
74}
75
76void copyPointersToSharedMemory(nn::Model::Subgraph* subgraph,
77 nn::ConstantMemoryBuilder* memoryBuilder) {
78 CHECK(subgraph != nullptr);
79 std::for_each(subgraph->operands.begin(), subgraph->operands.end(),
80 [memoryBuilder](auto& operand) {
81 copyPointersToSharedMemory(&operand, memoryBuilder);
82 });
83}
84
85} // anonymous namespace
86
87nn::Capabilities::OperandPerformanceTable makeQuantized8PerformanceConsistentWithP(
88 const nn::Capabilities::PerformanceInfo& float32Performance,
89 const nn::Capabilities::PerformanceInfo& quantized8Performance) {
90 // In Android P, most data types are treated as having the same performance as
91 // TENSOR_QUANT8_ASYMM. This collection must be in sorted order.
92 std::vector<nn::Capabilities::OperandPerformance> operandPerformances = {
93 {.type = nn::OperandType::FLOAT32, .info = float32Performance},
94 {.type = nn::OperandType::INT32, .info = quantized8Performance},
95 {.type = nn::OperandType::UINT32, .info = quantized8Performance},
96 {.type = nn::OperandType::TENSOR_FLOAT32, .info = float32Performance},
97 {.type = nn::OperandType::TENSOR_INT32, .info = quantized8Performance},
98 {.type = nn::OperandType::TENSOR_QUANT8_ASYMM, .info = quantized8Performance},
99 {.type = nn::OperandType::OEM, .info = quantized8Performance},
100 {.type = nn::OperandType::TENSOR_OEM_BYTE, .info = quantized8Performance},
101 };
102 return nn::Capabilities::OperandPerformanceTable::create(std::move(operandPerformances))
103 .value();
104}
105
106bool hasNoPointerData(const nn::Model& model) {
107 return hasNoPointerData(model.main) && hasNoPointerData(model.referenced);
108}
109
110bool hasNoPointerData(const nn::Request& request) {
111 return hasNoPointerData(request.inputs) && hasNoPointerData(request.outputs);
112}
113
114nn::Result<nn::Model> flushDataFromPointerToShared(const nn::Model& model) {
115 auto modelInShared = model;
116
117 nn::ConstantMemoryBuilder memoryBuilder(modelInShared.pools.size());
118 copyPointersToSharedMemory(&modelInShared.main, &memoryBuilder);
119 std::for_each(modelInShared.referenced.begin(), modelInShared.referenced.end(),
120 [&memoryBuilder](auto& subgraph) {
121 copyPointersToSharedMemory(&subgraph, &memoryBuilder);
122 });
123
124 if (!memoryBuilder.empty()) {
125 auto memory = NN_TRY(memoryBuilder.finish());
126 modelInShared.pools.push_back(std::move(memory));
127 }
128
129 return modelInShared;
130}
131
132nn::Result<nn::Request> flushDataFromPointerToShared(const nn::Request& request) {
133 auto requestInShared = request;
134
135 // Change input pointers to shared memory.
136 nn::ConstantMemoryBuilder inputBuilder(requestInShared.pools.size());
137 for (auto& input : requestInShared.inputs) {
138 const auto& location = input.location;
139 if (input.lifetime != nn::Request::Argument::LifeTime::POINTER) {
140 continue;
141 }
142
143 input.lifetime = nn::Request::Argument::LifeTime::POOL;
144 const void* data = std::visit([](auto ptr) { return static_cast<const void*>(ptr); },
145 location.pointer);
146 CHECK(data != nullptr);
147 input.location = inputBuilder.append(data, location.length);
148 }
149
150 // Allocate input memory.
151 if (!inputBuilder.empty()) {
152 auto memory = NN_TRY(inputBuilder.finish());
153 requestInShared.pools.push_back(std::move(memory));
154 }
155
156 // Change output pointers to shared memory.
157 nn::MutableMemoryBuilder outputBuilder(requestInShared.pools.size());
158 for (auto& output : requestInShared.outputs) {
159 const auto& location = output.location;
160 if (output.lifetime != nn::Request::Argument::LifeTime::POINTER) {
161 continue;
162 }
163
164 output.lifetime = nn::Request::Argument::LifeTime::POOL;
165 output.location = outputBuilder.append(location.length);
166 }
167
168 // Allocate output memory.
169 if (!outputBuilder.empty()) {
170 auto memory = NN_TRY(outputBuilder.finish());
171 requestInShared.pools.push_back(std::move(memory));
172 }
173
174 return requestInShared;
175}
176
177nn::Result<void> unflushDataFromSharedToPointer(const nn::Request& request,
178 const nn::Request& requestInShared) {
179 if (requestInShared.pools.empty() ||
180 !std::holds_alternative<nn::Memory>(requestInShared.pools.back())) {
181 return {};
182 }
183
184 // Map the memory.
185 const auto& outputMemory = std::get<nn::Memory>(requestInShared.pools.back());
186 const auto [pointer, size, context] = NN_TRY(map(outputMemory));
187 const uint8_t* constantPointer =
188 std::visit([](const auto& o) { return static_cast<const uint8_t*>(o); }, pointer);
189
190 // Flush each output pointer.
191 CHECK_EQ(request.outputs.size(), requestInShared.outputs.size());
192 for (size_t i = 0; i < request.outputs.size(); ++i) {
193 const auto& location = request.outputs[i].location;
194 const auto& locationInShared = requestInShared.outputs[i].location;
195 if (!std::holds_alternative<void*>(location.pointer)) {
196 continue;
197 }
198
199 // Get output pointer and size.
200 void* data = std::get<void*>(location.pointer);
201 CHECK(data != nullptr);
202 const size_t length = location.length;
203
204 // Get output pool location.
205 CHECK(requestInShared.outputs[i].lifetime == nn::Request::Argument::LifeTime::POOL);
206 const size_t index = locationInShared.poolIndex;
207 const size_t offset = locationInShared.offset;
208 const size_t outputPoolIndex = requestInShared.pools.size() - 1;
209 CHECK(locationInShared.length == length);
210 CHECK(index == outputPoolIndex);
211
212 // Flush memory.
213 std::memcpy(data, constantPointer + offset, length);
214 }
215
216 return {};
217}
218
219std::vector<uint32_t> countNumberOfConsumers(size_t numberOfOperands,
220 const std::vector<nn::Operation>& operations) {
221 return nn::countNumberOfConsumers(numberOfOperands, operations);
222}
223
224} // namespace android::hardware::neuralnetworks::utils