blob: 77292d4798dc2a7232a586f5d29b788232bb8ad9 [file] [log] [blame]
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -08001/*
2 * Copyright (C) 2022 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "MotionPredictor"
18
19#include <input/MotionPredictor.h>
20
Derek Wu705068d2024-03-20 10:41:37 -070021#include <array>
Philip Quinn8f953ab2022-12-06 15:37:07 -080022#include <cinttypes>
23#include <cmath>
24#include <cstddef>
25#include <cstdint>
Derek Wuea36ee72024-03-25 13:17:51 -070026#include <limits>
Derek Wu705068d2024-03-20 10:41:37 -070027#include <optional>
Philip Quinn8f953ab2022-12-06 15:37:07 -080028#include <string>
Derek Wu705068d2024-03-20 10:41:37 -070029#include <utility>
Philip Quinn8f953ab2022-12-06 15:37:07 -080030#include <vector>
31
Yeabkal Wubshit64f090f2023-03-03 17:35:11 -080032#include <android-base/logging.h>
Philip Quinn8f953ab2022-12-06 15:37:07 -080033#include <android-base/strings.h>
34#include <android/input.h>
Derek Wuea36ee72024-03-25 13:17:51 -070035#include <com_android_input_flags.h>
Philip Quinn8f953ab2022-12-06 15:37:07 -080036
37#include <attestation/HmacKeyManager.h>
Siarhei Vishniakou09a8fe42022-07-21 17:27:03 -070038#include <ftl/enum.h>
Philip Quinn8f953ab2022-12-06 15:37:07 -080039#include <input/TfLiteMotionPredictor.h>
40
Derek Wuea36ee72024-03-25 13:17:51 -070041namespace input_flags = com::android::input::flags;
42
Philip Quinn8f953ab2022-12-06 15:37:07 -080043namespace android {
44namespace {
45
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -080046/**
47 * Log debug messages about predictions.
48 * Enable this via "adb shell setprop log.tag.MotionPredictor DEBUG"
49 */
Philip Quinn8f953ab2022-12-06 15:37:07 -080050bool isDebug() {
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -080051 return __android_log_is_loggable(ANDROID_LOG_DEBUG, LOG_TAG, ANDROID_LOG_INFO);
52}
53
Philip Quinn8f953ab2022-12-06 15:37:07 -080054// Converts a prediction of some polar (r, phi) to Cartesian (x, y) when applied to an axis.
55TfLiteMotionPredictorSample::Point convertPrediction(
56 const TfLiteMotionPredictorSample::Point& axisFrom,
57 const TfLiteMotionPredictorSample::Point& axisTo, float r, float phi) {
58 const TfLiteMotionPredictorSample::Point axis = axisTo - axisFrom;
59 const float axis_phi = std::atan2(axis.y, axis.x);
60 const float x_delta = r * std::cos(axis_phi + phi);
61 const float y_delta = r * std::sin(axis_phi + phi);
62 return {.x = axisTo.x + x_delta, .y = axisTo.y + y_delta};
63}
64
65} // namespace
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -080066
Derek Wu705068d2024-03-20 10:41:37 -070067// --- JerkTracker ---
68
69JerkTracker::JerkTracker(bool normalizedDt) : mNormalizedDt(normalizedDt) {}
70
71void JerkTracker::pushSample(int64_t timestamp, float xPos, float yPos) {
72 mTimestamps.pushBack(timestamp);
73 const int numSamples = mTimestamps.size();
74
75 std::array<float, 4> newXDerivatives;
76 std::array<float, 4> newYDerivatives;
77
78 /**
79 * Diagram showing the calculation of higher order derivatives of sample x3
80 * collected at time=t3.
81 * Terms in parentheses are not stored (and not needed for calculations)
82 * t0 ----- t1 ----- t2 ----- t3
83 * (x0)-----(x1) ----- x2 ----- x3
84 * (x'0) --- x'1 --- x'2
85 * x''0 - x''1
86 * x'''0
87 *
88 * In this example:
89 * x'2 = (x3 - x2) / (t3 - t2)
90 * x''1 = (x'2 - x'1) / (t2 - t1)
91 * x'''0 = (x''1 - x''0) / (t1 - t0)
92 * Therefore, timestamp history is needed to calculate higher order derivatives,
93 * compared to just the last calculated derivative sample.
94 *
95 * If mNormalizedDt = true, then dt = 1 and the division is moot.
96 */
97 for (int i = 0; i < numSamples; ++i) {
98 if (i == 0) {
99 newXDerivatives[i] = xPos;
100 newYDerivatives[i] = yPos;
101 } else {
102 newXDerivatives[i] = newXDerivatives[i - 1] - mXDerivatives[i - 1];
103 newYDerivatives[i] = newYDerivatives[i - 1] - mYDerivatives[i - 1];
104 if (!mNormalizedDt) {
105 const float dt = mTimestamps[numSamples - i] - mTimestamps[numSamples - i - 1];
106 newXDerivatives[i] = newXDerivatives[i] / dt;
107 newYDerivatives[i] = newYDerivatives[i] / dt;
108 }
109 }
110 }
111
112 std::swap(newXDerivatives, mXDerivatives);
113 std::swap(newYDerivatives, mYDerivatives);
114}
115
116void JerkTracker::reset() {
117 mTimestamps.clear();
118}
119
120std::optional<float> JerkTracker::jerkMagnitude() const {
121 if (mTimestamps.size() == mTimestamps.capacity()) {
122 return std::hypot(mXDerivatives[3], mYDerivatives[3]);
123 }
124 return std::nullopt;
125}
126
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -0800127// --- MotionPredictor ---
128
Siarhei Vishniakoufd0a68e2023-02-28 13:25:36 -0800129MotionPredictor::MotionPredictor(nsecs_t predictionTimestampOffsetNanos,
Cody Heiner7b26dbe2023-11-14 14:47:10 -0800130 std::function<bool()> checkMotionPredictionEnabled,
131 ReportAtomFunction reportAtomFunction)
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -0800132 : mPredictionTimestampOffsetNanos(predictionTimestampOffsetNanos),
Cody Heiner7b26dbe2023-11-14 14:47:10 -0800133 mCheckMotionPredictionEnabled(std::move(checkMotionPredictionEnabled)),
134 mReportAtomFunction(reportAtomFunction) {}
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -0800135
Siarhei Vishniakou33cb38b2023-02-23 18:52:34 -0800136android::base::Result<void> MotionPredictor::record(const MotionEvent& event) {
137 if (mLastEvent && mLastEvent->getDeviceId() != event.getDeviceId()) {
138 // We still have an active gesture for another device. The provided MotionEvent is not
Cody Heiner088c63e2023-06-15 12:06:09 -0700139 // consistent with the previous gesture.
Siarhei Vishniakou33cb38b2023-02-23 18:52:34 -0800140 LOG(ERROR) << "Inconsistent event stream: last event is " << *mLastEvent << ", but "
141 << __func__ << " is called with " << event;
142 return android::base::Error()
143 << "Inconsistent event stream: still have an active gesture from device "
144 << mLastEvent->getDeviceId() << ", but received " << event;
145 }
Philip Quinn8f953ab2022-12-06 15:37:07 -0800146 if (!isPredictionAvailable(event.getDeviceId(), event.getSource())) {
147 ALOGE("Prediction not supported for device %d's %s source", event.getDeviceId(),
148 inputEventSourceToString(event.getSource()).c_str());
Siarhei Vishniakou33cb38b2023-02-23 18:52:34 -0800149 return {};
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -0800150 }
Philip Quinn8f953ab2022-12-06 15:37:07 -0800151
Philip Quinnbd66e622023-02-10 11:45:01 -0800152 // Initialise the model now that it's likely to be used.
153 if (!mModel) {
Siarhei Vishniakoufd0a68e2023-02-28 13:25:36 -0800154 mModel = TfLiteMotionPredictorModel::create();
Cody Heiner088c63e2023-06-15 12:06:09 -0700155 LOG_ALWAYS_FATAL_IF(!mModel);
Philip Quinnbd66e622023-02-10 11:45:01 -0800156 }
157
Cody Heiner088c63e2023-06-15 12:06:09 -0700158 if (!mBuffers) {
Siarhei Vishniakou33cb38b2023-02-23 18:52:34 -0800159 mBuffers = std::make_unique<TfLiteMotionPredictorBuffers>(mModel->inputLength());
160 }
Philip Quinn8f953ab2022-12-06 15:37:07 -0800161
Cody Heiner7b26dbe2023-11-14 14:47:10 -0800162 // Pass input event to the MetricsManager.
163 if (!mMetricsManager) {
164 mMetricsManager.emplace(mModel->config().predictionInterval, mModel->outputLength(),
165 mReportAtomFunction);
166 }
167 mMetricsManager->onRecord(event);
168
Philip Quinn8f953ab2022-12-06 15:37:07 -0800169 const int32_t action = event.getActionMasked();
Siarhei Vishniakou33cb38b2023-02-23 18:52:34 -0800170 if (action == AMOTION_EVENT_ACTION_UP || action == AMOTION_EVENT_ACTION_CANCEL) {
Philip Quinn8f953ab2022-12-06 15:37:07 -0800171 ALOGD_IF(isDebug(), "End of event stream");
Siarhei Vishniakou33cb38b2023-02-23 18:52:34 -0800172 mBuffers->reset();
Derek Wu705068d2024-03-20 10:41:37 -0700173 mJerkTracker.reset();
Siarhei Vishniakou33cb38b2023-02-23 18:52:34 -0800174 mLastEvent.reset();
175 return {};
Philip Quinn8f953ab2022-12-06 15:37:07 -0800176 } else if (action != AMOTION_EVENT_ACTION_DOWN && action != AMOTION_EVENT_ACTION_MOVE) {
177 ALOGD_IF(isDebug(), "Skipping unsupported %s action",
178 MotionEvent::actionToString(action).c_str());
Siarhei Vishniakou33cb38b2023-02-23 18:52:34 -0800179 return {};
Philip Quinn8f953ab2022-12-06 15:37:07 -0800180 }
181
182 if (event.getPointerCount() != 1) {
183 ALOGD_IF(isDebug(), "Prediction not supported for multiple pointers");
Siarhei Vishniakou33cb38b2023-02-23 18:52:34 -0800184 return {};
Philip Quinn8f953ab2022-12-06 15:37:07 -0800185 }
186
Siarhei Vishniakou09a8fe42022-07-21 17:27:03 -0700187 const ToolType toolType = event.getPointerProperties(0)->toolType;
188 if (toolType != ToolType::STYLUS) {
Philip Quinn8f953ab2022-12-06 15:37:07 -0800189 ALOGD_IF(isDebug(), "Prediction not supported for non-stylus tool: %s",
Siarhei Vishniakou09a8fe42022-07-21 17:27:03 -0700190 ftl::enum_string(toolType).c_str());
Siarhei Vishniakou33cb38b2023-02-23 18:52:34 -0800191 return {};
Philip Quinn8f953ab2022-12-06 15:37:07 -0800192 }
193
194 for (size_t i = 0; i <= event.getHistorySize(); ++i) {
195 if (event.isResampled(0, i)) {
196 continue;
197 }
198 const PointerCoords* coords = event.getHistoricalRawPointerCoords(0, i);
Siarhei Vishniakou33cb38b2023-02-23 18:52:34 -0800199 mBuffers->pushSample(event.getHistoricalEventTime(i),
200 {
201 .position.x = coords->getAxisValue(AMOTION_EVENT_AXIS_X),
202 .position.y = coords->getAxisValue(AMOTION_EVENT_AXIS_Y),
203 .pressure = event.getHistoricalPressure(0, i),
204 .tilt = event.getHistoricalAxisValue(AMOTION_EVENT_AXIS_TILT,
205 0, i),
206 .orientation = event.getHistoricalOrientation(0, i),
207 });
Derek Wu705068d2024-03-20 10:41:37 -0700208 mJerkTracker.pushSample(event.getHistoricalEventTime(i),
209 coords->getAxisValue(AMOTION_EVENT_AXIS_X),
210 coords->getAxisValue(AMOTION_EVENT_AXIS_Y));
Philip Quinn8f953ab2022-12-06 15:37:07 -0800211 }
212
Siarhei Vishniakou33cb38b2023-02-23 18:52:34 -0800213 if (!mLastEvent) {
214 mLastEvent = MotionEvent();
215 }
216 mLastEvent->copyFrom(&event, /*keepHistory=*/false);
Cody Heiner088c63e2023-06-15 12:06:09 -0700217
Siarhei Vishniakou33cb38b2023-02-23 18:52:34 -0800218 return {};
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -0800219}
220
Siarhei Vishniakou33cb38b2023-02-23 18:52:34 -0800221std::unique_ptr<MotionEvent> MotionPredictor::predict(nsecs_t timestamp) {
222 if (mBuffers == nullptr || !mBuffers->isReady()) {
223 return nullptr;
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -0800224 }
Siarhei Vishniakou33cb38b2023-02-23 18:52:34 -0800225
226 LOG_ALWAYS_FATAL_IF(!mModel);
227 mBuffers->copyTo(*mModel);
228 LOG_ALWAYS_FATAL_IF(!mModel->invoke());
229
230 // Read out the predictions.
231 const std::span<const float> predictedR = mModel->outputR();
232 const std::span<const float> predictedPhi = mModel->outputPhi();
233 const std::span<const float> predictedPressure = mModel->outputPressure();
234
235 TfLiteMotionPredictorSample::Point axisFrom = mBuffers->axisFrom().position;
236 TfLiteMotionPredictorSample::Point axisTo = mBuffers->axisTo().position;
237
238 if (isDebug()) {
239 ALOGD("axisFrom: %f, %f", axisFrom.x, axisFrom.y);
240 ALOGD("axisTo: %f, %f", axisTo.x, axisTo.y);
241 ALOGD("mInputR: %s", base::Join(mModel->inputR(), ", ").c_str());
242 ALOGD("mInputPhi: %s", base::Join(mModel->inputPhi(), ", ").c_str());
243 ALOGD("mInputPressure: %s", base::Join(mModel->inputPressure(), ", ").c_str());
244 ALOGD("mInputTilt: %s", base::Join(mModel->inputTilt(), ", ").c_str());
245 ALOGD("mInputOrientation: %s", base::Join(mModel->inputOrientation(), ", ").c_str());
246 ALOGD("predictedR: %s", base::Join(predictedR, ", ").c_str());
247 ALOGD("predictedPhi: %s", base::Join(predictedPhi, ", ").c_str());
248 ALOGD("predictedPressure: %s", base::Join(predictedPressure, ", ").c_str());
249 }
250
251 LOG_ALWAYS_FATAL_IF(!mLastEvent);
252 const MotionEvent& event = *mLastEvent;
253 bool hasPredictions = false;
254 std::unique_ptr<MotionEvent> prediction = std::make_unique<MotionEvent>();
255 int64_t predictionTime = mBuffers->lastTimestamp();
256 const int64_t futureTime = timestamp + mPredictionTimestampOffsetNanos;
257
Ryan Prichard5a8af502023-08-31 00:00:47 -0700258 for (size_t i = 0; i < static_cast<size_t>(predictedR.size()) && predictionTime <= futureTime;
259 ++i) {
Philip Quinn107ce702023-07-14 13:07:13 -0700260 if (predictedR[i] < mModel->config().distanceNoiseFloor) {
261 // Stop predicting when the predicted output is below the model's noise floor.
262 //
263 // We assume that all subsequent predictions in the batch are unreliable because later
264 // predictions are conditional on earlier predictions, and a state of noise is not a
265 // good basis for prediction.
266 //
267 // The UX trade-off is that this potentially sacrifices some predictions when the input
268 // device starts to speed up, but avoids producing noisy predictions as it slows down.
269 break;
270 }
Derek Wuea36ee72024-03-25 13:17:51 -0700271 if (input_flags::enable_prediction_pruning_via_jerk_thresholding()) {
272 // TODO(b/266747654): Stop predictions if confidence is < some threshold
273 // Arbitrarily high pruning index, will correct once jerk thresholding is implemented.
274 const size_t upperBoundPredictionIndex = std::numeric_limits<size_t>::max();
275 if (i > upperBoundPredictionIndex) {
276 break;
277 }
278 }
Siarhei Vishniakou33cb38b2023-02-23 18:52:34 -0800279
Cody Heiner088c63e2023-06-15 12:06:09 -0700280 const TfLiteMotionPredictorSample::Point predictedPoint =
281 convertPrediction(axisFrom, axisTo, predictedR[i], predictedPhi[i]);
282
Ryan Prichard5a8af502023-08-31 00:00:47 -0700283 ALOGD_IF(isDebug(), "prediction %zu: %f, %f", i, predictedPoint.x, predictedPoint.y);
Siarhei Vishniakou33cb38b2023-02-23 18:52:34 -0800284 PointerCoords coords;
285 coords.clear();
Cody Heiner088c63e2023-06-15 12:06:09 -0700286 coords.setAxisValue(AMOTION_EVENT_AXIS_X, predictedPoint.x);
287 coords.setAxisValue(AMOTION_EVENT_AXIS_Y, predictedPoint.y);
Siarhei Vishniakou33cb38b2023-02-23 18:52:34 -0800288 coords.setAxisValue(AMOTION_EVENT_AXIS_PRESSURE, predictedPressure[i]);
Philip Quinn59fa9122023-09-18 13:35:54 -0700289 // Copy forward tilt and orientation from the last event until they are predicted
290 // (b/291789258).
291 coords.setAxisValue(AMOTION_EVENT_AXIS_TILT,
292 event.getAxisValue(AMOTION_EVENT_AXIS_TILT, 0));
293 coords.setAxisValue(AMOTION_EVENT_AXIS_ORIENTATION,
294 event.getRawPointerCoords(0)->getAxisValue(
295 AMOTION_EVENT_AXIS_ORIENTATION));
Siarhei Vishniakou33cb38b2023-02-23 18:52:34 -0800296
Philip Quinn107ce702023-07-14 13:07:13 -0700297 predictionTime += mModel->config().predictionInterval;
Siarhei Vishniakou33cb38b2023-02-23 18:52:34 -0800298 if (i == 0) {
299 hasPredictions = true;
300 prediction->initialize(InputEvent::nextId(), event.getDeviceId(), event.getSource(),
301 event.getDisplayId(), INVALID_HMAC, AMOTION_EVENT_ACTION_MOVE,
302 event.getActionButton(), event.getFlags(), event.getEdgeFlags(),
303 event.getMetaState(), event.getButtonState(),
304 event.getClassification(), event.getTransform(),
305 event.getXPrecision(), event.getYPrecision(),
306 event.getRawXCursorPosition(), event.getRawYCursorPosition(),
307 event.getRawTransform(), event.getDownTime(), predictionTime,
308 event.getPointerCount(), event.getPointerProperties(), &coords);
309 } else {
310 prediction->addSample(predictionTime, &coords);
311 }
312
313 axisFrom = axisTo;
Cody Heiner088c63e2023-06-15 12:06:09 -0700314 axisTo = predictedPoint;
Siarhei Vishniakou33cb38b2023-02-23 18:52:34 -0800315 }
Cody Heiner088c63e2023-06-15 12:06:09 -0700316
Siarhei Vishniakou33cb38b2023-02-23 18:52:34 -0800317 if (!hasPredictions) {
318 return nullptr;
319 }
Cody Heiner088c63e2023-06-15 12:06:09 -0700320
321 // Pass predictions to the MetricsManager.
322 LOG_ALWAYS_FATAL_IF(!mMetricsManager);
323 mMetricsManager->onPredict(*prediction);
324
Siarhei Vishniakou33cb38b2023-02-23 18:52:34 -0800325 return prediction;
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -0800326}
327
328bool MotionPredictor::isPredictionAvailable(int32_t /*deviceId*/, int32_t source) {
329 // Global flag override
330 if (!mCheckMotionPredictionEnabled()) {
331 ALOGD_IF(isDebug(), "Prediction not available due to flag override");
332 return false;
333 }
334
335 // Prediction is only supported for stylus sources.
336 if (!isFromSource(source, AINPUT_SOURCE_STYLUS)) {
337 ALOGD_IF(isDebug(), "Prediction not available for non-stylus source: %s",
338 inputEventSourceToString(source).c_str());
339 return false;
340 }
341 return true;
342}
343
Siarhei Vishniakou39147ce2022-11-15 12:13:04 -0800344} // namespace android