blob: 38472d8df7140dd8384cc543d3007a02d0ae9a9b [file] [log] [blame]
Cody Heiner088c63e2023-06-15 12:06:09 -07001/*
Cody Heiner52db4742023-06-29 13:19:01 -07002 * Copyright 2023 The Android Open Source Project
Cody Heiner088c63e2023-06-15 12:06:09 -07003 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Cody Heiner52db4742023-06-29 13:19:01 -070017#include <cstddef>
18#include <cstdint>
19#include <functional>
20#include <limits>
Cody Heiner52db4742023-06-29 13:19:01 -070021#include <vector>
22
23#include <input/Input.h> // for MotionEvent
24#include <input/RingBuffer.h>
25#include <utils/Timers.h> // for nsecs_t
26
27#include "Eigen/Core"
Cody Heiner088c63e2023-06-15 12:06:09 -070028
29namespace android {
30
31/**
32 * Class to handle computing and reporting metrics for MotionPredictor.
33 *
Cody Heiner52db4742023-06-29 13:19:01 -070034 * The public API provides two methods: `onRecord` and `onPredict`, which expect to receive the
35 * MotionEvents from the corresponding methods in MotionPredictor.
36 *
37 * This class stores AggregatedStrokeMetrics, updating them as new MotionEvents are passed in. When
38 * onRecord receives an UP or CANCEL event, this indicates the end of the stroke, and the final
Cody Heiner7b26dbe2023-11-14 14:47:10 -080039 * AtomFields are computed and reported to the stats library. The number of atoms reported is equal
40 * to the value of `maxNumPredictions` passed to the constructor. Each atom corresponds to one
41 * "prediction time bucket" — the amount of time into the future being predicted.
Cody Heiner52db4742023-06-29 13:19:01 -070042 *
43 * If mMockLoggedAtomFields is set, the batch of AtomFields that are reported to the stats library
44 * for one stroke are also stored in mMockLoggedAtomFields at the time they're reported.
Cody Heiner088c63e2023-06-15 12:06:09 -070045 */
46class MotionPredictorMetricsManager {
47public:
Cody Heiner7b26dbe2023-11-14 14:47:10 -080048 struct AtomFields;
49
50 using ReportAtomFunction = std::function<void(const AtomFields&)>;
51
52 static void defaultReportAtomFunction(const AtomFields& atomFields);
53
54 // Parameters:
55 // • predictionInterval: the time interval between successive prediction target timestamps.
56 // Note: the MetricsManager assumes that the input interval equals the prediction interval.
57 // • maxNumPredictions: the maximum number of distinct target timestamps the prediction model
58 // will generate predictions for. The MetricsManager reports this many atoms per stroke.
59 // • [Optional] reportAtomFunction: the function that will be called to report metrics. If
60 // omitted (or if an empty function is given), the `stats_write(…)` function from the Android
61 // stats library will be used.
62 MotionPredictorMetricsManager(
63 nsecs_t predictionInterval,
64 size_t maxNumPredictions,
65 ReportAtomFunction reportAtomFunction = defaultReportAtomFunction);
Cody Heiner088c63e2023-06-15 12:06:09 -070066
Cody Heiner52db4742023-06-29 13:19:01 -070067 // This method should be called once for each call to MotionPredictor::record, receiving the
68 // forwarded MotionEvent argument.
69 void onRecord(const MotionEvent& inputEvent);
Cody Heiner088c63e2023-06-15 12:06:09 -070070
Cody Heiner52db4742023-06-29 13:19:01 -070071 // This method should be called once for each call to MotionPredictor::predict, receiving the
72 // MotionEvent that will be returned by MotionPredictor::predict.
73 void onPredict(const MotionEvent& predictionEvent);
74
75 // Simple structs to hold relevant touch input information. Public so they can be used in tests.
76
77 struct TouchPoint {
78 Eigen::Vector2f position; // (y, x) in pixels
79 float pressure;
80 };
81
82 struct GroundTruthPoint : TouchPoint {
83 nsecs_t timestamp;
84 };
85
86 struct PredictionPoint : TouchPoint {
87 // The timestamp of the last ground truth point when the prediction was made.
88 nsecs_t originTimestamp;
89
90 nsecs_t targetTimestamp;
91
92 // Order by targetTimestamp when sorting.
93 bool operator<(const PredictionPoint& other) const {
94 return this->targetTimestamp < other.targetTimestamp;
95 }
96 };
97
98 // Metrics aggregated so far for the current stroke. These are not the final fields to be
99 // reported in the atom (see AtomFields below), but rather an intermediate representation of the
100 // data that can be conveniently aggregated and from which the atom fields can be derived later.
101 //
102 // Displacement units are in pixels.
103 //
104 // "Along-trajectory error" is the dot product of the prediction error with the unit vector
105 // pointing towards the ground truth point whose timestamp corresponds to the prediction
106 // target timestamp, originating from the preceding ground truth point.
107 //
108 // "Off-trajectory error" is the component of the prediction error orthogonal to the
109 // "along-trajectory" unit vector described above.
110 //
111 // "High-velocity" errors are errors that are only accumulated when the velocity between the
112 // most recent two input events exceeds a certain threshold.
113 //
114 // "Scale-invariant errors" are the errors produced when the path length of the stroke is
115 // scaled to 1. (In other words, the error distances are normalized by the path length.)
116 struct AggregatedStrokeMetrics {
117 // General errors
118 float alongTrajectoryErrorSum = 0;
119 float alongTrajectorySumSquaredErrors = 0;
120 float offTrajectorySumSquaredErrors = 0;
121 float pressureSumSquaredErrors = 0;
122 size_t generalErrorsCount = 0;
123
124 // High-velocity errors
125 float highVelocityAlongTrajectorySse = 0;
126 float highVelocityOffTrajectorySse = 0;
127 size_t highVelocityErrorsCount = 0;
128
129 // Scale-invariant errors
130 float scaleInvariantAlongTrajectorySse = 0;
131 float scaleInvariantOffTrajectorySse = 0;
132 size_t scaleInvariantErrorsCount = 0;
133 };
134
135 // In order to explicitly indicate "no relevant data" for a metric, we report this
136 // large-magnitude negative sentinel value. (Most metrics are non-negative, so this value is
137 // completely unobtainable. For along-trajectory error mean, which can be negative, the
138 // magnitude makes it unobtainable in practice.)
139 static const int NO_DATA_SENTINEL = std::numeric_limits<int32_t>::min();
140
Cody Heiner7b26dbe2023-11-14 14:47:10 -0800141 // Final metric values reported in the atom.
Cody Heiner52db4742023-06-29 13:19:01 -0700142 struct AtomFields {
143 int deltaTimeBucketMilliseconds = 0;
144
145 // General errors
146 int alongTrajectoryErrorMeanMillipixels = NO_DATA_SENTINEL;
147 int alongTrajectoryErrorStdMillipixels = NO_DATA_SENTINEL;
148 int offTrajectoryRmseMillipixels = NO_DATA_SENTINEL;
149 int pressureRmseMilliunits = NO_DATA_SENTINEL;
150
151 // High-velocity errors
152 int highVelocityAlongTrajectoryRmse = NO_DATA_SENTINEL; // millipixels
153 int highVelocityOffTrajectoryRmse = NO_DATA_SENTINEL; // millipixels
154
155 // Scale-invariant errors
156 int scaleInvariantAlongTrajectoryRmse = NO_DATA_SENTINEL; // millipixels
157 int scaleInvariantOffTrajectoryRmse = NO_DATA_SENTINEL; // millipixels
158 };
159
Cody Heiner52db4742023-06-29 13:19:01 -0700160private:
161 // The interval between consecutive predictions' target timestamps. We assume that the input
162 // interval also equals this value.
163 const nsecs_t mPredictionInterval;
164
165 // The maximum number of input frames into the future the model can predict.
166 // Used to perform time-bucketing of metrics.
167 const size_t mMaxNumPredictions;
168
169 // History of mMaxNumPredictions + 1 ground truth points, used to compute scale-invariant
170 // error. (Also, the last two points are used to compute the ground truth trajectory.)
171 RingBuffer<GroundTruthPoint> mRecentGroundTruthPoints;
172
173 // Predictions having a targetTimestamp after the most recent ground truth point's timestamp.
174 // Invariant: sorted in ascending order of targetTimestamp.
175 std::vector<PredictionPoint> mRecentPredictions;
176
177 // Containers for the intermediate representation of stroke metrics and the final atom fields.
178 // These are indexed by the number of input frames into the future being predicted minus one,
179 // and always have size mMaxNumPredictions.
180 std::vector<AggregatedStrokeMetrics> mAggregatedMetrics;
181 std::vector<AtomFields> mAtomFields;
182
Cody Heiner7b26dbe2023-11-14 14:47:10 -0800183 const ReportAtomFunction mReportAtomFunction;
Cody Heiner52db4742023-06-29 13:19:01 -0700184
185 // Helper methods for the implementation of onRecord and onPredict.
186
187 // Clears stored ground truth and prediction points, as well as all stored metrics for the
188 // current stroke.
189 void clearStrokeData();
190
191 // Adds the new ground truth point to mRecentGroundTruths, removes outdated predictions from
192 // mRecentPredictions, and updates the aggregated metrics to include the recent predictions that
193 // fuzzily match with the new ground truth point.
194 void incorporateNewGroundTruth(const GroundTruthPoint& groundTruthPoint);
195
196 // Given a new prediction with targetTimestamp matching the latest ground truth point's
197 // timestamp, computes the corresponding metrics and updates mAggregatedMetrics.
198 void updateAggregatedMetrics(const PredictionPoint& predictionPoint);
199
200 // Computes the atom fields to mAtomFields from the values in mAggregatedMetrics.
201 void computeAtomFields();
202
Cody Heiner7b26dbe2023-11-14 14:47:10 -0800203 // Reports the current data in mAtomFields by calling mReportAtomFunction.
Cody Heiner52db4742023-06-29 13:19:01 -0700204 void reportMetrics();
Cody Heiner088c63e2023-06-15 12:06:09 -0700205};
206
207} // namespace android