blob: ff0f51c7d900139e12db66f4153088c27ce2f728 [file] [log] [blame]
Philip Quinn8f953ab2022-12-06 15:37:07 -08001/*
2 * Copyright (C) 2023 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#pragma once
18
19#include <array>
20#include <cstddef>
21#include <cstdint>
22#include <memory>
23#include <optional>
24#include <span>
25#include <string>
26#include <vector>
27
28#include <tensorflow/lite/core/api/error_reporter.h>
29#include <tensorflow/lite/interpreter.h>
30#include <tensorflow/lite/model.h>
31#include <tensorflow/lite/signature_runner.h>
32
33namespace android {
34
35struct TfLiteMotionPredictorSample {
36 // The untransformed AMOTION_EVENT_AXIS_X and AMOTION_EVENT_AXIS_Y of the sample.
37 struct Point {
38 float x;
39 float y;
40 } position;
41 // The AMOTION_EVENT_AXIS_PRESSURE, _TILT, and _ORIENTATION.
42 float pressure;
43 float tilt;
44 float orientation;
45};
46
47inline TfLiteMotionPredictorSample::Point operator-(const TfLiteMotionPredictorSample::Point& lhs,
48 const TfLiteMotionPredictorSample::Point& rhs) {
49 return {.x = lhs.x - rhs.x, .y = lhs.y - rhs.y};
50}
51
52class TfLiteMotionPredictorModel;
53
54// Buffer storage for a TfLiteMotionPredictorModel.
55class TfLiteMotionPredictorBuffers {
56public:
57 // Creates buffer storage for a model with the given input length.
58 TfLiteMotionPredictorBuffers(size_t inputLength);
59
60 // Adds a motion sample to the buffers.
61 void pushSample(int64_t timestamp, TfLiteMotionPredictorSample sample);
62
63 // Returns true if the buffers are complete enough to generate a prediction.
64 bool isReady() const {
65 // Predictions can't be applied unless there are at least two points to determine
66 // the direction to apply them in.
67 return mAxisFrom && mAxisTo;
68 }
69
70 // Resets all buffers to their initial state.
71 void reset();
72
73 // Copies the buffers to those of a model for prediction.
74 void copyTo(TfLiteMotionPredictorModel& model) const;
75
76 // Returns the current axis of the buffer's samples. Only valid if isReady().
77 TfLiteMotionPredictorSample axisFrom() const { return *mAxisFrom; }
78 TfLiteMotionPredictorSample axisTo() const { return *mAxisTo; }
79
80 // Returns the timestamp of the last sample.
81 int64_t lastTimestamp() const { return mTimestamp; }
82
83private:
84 int64_t mTimestamp = 0;
85
86 std::vector<float> mInputR;
87 std::vector<float> mInputPhi;
88 std::vector<float> mInputPressure;
89 std::vector<float> mInputTilt;
90 std::vector<float> mInputOrientation;
91
92 // The samples defining the current polar axis.
93 std::optional<TfLiteMotionPredictorSample> mAxisFrom;
94 std::optional<TfLiteMotionPredictorSample> mAxisTo;
95};
96
97// A TFLite model for generating motion predictions.
98class TfLiteMotionPredictorModel {
99public:
100 // Creates a model from an encoded Flatbuffer model.
101 static std::unique_ptr<TfLiteMotionPredictorModel> create(const char* modelPath);
102
103 // Returns the length of the model's input buffers.
104 size_t inputLength() const;
105
106 // Executes the model.
107 // Returns true if the model successfully executed and the output tensors can be read.
108 bool invoke();
109
110 // Returns mutable buffers to the input tensors of inputLength() elements.
111 std::span<float> inputR();
112 std::span<float> inputPhi();
113 std::span<float> inputPressure();
114 std::span<float> inputOrientation();
115 std::span<float> inputTilt();
116
117 // Returns immutable buffers to the output tensors of identical length. Only valid after a
118 // successful call to invoke().
119 std::span<const float> outputR() const;
120 std::span<const float> outputPhi() const;
121 std::span<const float> outputPressure() const;
122
123private:
124 explicit TfLiteMotionPredictorModel(std::string model);
125
126 void allocateTensors();
127 void attachInputTensors();
128 void attachOutputTensors();
129
130 TfLiteTensor* mInputR = nullptr;
131 TfLiteTensor* mInputPhi = nullptr;
132 TfLiteTensor* mInputPressure = nullptr;
133 TfLiteTensor* mInputTilt = nullptr;
134 TfLiteTensor* mInputOrientation = nullptr;
135
136 const TfLiteTensor* mOutputR = nullptr;
137 const TfLiteTensor* mOutputPhi = nullptr;
138 const TfLiteTensor* mOutputPressure = nullptr;
139
140 std::string mFlatBuffer;
141 std::unique_ptr<tflite::ErrorReporter> mErrorReporter;
142 std::unique_ptr<tflite::FlatBufferModel> mModel;
143 std::unique_ptr<tflite::Interpreter> mInterpreter;
144 tflite::SignatureRunner* mRunner = nullptr;
145};
146
147} // namespace android