Cody Heiner | 088c63e | 2023-06-15 12:06:09 -0700 | [diff] [blame] | 1 | /* |
Cody Heiner | 52db474 | 2023-06-29 13:19:01 -0700 | [diff] [blame] | 2 | * Copyright 2023 The Android Open Source Project |
Cody Heiner | 088c63e | 2023-06-15 12:06:09 -0700 | [diff] [blame] | 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | */ |
| 16 | |
Cody Heiner | 52db474 | 2023-06-29 13:19:01 -0700 | [diff] [blame] | 17 | #include <cstddef> |
| 18 | #include <cstdint> |
| 19 | #include <functional> |
| 20 | #include <limits> |
Cody Heiner | 52db474 | 2023-06-29 13:19:01 -0700 | [diff] [blame] | 21 | #include <vector> |
| 22 | |
| 23 | #include <input/Input.h> // for MotionEvent |
| 24 | #include <input/RingBuffer.h> |
| 25 | #include <utils/Timers.h> // for nsecs_t |
| 26 | |
| 27 | #include "Eigen/Core" |
Cody Heiner | 088c63e | 2023-06-15 12:06:09 -0700 | [diff] [blame] | 28 | |
| 29 | namespace android { |
| 30 | |
| 31 | /** |
| 32 | * Class to handle computing and reporting metrics for MotionPredictor. |
| 33 | * |
Cody Heiner | 52db474 | 2023-06-29 13:19:01 -0700 | [diff] [blame] | 34 | * The public API provides two methods: `onRecord` and `onPredict`, which expect to receive the |
| 35 | * MotionEvents from the corresponding methods in MotionPredictor. |
| 36 | * |
| 37 | * This class stores AggregatedStrokeMetrics, updating them as new MotionEvents are passed in. When |
| 38 | * onRecord receives an UP or CANCEL event, this indicates the end of the stroke, and the final |
Cody Heiner | 7b26dbe | 2023-11-14 14:47:10 -0800 | [diff] [blame] | 39 | * AtomFields are computed and reported to the stats library. The number of atoms reported is equal |
| 40 | * to the value of `maxNumPredictions` passed to the constructor. Each atom corresponds to one |
| 41 | * "prediction time bucket" — the amount of time into the future being predicted. |
Cody Heiner | 52db474 | 2023-06-29 13:19:01 -0700 | [diff] [blame] | 42 | * |
| 43 | * If mMockLoggedAtomFields is set, the batch of AtomFields that are reported to the stats library |
| 44 | * for one stroke are also stored in mMockLoggedAtomFields at the time they're reported. |
Cody Heiner | 088c63e | 2023-06-15 12:06:09 -0700 | [diff] [blame] | 45 | */ |
| 46 | class MotionPredictorMetricsManager { |
| 47 | public: |
Cody Heiner | 7b26dbe | 2023-11-14 14:47:10 -0800 | [diff] [blame] | 48 | struct AtomFields; |
| 49 | |
| 50 | using ReportAtomFunction = std::function<void(const AtomFields&)>; |
| 51 | |
| 52 | static void defaultReportAtomFunction(const AtomFields& atomFields); |
| 53 | |
| 54 | // Parameters: |
| 55 | // • predictionInterval: the time interval between successive prediction target timestamps. |
| 56 | // Note: the MetricsManager assumes that the input interval equals the prediction interval. |
| 57 | // • maxNumPredictions: the maximum number of distinct target timestamps the prediction model |
| 58 | // will generate predictions for. The MetricsManager reports this many atoms per stroke. |
| 59 | // • [Optional] reportAtomFunction: the function that will be called to report metrics. If |
| 60 | // omitted (or if an empty function is given), the `stats_write(…)` function from the Android |
| 61 | // stats library will be used. |
| 62 | MotionPredictorMetricsManager( |
| 63 | nsecs_t predictionInterval, |
| 64 | size_t maxNumPredictions, |
| 65 | ReportAtomFunction reportAtomFunction = defaultReportAtomFunction); |
Cody Heiner | 088c63e | 2023-06-15 12:06:09 -0700 | [diff] [blame] | 66 | |
Cody Heiner | 52db474 | 2023-06-29 13:19:01 -0700 | [diff] [blame] | 67 | // This method should be called once for each call to MotionPredictor::record, receiving the |
| 68 | // forwarded MotionEvent argument. |
| 69 | void onRecord(const MotionEvent& inputEvent); |
Cody Heiner | 088c63e | 2023-06-15 12:06:09 -0700 | [diff] [blame] | 70 | |
Cody Heiner | 52db474 | 2023-06-29 13:19:01 -0700 | [diff] [blame] | 71 | // This method should be called once for each call to MotionPredictor::predict, receiving the |
| 72 | // MotionEvent that will be returned by MotionPredictor::predict. |
| 73 | void onPredict(const MotionEvent& predictionEvent); |
| 74 | |
| 75 | // Simple structs to hold relevant touch input information. Public so they can be used in tests. |
| 76 | |
| 77 | struct TouchPoint { |
| 78 | Eigen::Vector2f position; // (y, x) in pixels |
| 79 | float pressure; |
| 80 | }; |
| 81 | |
| 82 | struct GroundTruthPoint : TouchPoint { |
| 83 | nsecs_t timestamp; |
| 84 | }; |
| 85 | |
| 86 | struct PredictionPoint : TouchPoint { |
| 87 | // The timestamp of the last ground truth point when the prediction was made. |
| 88 | nsecs_t originTimestamp; |
| 89 | |
| 90 | nsecs_t targetTimestamp; |
| 91 | |
| 92 | // Order by targetTimestamp when sorting. |
| 93 | bool operator<(const PredictionPoint& other) const { |
| 94 | return this->targetTimestamp < other.targetTimestamp; |
| 95 | } |
| 96 | }; |
| 97 | |
| 98 | // Metrics aggregated so far for the current stroke. These are not the final fields to be |
| 99 | // reported in the atom (see AtomFields below), but rather an intermediate representation of the |
| 100 | // data that can be conveniently aggregated and from which the atom fields can be derived later. |
| 101 | // |
| 102 | // Displacement units are in pixels. |
| 103 | // |
| 104 | // "Along-trajectory error" is the dot product of the prediction error with the unit vector |
| 105 | // pointing towards the ground truth point whose timestamp corresponds to the prediction |
| 106 | // target timestamp, originating from the preceding ground truth point. |
| 107 | // |
| 108 | // "Off-trajectory error" is the component of the prediction error orthogonal to the |
| 109 | // "along-trajectory" unit vector described above. |
| 110 | // |
| 111 | // "High-velocity" errors are errors that are only accumulated when the velocity between the |
| 112 | // most recent two input events exceeds a certain threshold. |
| 113 | // |
| 114 | // "Scale-invariant errors" are the errors produced when the path length of the stroke is |
| 115 | // scaled to 1. (In other words, the error distances are normalized by the path length.) |
| 116 | struct AggregatedStrokeMetrics { |
| 117 | // General errors |
| 118 | float alongTrajectoryErrorSum = 0; |
| 119 | float alongTrajectorySumSquaredErrors = 0; |
| 120 | float offTrajectorySumSquaredErrors = 0; |
| 121 | float pressureSumSquaredErrors = 0; |
| 122 | size_t generalErrorsCount = 0; |
| 123 | |
| 124 | // High-velocity errors |
| 125 | float highVelocityAlongTrajectorySse = 0; |
| 126 | float highVelocityOffTrajectorySse = 0; |
| 127 | size_t highVelocityErrorsCount = 0; |
| 128 | |
| 129 | // Scale-invariant errors |
| 130 | float scaleInvariantAlongTrajectorySse = 0; |
| 131 | float scaleInvariantOffTrajectorySse = 0; |
| 132 | size_t scaleInvariantErrorsCount = 0; |
| 133 | }; |
| 134 | |
| 135 | // In order to explicitly indicate "no relevant data" for a metric, we report this |
| 136 | // large-magnitude negative sentinel value. (Most metrics are non-negative, so this value is |
| 137 | // completely unobtainable. For along-trajectory error mean, which can be negative, the |
| 138 | // magnitude makes it unobtainable in practice.) |
| 139 | static const int NO_DATA_SENTINEL = std::numeric_limits<int32_t>::min(); |
| 140 | |
Cody Heiner | 7b26dbe | 2023-11-14 14:47:10 -0800 | [diff] [blame] | 141 | // Final metric values reported in the atom. |
Cody Heiner | 52db474 | 2023-06-29 13:19:01 -0700 | [diff] [blame] | 142 | struct AtomFields { |
| 143 | int deltaTimeBucketMilliseconds = 0; |
| 144 | |
| 145 | // General errors |
| 146 | int alongTrajectoryErrorMeanMillipixels = NO_DATA_SENTINEL; |
| 147 | int alongTrajectoryErrorStdMillipixels = NO_DATA_SENTINEL; |
| 148 | int offTrajectoryRmseMillipixels = NO_DATA_SENTINEL; |
| 149 | int pressureRmseMilliunits = NO_DATA_SENTINEL; |
| 150 | |
| 151 | // High-velocity errors |
| 152 | int highVelocityAlongTrajectoryRmse = NO_DATA_SENTINEL; // millipixels |
| 153 | int highVelocityOffTrajectoryRmse = NO_DATA_SENTINEL; // millipixels |
| 154 | |
| 155 | // Scale-invariant errors |
| 156 | int scaleInvariantAlongTrajectoryRmse = NO_DATA_SENTINEL; // millipixels |
| 157 | int scaleInvariantOffTrajectoryRmse = NO_DATA_SENTINEL; // millipixels |
| 158 | }; |
| 159 | |
Cody Heiner | 52db474 | 2023-06-29 13:19:01 -0700 | [diff] [blame] | 160 | private: |
| 161 | // The interval between consecutive predictions' target timestamps. We assume that the input |
| 162 | // interval also equals this value. |
| 163 | const nsecs_t mPredictionInterval; |
| 164 | |
| 165 | // The maximum number of input frames into the future the model can predict. |
| 166 | // Used to perform time-bucketing of metrics. |
| 167 | const size_t mMaxNumPredictions; |
| 168 | |
| 169 | // History of mMaxNumPredictions + 1 ground truth points, used to compute scale-invariant |
| 170 | // error. (Also, the last two points are used to compute the ground truth trajectory.) |
| 171 | RingBuffer<GroundTruthPoint> mRecentGroundTruthPoints; |
| 172 | |
| 173 | // Predictions having a targetTimestamp after the most recent ground truth point's timestamp. |
| 174 | // Invariant: sorted in ascending order of targetTimestamp. |
| 175 | std::vector<PredictionPoint> mRecentPredictions; |
| 176 | |
| 177 | // Containers for the intermediate representation of stroke metrics and the final atom fields. |
| 178 | // These are indexed by the number of input frames into the future being predicted minus one, |
| 179 | // and always have size mMaxNumPredictions. |
| 180 | std::vector<AggregatedStrokeMetrics> mAggregatedMetrics; |
| 181 | std::vector<AtomFields> mAtomFields; |
| 182 | |
Cody Heiner | 7b26dbe | 2023-11-14 14:47:10 -0800 | [diff] [blame] | 183 | const ReportAtomFunction mReportAtomFunction; |
Cody Heiner | 52db474 | 2023-06-29 13:19:01 -0700 | [diff] [blame] | 184 | |
| 185 | // Helper methods for the implementation of onRecord and onPredict. |
| 186 | |
| 187 | // Clears stored ground truth and prediction points, as well as all stored metrics for the |
| 188 | // current stroke. |
| 189 | void clearStrokeData(); |
| 190 | |
| 191 | // Adds the new ground truth point to mRecentGroundTruths, removes outdated predictions from |
| 192 | // mRecentPredictions, and updates the aggregated metrics to include the recent predictions that |
| 193 | // fuzzily match with the new ground truth point. |
| 194 | void incorporateNewGroundTruth(const GroundTruthPoint& groundTruthPoint); |
| 195 | |
| 196 | // Given a new prediction with targetTimestamp matching the latest ground truth point's |
| 197 | // timestamp, computes the corresponding metrics and updates mAggregatedMetrics. |
| 198 | void updateAggregatedMetrics(const PredictionPoint& predictionPoint); |
| 199 | |
| 200 | // Computes the atom fields to mAtomFields from the values in mAggregatedMetrics. |
| 201 | void computeAtomFields(); |
| 202 | |
Cody Heiner | 7b26dbe | 2023-11-14 14:47:10 -0800 | [diff] [blame] | 203 | // Reports the current data in mAtomFields by calling mReportAtomFunction. |
Cody Heiner | 52db474 | 2023-06-29 13:19:01 -0700 | [diff] [blame] | 204 | void reportMetrics(); |
Cody Heiner | 088c63e | 2023-06-15 12:06:09 -0700 | [diff] [blame] | 205 | }; |
| 206 | |
| 207 | } // namespace android |