blob: 7d11ef25750cd53c35ae7533cca4e36771a1e9a0 [file] [log] [blame]
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -08001/*
2 * Copyright (C) 2022 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "MotionPredictor"
18
19#include <input/MotionPredictor.h>
20
Philip Quinn8f953ab2022-12-06 15:37:07 -080021#include <cinttypes>
22#include <cmath>
23#include <cstddef>
24#include <cstdint>
25#include <string>
26#include <vector>
27
28#include <android-base/strings.h>
29#include <android/input.h>
30#include <log/log.h>
31
32#include <attestation/HmacKeyManager.h>
33#include <input/TfLiteMotionPredictor.h>
34
35namespace android {
36namespace {
37
38const char DEFAULT_MODEL_PATH[] = "/system/etc/motion_predictor_model.fb";
39const int64_t PREDICTION_INTERVAL_NANOS =
40 12500000 / 3; // TODO(b/266747937): Get this from the model.
41
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -080042/**
43 * Log debug messages about predictions.
44 * Enable this via "adb shell setprop log.tag.MotionPredictor DEBUG"
45 */
Philip Quinn8f953ab2022-12-06 15:37:07 -080046bool isDebug() {
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -080047 return __android_log_is_loggable(ANDROID_LOG_DEBUG, LOG_TAG, ANDROID_LOG_INFO);
48}
49
Philip Quinn8f953ab2022-12-06 15:37:07 -080050// Converts a prediction of some polar (r, phi) to Cartesian (x, y) when applied to an axis.
51TfLiteMotionPredictorSample::Point convertPrediction(
52 const TfLiteMotionPredictorSample::Point& axisFrom,
53 const TfLiteMotionPredictorSample::Point& axisTo, float r, float phi) {
54 const TfLiteMotionPredictorSample::Point axis = axisTo - axisFrom;
55 const float axis_phi = std::atan2(axis.y, axis.x);
56 const float x_delta = r * std::cos(axis_phi + phi);
57 const float y_delta = r * std::sin(axis_phi + phi);
58 return {.x = axisTo.x + x_delta, .y = axisTo.y + y_delta};
59}
60
61} // namespace
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -080062
63// --- MotionPredictor ---
64
Philip Quinn8f953ab2022-12-06 15:37:07 -080065MotionPredictor::MotionPredictor(nsecs_t predictionTimestampOffsetNanos, const char* modelPath,
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -080066 std::function<bool()> checkMotionPredictionEnabled)
67 : mPredictionTimestampOffsetNanos(predictionTimestampOffsetNanos),
Philip Quinnbd66e622023-02-10 11:45:01 -080068 mModelPath(modelPath == nullptr ? DEFAULT_MODEL_PATH : modelPath),
69 mCheckMotionPredictionEnabled(std::move(checkMotionPredictionEnabled)) {}
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -080070
71void MotionPredictor::record(const MotionEvent& event) {
Philip Quinn8f953ab2022-12-06 15:37:07 -080072 if (!isPredictionAvailable(event.getDeviceId(), event.getSource())) {
73 ALOGE("Prediction not supported for device %d's %s source", event.getDeviceId(),
74 inputEventSourceToString(event.getSource()).c_str());
75 return;
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -080076 }
Philip Quinn8f953ab2022-12-06 15:37:07 -080077
Philip Quinnbd66e622023-02-10 11:45:01 -080078 // Initialise the model now that it's likely to be used.
79 if (!mModel) {
80 mModel = TfLiteMotionPredictorModel::create(mModelPath.c_str());
81 }
82
Philip Quinn8f953ab2022-12-06 15:37:07 -080083 TfLiteMotionPredictorBuffers& buffers =
84 mDeviceBuffers.try_emplace(event.getDeviceId(), mModel->inputLength()).first->second;
85
86 const int32_t action = event.getActionMasked();
87 if (action == AMOTION_EVENT_ACTION_UP) {
88 ALOGD_IF(isDebug(), "End of event stream");
89 buffers.reset();
90 return;
91 } else if (action != AMOTION_EVENT_ACTION_DOWN && action != AMOTION_EVENT_ACTION_MOVE) {
92 ALOGD_IF(isDebug(), "Skipping unsupported %s action",
93 MotionEvent::actionToString(action).c_str());
94 return;
95 }
96
97 if (event.getPointerCount() != 1) {
98 ALOGD_IF(isDebug(), "Prediction not supported for multiple pointers");
99 return;
100 }
101
102 const int32_t toolType = event.getPointerProperties(0)->toolType;
103 if (toolType != AMOTION_EVENT_TOOL_TYPE_STYLUS) {
104 ALOGD_IF(isDebug(), "Prediction not supported for non-stylus tool: %s",
105 motionToolTypeToString(toolType));
106 return;
107 }
108
109 for (size_t i = 0; i <= event.getHistorySize(); ++i) {
110 if (event.isResampled(0, i)) {
111 continue;
112 }
113 const PointerCoords* coords = event.getHistoricalRawPointerCoords(0, i);
114 buffers.pushSample(event.getHistoricalEventTime(i),
115 {
116 .position.x = coords->getAxisValue(AMOTION_EVENT_AXIS_X),
117 .position.y = coords->getAxisValue(AMOTION_EVENT_AXIS_Y),
118 .pressure = event.getHistoricalPressure(0, i),
119 .tilt = event.getHistoricalAxisValue(AMOTION_EVENT_AXIS_TILT, 0,
120 i),
121 .orientation = event.getHistoricalOrientation(0, i),
122 });
123 }
124
125 mLastEvents.try_emplace(event.getDeviceId())
126 .first->second.copyFrom(&event, /*keepHistory=*/false);
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -0800127}
128
Siarhei Vishniakou0839bd62023-01-05 17:20:00 -0800129std::vector<std::unique_ptr<MotionEvent>> MotionPredictor::predict(nsecs_t timestamp) {
Philip Quinn8f953ab2022-12-06 15:37:07 -0800130 std::vector<std::unique_ptr<MotionEvent>> predictions;
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -0800131
Philip Quinn8f953ab2022-12-06 15:37:07 -0800132 for (const auto& [deviceId, buffer] : mDeviceBuffers) {
133 if (!buffer.isReady()) {
134 continue;
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -0800135 }
136
Philip Quinnbd66e622023-02-10 11:45:01 -0800137 LOG_ALWAYS_FATAL_IF(!mModel);
Philip Quinn8f953ab2022-12-06 15:37:07 -0800138 buffer.copyTo(*mModel);
139 LOG_ALWAYS_FATAL_IF(!mModel->invoke());
140
141 // Read out the predictions.
142 const std::span<const float> predictedR = mModel->outputR();
143 const std::span<const float> predictedPhi = mModel->outputPhi();
144 const std::span<const float> predictedPressure = mModel->outputPressure();
145
146 TfLiteMotionPredictorSample::Point axisFrom = buffer.axisFrom().position;
147 TfLiteMotionPredictorSample::Point axisTo = buffer.axisTo().position;
148
149 if (isDebug()) {
150 ALOGD("deviceId: %d", deviceId);
151 ALOGD("axisFrom: %f, %f", axisFrom.x, axisFrom.y);
152 ALOGD("axisTo: %f, %f", axisTo.x, axisTo.y);
153 ALOGD("mInputR: %s", base::Join(mModel->inputR(), ", ").c_str());
154 ALOGD("mInputPhi: %s", base::Join(mModel->inputPhi(), ", ").c_str());
155 ALOGD("mInputPressure: %s", base::Join(mModel->inputPressure(), ", ").c_str());
156 ALOGD("mInputTilt: %s", base::Join(mModel->inputTilt(), ", ").c_str());
157 ALOGD("mInputOrientation: %s", base::Join(mModel->inputOrientation(), ", ").c_str());
158 ALOGD("predictedR: %s", base::Join(predictedR, ", ").c_str());
159 ALOGD("predictedPhi: %s", base::Join(predictedPhi, ", ").c_str());
160 ALOGD("predictedPressure: %s", base::Join(predictedPressure, ", ").c_str());
161 }
162
163 const MotionEvent& event = mLastEvents[deviceId];
164 bool hasPredictions = false;
165 std::unique_ptr<MotionEvent> prediction = std::make_unique<MotionEvent>();
166 int64_t predictionTime = buffer.lastTimestamp();
167 const int64_t futureTime = timestamp + mPredictionTimestampOffsetNanos;
168
169 for (int i = 0; i < predictedR.size() && predictionTime <= futureTime; ++i) {
170 const TfLiteMotionPredictorSample::Point point =
171 convertPrediction(axisFrom, axisTo, predictedR[i], predictedPhi[i]);
172 // TODO(b/266747654): Stop predictions if confidence is < some threshold.
173
174 ALOGD_IF(isDebug(), "prediction %d: %f, %f", i, point.x, point.y);
175 PointerCoords coords;
176 coords.clear();
177 coords.setAxisValue(AMOTION_EVENT_AXIS_X, point.x);
178 coords.setAxisValue(AMOTION_EVENT_AXIS_Y, point.y);
179 // TODO(b/266747654): Stop predictions if predicted pressure is < some threshold.
180 coords.setAxisValue(AMOTION_EVENT_AXIS_PRESSURE, predictedPressure[i]);
181
182 predictionTime += PREDICTION_INTERVAL_NANOS;
183 if (i == 0) {
184 hasPredictions = true;
185 prediction->initialize(InputEvent::nextId(), event.getDeviceId(), event.getSource(),
186 event.getDisplayId(), INVALID_HMAC,
187 AMOTION_EVENT_ACTION_MOVE, event.getActionButton(),
188 event.getFlags(), event.getEdgeFlags(), event.getMetaState(),
189 event.getButtonState(), event.getClassification(),
190 event.getTransform(), event.getXPrecision(),
191 event.getYPrecision(), event.getRawXCursorPosition(),
192 event.getRawYCursorPosition(), event.getRawTransform(),
193 event.getDownTime(), predictionTime, event.getPointerCount(),
194 event.getPointerProperties(), &coords);
195 } else {
196 prediction->addSample(predictionTime, &coords);
197 }
198
199 axisFrom = axisTo;
200 axisTo = point;
201 }
202 // TODO(b/266747511): Interpolate to futureTime?
203 if (hasPredictions) {
204 predictions.push_back(std::move(prediction));
205 }
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -0800206 }
Philip Quinn8f953ab2022-12-06 15:37:07 -0800207 return predictions;
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -0800208}
209
210bool MotionPredictor::isPredictionAvailable(int32_t /*deviceId*/, int32_t source) {
211 // Global flag override
212 if (!mCheckMotionPredictionEnabled()) {
213 ALOGD_IF(isDebug(), "Prediction not available due to flag override");
214 return false;
215 }
216
217 // Prediction is only supported for stylus sources.
218 if (!isFromSource(source, AINPUT_SOURCE_STYLUS)) {
219 ALOGD_IF(isDebug(), "Prediction not available for non-stylus source: %s",
220 inputEventSourceToString(source).c_str());
221 return false;
222 }
223 return true;
224}
225
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -0800226} // namespace android