blob: 704349c2eb2e81bcd8e13c4404918b65a1b73cc8 [file] [log] [blame]
Philip Quinn8f953ab2022-12-06 15:37:07 -08001/*
2 * Copyright (C) 2023 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#pragma once
18
19#include <array>
20#include <cstddef>
21#include <cstdint>
22#include <memory>
23#include <optional>
24#include <span>
25#include <string>
Philip Quinn9b8926e2023-01-31 14:50:02 -080026
27#include <input/RingBuffer.h>
Philip Quinn8f953ab2022-12-06 15:37:07 -080028
29#include <tensorflow/lite/core/api/error_reporter.h>
30#include <tensorflow/lite/interpreter.h>
31#include <tensorflow/lite/model.h>
32#include <tensorflow/lite/signature_runner.h>
33
34namespace android {
35
36struct TfLiteMotionPredictorSample {
37 // The untransformed AMOTION_EVENT_AXIS_X and AMOTION_EVENT_AXIS_Y of the sample.
38 struct Point {
39 float x;
40 float y;
41 } position;
42 // The AMOTION_EVENT_AXIS_PRESSURE, _TILT, and _ORIENTATION.
43 float pressure;
44 float tilt;
45 float orientation;
46};
47
48inline TfLiteMotionPredictorSample::Point operator-(const TfLiteMotionPredictorSample::Point& lhs,
49 const TfLiteMotionPredictorSample::Point& rhs) {
50 return {.x = lhs.x - rhs.x, .y = lhs.y - rhs.y};
51}
52
53class TfLiteMotionPredictorModel;
54
55// Buffer storage for a TfLiteMotionPredictorModel.
56class TfLiteMotionPredictorBuffers {
57public:
58 // Creates buffer storage for a model with the given input length.
59 TfLiteMotionPredictorBuffers(size_t inputLength);
60
61 // Adds a motion sample to the buffers.
62 void pushSample(int64_t timestamp, TfLiteMotionPredictorSample sample);
63
64 // Returns true if the buffers are complete enough to generate a prediction.
65 bool isReady() const {
66 // Predictions can't be applied unless there are at least two points to determine
67 // the direction to apply them in.
68 return mAxisFrom && mAxisTo;
69 }
70
71 // Resets all buffers to their initial state.
72 void reset();
73
74 // Copies the buffers to those of a model for prediction.
75 void copyTo(TfLiteMotionPredictorModel& model) const;
76
77 // Returns the current axis of the buffer's samples. Only valid if isReady().
78 TfLiteMotionPredictorSample axisFrom() const { return *mAxisFrom; }
79 TfLiteMotionPredictorSample axisTo() const { return *mAxisTo; }
80
81 // Returns the timestamp of the last sample.
82 int64_t lastTimestamp() const { return mTimestamp; }
83
84private:
85 int64_t mTimestamp = 0;
86
Philip Quinn9b8926e2023-01-31 14:50:02 -080087 RingBuffer<float> mInputR;
88 RingBuffer<float> mInputPhi;
89 RingBuffer<float> mInputPressure;
90 RingBuffer<float> mInputTilt;
91 RingBuffer<float> mInputOrientation;
Philip Quinn8f953ab2022-12-06 15:37:07 -080092
93 // The samples defining the current polar axis.
94 std::optional<TfLiteMotionPredictorSample> mAxisFrom;
95 std::optional<TfLiteMotionPredictorSample> mAxisTo;
96};
97
98// A TFLite model for generating motion predictions.
99class TfLiteMotionPredictorModel {
100public:
101 // Creates a model from an encoded Flatbuffer model.
102 static std::unique_ptr<TfLiteMotionPredictorModel> create(const char* modelPath);
103
104 // Returns the length of the model's input buffers.
105 size_t inputLength() const;
106
107 // Executes the model.
108 // Returns true if the model successfully executed and the output tensors can be read.
109 bool invoke();
110
111 // Returns mutable buffers to the input tensors of inputLength() elements.
112 std::span<float> inputR();
113 std::span<float> inputPhi();
114 std::span<float> inputPressure();
115 std::span<float> inputOrientation();
116 std::span<float> inputTilt();
117
118 // Returns immutable buffers to the output tensors of identical length. Only valid after a
119 // successful call to invoke().
120 std::span<const float> outputR() const;
121 std::span<const float> outputPhi() const;
122 std::span<const float> outputPressure() const;
123
124private:
125 explicit TfLiteMotionPredictorModel(std::string model);
126
127 void allocateTensors();
128 void attachInputTensors();
129 void attachOutputTensors();
130
131 TfLiteTensor* mInputR = nullptr;
132 TfLiteTensor* mInputPhi = nullptr;
133 TfLiteTensor* mInputPressure = nullptr;
134 TfLiteTensor* mInputTilt = nullptr;
135 TfLiteTensor* mInputOrientation = nullptr;
136
137 const TfLiteTensor* mOutputR = nullptr;
138 const TfLiteTensor* mOutputPhi = nullptr;
139 const TfLiteTensor* mOutputPressure = nullptr;
140
141 std::string mFlatBuffer;
142 std::unique_ptr<tflite::ErrorReporter> mErrorReporter;
143 std::unique_ptr<tflite::FlatBufferModel> mModel;
144 std::unique_ptr<tflite::Interpreter> mInterpreter;
145 tflite::SignatureRunner* mRunner = nullptr;
146};
147
148} // namespace android