blob: 12e50ba3b49a0fd8a8cdcbfd3c96c078a42a79e5 [file] [log] [blame]
Cody Heiner088c63e2023-06-15 12:06:09 -07001/*
Cody Heiner52db4742023-06-29 13:19:01 -07002 * Copyright 2023 The Android Open Source Project
Cody Heiner088c63e2023-06-15 12:06:09 -07003 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Cody Heiner52db4742023-06-29 13:19:01 -070017#include <cstddef>
18#include <cstdint>
19#include <functional>
20#include <limits>
21#include <optional>
22#include <vector>
23
24#include <input/Input.h> // for MotionEvent
25#include <input/RingBuffer.h>
26#include <utils/Timers.h> // for nsecs_t
27
28#include "Eigen/Core"
Cody Heiner088c63e2023-06-15 12:06:09 -070029
30namespace android {
31
32/**
33 * Class to handle computing and reporting metrics for MotionPredictor.
34 *
Cody Heiner52db4742023-06-29 13:19:01 -070035 * The public API provides two methods: `onRecord` and `onPredict`, which expect to receive the
36 * MotionEvents from the corresponding methods in MotionPredictor.
37 *
38 * This class stores AggregatedStrokeMetrics, updating them as new MotionEvents are passed in. When
39 * onRecord receives an UP or CANCEL event, this indicates the end of the stroke, and the final
40 * AtomFields are computed and reported to the stats library.
41 *
42 * If mMockLoggedAtomFields is set, the batch of AtomFields that are reported to the stats library
43 * for one stroke are also stored in mMockLoggedAtomFields at the time they're reported.
Cody Heiner088c63e2023-06-15 12:06:09 -070044 */
45class MotionPredictorMetricsManager {
46public:
47 // Note: the MetricsManager assumes that the input interval equals the prediction interval.
Cody Heiner52db4742023-06-29 13:19:01 -070048 MotionPredictorMetricsManager(nsecs_t predictionInterval, size_t maxNumPredictions);
Cody Heiner088c63e2023-06-15 12:06:09 -070049
Cody Heiner52db4742023-06-29 13:19:01 -070050 // This method should be called once for each call to MotionPredictor::record, receiving the
51 // forwarded MotionEvent argument.
52 void onRecord(const MotionEvent& inputEvent);
Cody Heiner088c63e2023-06-15 12:06:09 -070053
Cody Heiner52db4742023-06-29 13:19:01 -070054 // This method should be called once for each call to MotionPredictor::predict, receiving the
55 // MotionEvent that will be returned by MotionPredictor::predict.
56 void onPredict(const MotionEvent& predictionEvent);
57
58 // Simple structs to hold relevant touch input information. Public so they can be used in tests.
59
60 struct TouchPoint {
61 Eigen::Vector2f position; // (y, x) in pixels
62 float pressure;
63 };
64
65 struct GroundTruthPoint : TouchPoint {
66 nsecs_t timestamp;
67 };
68
69 struct PredictionPoint : TouchPoint {
70 // The timestamp of the last ground truth point when the prediction was made.
71 nsecs_t originTimestamp;
72
73 nsecs_t targetTimestamp;
74
75 // Order by targetTimestamp when sorting.
76 bool operator<(const PredictionPoint& other) const {
77 return this->targetTimestamp < other.targetTimestamp;
78 }
79 };
80
81 // Metrics aggregated so far for the current stroke. These are not the final fields to be
82 // reported in the atom (see AtomFields below), but rather an intermediate representation of the
83 // data that can be conveniently aggregated and from which the atom fields can be derived later.
84 //
85 // Displacement units are in pixels.
86 //
87 // "Along-trajectory error" is the dot product of the prediction error with the unit vector
88 // pointing towards the ground truth point whose timestamp corresponds to the prediction
89 // target timestamp, originating from the preceding ground truth point.
90 //
91 // "Off-trajectory error" is the component of the prediction error orthogonal to the
92 // "along-trajectory" unit vector described above.
93 //
94 // "High-velocity" errors are errors that are only accumulated when the velocity between the
95 // most recent two input events exceeds a certain threshold.
96 //
97 // "Scale-invariant errors" are the errors produced when the path length of the stroke is
98 // scaled to 1. (In other words, the error distances are normalized by the path length.)
99 struct AggregatedStrokeMetrics {
100 // General errors
101 float alongTrajectoryErrorSum = 0;
102 float alongTrajectorySumSquaredErrors = 0;
103 float offTrajectorySumSquaredErrors = 0;
104 float pressureSumSquaredErrors = 0;
105 size_t generalErrorsCount = 0;
106
107 // High-velocity errors
108 float highVelocityAlongTrajectorySse = 0;
109 float highVelocityOffTrajectorySse = 0;
110 size_t highVelocityErrorsCount = 0;
111
112 // Scale-invariant errors
113 float scaleInvariantAlongTrajectorySse = 0;
114 float scaleInvariantOffTrajectorySse = 0;
115 size_t scaleInvariantErrorsCount = 0;
116 };
117
118 // In order to explicitly indicate "no relevant data" for a metric, we report this
119 // large-magnitude negative sentinel value. (Most metrics are non-negative, so this value is
120 // completely unobtainable. For along-trajectory error mean, which can be negative, the
121 // magnitude makes it unobtainable in practice.)
122 static const int NO_DATA_SENTINEL = std::numeric_limits<int32_t>::min();
123
124 // Final metrics reported in the atom.
125 struct AtomFields {
126 int deltaTimeBucketMilliseconds = 0;
127
128 // General errors
129 int alongTrajectoryErrorMeanMillipixels = NO_DATA_SENTINEL;
130 int alongTrajectoryErrorStdMillipixels = NO_DATA_SENTINEL;
131 int offTrajectoryRmseMillipixels = NO_DATA_SENTINEL;
132 int pressureRmseMilliunits = NO_DATA_SENTINEL;
133
134 // High-velocity errors
135 int highVelocityAlongTrajectoryRmse = NO_DATA_SENTINEL; // millipixels
136 int highVelocityOffTrajectoryRmse = NO_DATA_SENTINEL; // millipixels
137
138 // Scale-invariant errors
139 int scaleInvariantAlongTrajectoryRmse = NO_DATA_SENTINEL; // millipixels
140 int scaleInvariantOffTrajectoryRmse = NO_DATA_SENTINEL; // millipixels
141 };
142
143 // Allow tests to pass in a mock AtomFields pointer.
144 //
145 // When metrics are reported to the stats library on stroke end, they will also be written to
146 // mockLoggedAtomFields, overwriting existing data. The size of mockLoggedAtomFields will equal
147 // the number of calls to stats_write for that stroke.
148 void setMockLoggedAtomFields(std::vector<AtomFields>* mockLoggedAtomFields) {
149 mMockLoggedAtomFields = mockLoggedAtomFields;
150 }
151
152private:
153 // The interval between consecutive predictions' target timestamps. We assume that the input
154 // interval also equals this value.
155 const nsecs_t mPredictionInterval;
156
157 // The maximum number of input frames into the future the model can predict.
158 // Used to perform time-bucketing of metrics.
159 const size_t mMaxNumPredictions;
160
161 // History of mMaxNumPredictions + 1 ground truth points, used to compute scale-invariant
162 // error. (Also, the last two points are used to compute the ground truth trajectory.)
163 RingBuffer<GroundTruthPoint> mRecentGroundTruthPoints;
164
165 // Predictions having a targetTimestamp after the most recent ground truth point's timestamp.
166 // Invariant: sorted in ascending order of targetTimestamp.
167 std::vector<PredictionPoint> mRecentPredictions;
168
169 // Containers for the intermediate representation of stroke metrics and the final atom fields.
170 // These are indexed by the number of input frames into the future being predicted minus one,
171 // and always have size mMaxNumPredictions.
172 std::vector<AggregatedStrokeMetrics> mAggregatedMetrics;
173 std::vector<AtomFields> mAtomFields;
174
175 // Non-owning pointer to the location of mock AtomFields. If present, will be filled with the
176 // values reported to stats_write on each batch of reported metrics.
177 //
178 // This pointer must remain valid as long as the MotionPredictorMetricsManager exists.
179 std::vector<AtomFields>* mMockLoggedAtomFields = nullptr;
180
181 // Helper methods for the implementation of onRecord and onPredict.
182
183 // Clears stored ground truth and prediction points, as well as all stored metrics for the
184 // current stroke.
185 void clearStrokeData();
186
187 // Adds the new ground truth point to mRecentGroundTruths, removes outdated predictions from
188 // mRecentPredictions, and updates the aggregated metrics to include the recent predictions that
189 // fuzzily match with the new ground truth point.
190 void incorporateNewGroundTruth(const GroundTruthPoint& groundTruthPoint);
191
192 // Given a new prediction with targetTimestamp matching the latest ground truth point's
193 // timestamp, computes the corresponding metrics and updates mAggregatedMetrics.
194 void updateAggregatedMetrics(const PredictionPoint& predictionPoint);
195
196 // Computes the atom fields to mAtomFields from the values in mAggregatedMetrics.
197 void computeAtomFields();
198
199 // Reports the metrics given by the current data in mAtomFields:
200 // • If on an Android device, reports the metrics to stats_write.
201 // • If mMockLoggedAtomFields is present, it will be overwritten with logged metrics, with one
202 // AtomFields element per call to stats_write.
203 void reportMetrics();
Cody Heiner088c63e2023-06-15 12:06:09 -0700204};
205
206} // namespace android