blob: 0f889e8128096e98b403956b52f38cec8e6a68cb [file] [log] [blame]
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -08001/*
2 * Copyright (C) 2022 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "MotionPredictor"
18
19#include <input/MotionPredictor.h>
20
Philip Quinn8f953ab2022-12-06 15:37:07 -080021#include <cinttypes>
22#include <cmath>
23#include <cstddef>
24#include <cstdint>
25#include <string>
26#include <vector>
27
28#include <android-base/strings.h>
29#include <android/input.h>
30#include <log/log.h>
31
32#include <attestation/HmacKeyManager.h>
33#include <input/TfLiteMotionPredictor.h>
34
35namespace android {
36namespace {
37
38const char DEFAULT_MODEL_PATH[] = "/system/etc/motion_predictor_model.fb";
39const int64_t PREDICTION_INTERVAL_NANOS =
40 12500000 / 3; // TODO(b/266747937): Get this from the model.
41
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -080042/**
43 * Log debug messages about predictions.
44 * Enable this via "adb shell setprop log.tag.MotionPredictor DEBUG"
45 */
Philip Quinn8f953ab2022-12-06 15:37:07 -080046bool isDebug() {
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -080047 return __android_log_is_loggable(ANDROID_LOG_DEBUG, LOG_TAG, ANDROID_LOG_INFO);
48}
49
Philip Quinn8f953ab2022-12-06 15:37:07 -080050// Converts a prediction of some polar (r, phi) to Cartesian (x, y) when applied to an axis.
51TfLiteMotionPredictorSample::Point convertPrediction(
52 const TfLiteMotionPredictorSample::Point& axisFrom,
53 const TfLiteMotionPredictorSample::Point& axisTo, float r, float phi) {
54 const TfLiteMotionPredictorSample::Point axis = axisTo - axisFrom;
55 const float axis_phi = std::atan2(axis.y, axis.x);
56 const float x_delta = r * std::cos(axis_phi + phi);
57 const float y_delta = r * std::sin(axis_phi + phi);
58 return {.x = axisTo.x + x_delta, .y = axisTo.y + y_delta};
59}
60
61} // namespace
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -080062
63// --- MotionPredictor ---
64
Philip Quinn8f953ab2022-12-06 15:37:07 -080065MotionPredictor::MotionPredictor(nsecs_t predictionTimestampOffsetNanos, const char* modelPath,
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -080066 std::function<bool()> checkMotionPredictionEnabled)
67 : mPredictionTimestampOffsetNanos(predictionTimestampOffsetNanos),
Philip Quinn8f953ab2022-12-06 15:37:07 -080068 mCheckMotionPredictionEnabled(std::move(checkMotionPredictionEnabled)),
69 mModel(TfLiteMotionPredictorModel::create(modelPath == nullptr ? DEFAULT_MODEL_PATH
70 : modelPath)) {}
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -080071
72void MotionPredictor::record(const MotionEvent& event) {
Philip Quinn8f953ab2022-12-06 15:37:07 -080073 if (!isPredictionAvailable(event.getDeviceId(), event.getSource())) {
74 ALOGE("Prediction not supported for device %d's %s source", event.getDeviceId(),
75 inputEventSourceToString(event.getSource()).c_str());
76 return;
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -080077 }
Philip Quinn8f953ab2022-12-06 15:37:07 -080078
79 TfLiteMotionPredictorBuffers& buffers =
80 mDeviceBuffers.try_emplace(event.getDeviceId(), mModel->inputLength()).first->second;
81
82 const int32_t action = event.getActionMasked();
83 if (action == AMOTION_EVENT_ACTION_UP) {
84 ALOGD_IF(isDebug(), "End of event stream");
85 buffers.reset();
86 return;
87 } else if (action != AMOTION_EVENT_ACTION_DOWN && action != AMOTION_EVENT_ACTION_MOVE) {
88 ALOGD_IF(isDebug(), "Skipping unsupported %s action",
89 MotionEvent::actionToString(action).c_str());
90 return;
91 }
92
93 if (event.getPointerCount() != 1) {
94 ALOGD_IF(isDebug(), "Prediction not supported for multiple pointers");
95 return;
96 }
97
98 const int32_t toolType = event.getPointerProperties(0)->toolType;
99 if (toolType != AMOTION_EVENT_TOOL_TYPE_STYLUS) {
100 ALOGD_IF(isDebug(), "Prediction not supported for non-stylus tool: %s",
101 motionToolTypeToString(toolType));
102 return;
103 }
104
105 for (size_t i = 0; i <= event.getHistorySize(); ++i) {
106 if (event.isResampled(0, i)) {
107 continue;
108 }
109 const PointerCoords* coords = event.getHistoricalRawPointerCoords(0, i);
110 buffers.pushSample(event.getHistoricalEventTime(i),
111 {
112 .position.x = coords->getAxisValue(AMOTION_EVENT_AXIS_X),
113 .position.y = coords->getAxisValue(AMOTION_EVENT_AXIS_Y),
114 .pressure = event.getHistoricalPressure(0, i),
115 .tilt = event.getHistoricalAxisValue(AMOTION_EVENT_AXIS_TILT, 0,
116 i),
117 .orientation = event.getHistoricalOrientation(0, i),
118 });
119 }
120
121 mLastEvents.try_emplace(event.getDeviceId())
122 .first->second.copyFrom(&event, /*keepHistory=*/false);
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -0800123}
124
Siarhei Vishniakou0839bd62023-01-05 17:20:00 -0800125std::vector<std::unique_ptr<MotionEvent>> MotionPredictor::predict(nsecs_t timestamp) {
Philip Quinn8f953ab2022-12-06 15:37:07 -0800126 std::vector<std::unique_ptr<MotionEvent>> predictions;
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -0800127
Philip Quinn8f953ab2022-12-06 15:37:07 -0800128 for (const auto& [deviceId, buffer] : mDeviceBuffers) {
129 if (!buffer.isReady()) {
130 continue;
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -0800131 }
132
Philip Quinn8f953ab2022-12-06 15:37:07 -0800133 buffer.copyTo(*mModel);
134 LOG_ALWAYS_FATAL_IF(!mModel->invoke());
135
136 // Read out the predictions.
137 const std::span<const float> predictedR = mModel->outputR();
138 const std::span<const float> predictedPhi = mModel->outputPhi();
139 const std::span<const float> predictedPressure = mModel->outputPressure();
140
141 TfLiteMotionPredictorSample::Point axisFrom = buffer.axisFrom().position;
142 TfLiteMotionPredictorSample::Point axisTo = buffer.axisTo().position;
143
144 if (isDebug()) {
145 ALOGD("deviceId: %d", deviceId);
146 ALOGD("axisFrom: %f, %f", axisFrom.x, axisFrom.y);
147 ALOGD("axisTo: %f, %f", axisTo.x, axisTo.y);
148 ALOGD("mInputR: %s", base::Join(mModel->inputR(), ", ").c_str());
149 ALOGD("mInputPhi: %s", base::Join(mModel->inputPhi(), ", ").c_str());
150 ALOGD("mInputPressure: %s", base::Join(mModel->inputPressure(), ", ").c_str());
151 ALOGD("mInputTilt: %s", base::Join(mModel->inputTilt(), ", ").c_str());
152 ALOGD("mInputOrientation: %s", base::Join(mModel->inputOrientation(), ", ").c_str());
153 ALOGD("predictedR: %s", base::Join(predictedR, ", ").c_str());
154 ALOGD("predictedPhi: %s", base::Join(predictedPhi, ", ").c_str());
155 ALOGD("predictedPressure: %s", base::Join(predictedPressure, ", ").c_str());
156 }
157
158 const MotionEvent& event = mLastEvents[deviceId];
159 bool hasPredictions = false;
160 std::unique_ptr<MotionEvent> prediction = std::make_unique<MotionEvent>();
161 int64_t predictionTime = buffer.lastTimestamp();
162 const int64_t futureTime = timestamp + mPredictionTimestampOffsetNanos;
163
164 for (int i = 0; i < predictedR.size() && predictionTime <= futureTime; ++i) {
165 const TfLiteMotionPredictorSample::Point point =
166 convertPrediction(axisFrom, axisTo, predictedR[i], predictedPhi[i]);
167 // TODO(b/266747654): Stop predictions if confidence is < some threshold.
168
169 ALOGD_IF(isDebug(), "prediction %d: %f, %f", i, point.x, point.y);
170 PointerCoords coords;
171 coords.clear();
172 coords.setAxisValue(AMOTION_EVENT_AXIS_X, point.x);
173 coords.setAxisValue(AMOTION_EVENT_AXIS_Y, point.y);
174 // TODO(b/266747654): Stop predictions if predicted pressure is < some threshold.
175 coords.setAxisValue(AMOTION_EVENT_AXIS_PRESSURE, predictedPressure[i]);
176
177 predictionTime += PREDICTION_INTERVAL_NANOS;
178 if (i == 0) {
179 hasPredictions = true;
180 prediction->initialize(InputEvent::nextId(), event.getDeviceId(), event.getSource(),
181 event.getDisplayId(), INVALID_HMAC,
182 AMOTION_EVENT_ACTION_MOVE, event.getActionButton(),
183 event.getFlags(), event.getEdgeFlags(), event.getMetaState(),
184 event.getButtonState(), event.getClassification(),
185 event.getTransform(), event.getXPrecision(),
186 event.getYPrecision(), event.getRawXCursorPosition(),
187 event.getRawYCursorPosition(), event.getRawTransform(),
188 event.getDownTime(), predictionTime, event.getPointerCount(),
189 event.getPointerProperties(), &coords);
190 } else {
191 prediction->addSample(predictionTime, &coords);
192 }
193
194 axisFrom = axisTo;
195 axisTo = point;
196 }
197 // TODO(b/266747511): Interpolate to futureTime?
198 if (hasPredictions) {
199 predictions.push_back(std::move(prediction));
200 }
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -0800201 }
Philip Quinn8f953ab2022-12-06 15:37:07 -0800202 return predictions;
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -0800203}
204
205bool MotionPredictor::isPredictionAvailable(int32_t /*deviceId*/, int32_t source) {
206 // Global flag override
207 if (!mCheckMotionPredictionEnabled()) {
208 ALOGD_IF(isDebug(), "Prediction not available due to flag override");
209 return false;
210 }
211
212 // Prediction is only supported for stylus sources.
213 if (!isFromSource(source, AINPUT_SOURCE_STYLUS)) {
214 ALOGD_IF(isDebug(), "Prediction not available for non-stylus source: %s",
215 inputEventSourceToString(source).c_str());
216 return false;
217 }
218 return true;
219}
220
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -0800221} // namespace android