Pass all input events to MetricsManager

The MetricsManager needs to receive UP/CANCEL events to trigger
atom reporting. I must have moved these lines around during the
refactor and overlooked this mistake.

This change also modifies MotionPredictor and MetricsManager to
hold a "ReportAtomFunction" to facilitate testing.

Test: `statsd_testdrive 718` shows atoms reported with `adb shell setenforce 0`.
Test: `atest frameworks/native/libs/input/tests/MotionPredictor_test.cpp -c` passes.
Test: `atest frameworks/native/libs/input/tests/MotionPredictorMetricsManager_test.cpp -c` passes.

Bug: 311066949

Change-Id: Icbb709bbb7cf548512e0d9aa062783d554b857e3
diff --git a/libs/input/MotionPredictor.cpp b/libs/input/MotionPredictor.cpp
index 412931b..c4e3ff6 100644
--- a/libs/input/MotionPredictor.cpp
+++ b/libs/input/MotionPredictor.cpp
@@ -60,9 +60,11 @@
 // --- MotionPredictor ---
 
 MotionPredictor::MotionPredictor(nsecs_t predictionTimestampOffsetNanos,
-                                 std::function<bool()> checkMotionPredictionEnabled)
+                                 std::function<bool()> checkMotionPredictionEnabled,
+                                 ReportAtomFunction reportAtomFunction)
       : mPredictionTimestampOffsetNanos(predictionTimestampOffsetNanos),
-        mCheckMotionPredictionEnabled(std::move(checkMotionPredictionEnabled)) {}
+        mCheckMotionPredictionEnabled(std::move(checkMotionPredictionEnabled)),
+        mReportAtomFunction(reportAtomFunction) {}
 
 android::base::Result<void> MotionPredictor::record(const MotionEvent& event) {
     if (mLastEvent && mLastEvent->getDeviceId() != event.getDeviceId()) {
@@ -90,6 +92,13 @@
         mBuffers = std::make_unique<TfLiteMotionPredictorBuffers>(mModel->inputLength());
     }
 
+    // Pass input event to the MetricsManager.
+    if (!mMetricsManager) {
+        mMetricsManager.emplace(mModel->config().predictionInterval, mModel->outputLength(),
+                                mReportAtomFunction);
+    }
+    mMetricsManager->onRecord(event);
+
     const int32_t action = event.getActionMasked();
     if (action == AMOTION_EVENT_ACTION_UP || action == AMOTION_EVENT_ACTION_CANCEL) {
         ALOGD_IF(isDebug(), "End of event stream");
@@ -135,12 +144,6 @@
     }
     mLastEvent->copyFrom(&event, /*keepHistory=*/false);
 
-    // Pass input event to the MetricsManager.
-    if (!mMetricsManager) {
-        mMetricsManager.emplace(mModel->config().predictionInterval, mModel->outputLength());
-    }
-    mMetricsManager->onRecord(event);
-
     return {};
 }
 
diff --git a/libs/input/MotionPredictorMetricsManager.cpp b/libs/input/MotionPredictorMetricsManager.cpp
index 67b1032..0412d08 100644
--- a/libs/input/MotionPredictorMetricsManager.cpp
+++ b/libs/input/MotionPredictorMetricsManager.cpp
@@ -46,13 +46,36 @@
 
 } // namespace
 
-MotionPredictorMetricsManager::MotionPredictorMetricsManager(nsecs_t predictionInterval,
-                                                             size_t maxNumPredictions)
+void MotionPredictorMetricsManager::defaultReportAtomFunction(
+        const MotionPredictorMetricsManager::AtomFields& atomFields) {
+    // Call stats_write logging function only on Android targets (not supported on host).
+#ifdef __ANDROID__
+    android::stats::libinput::
+            stats_write(android::stats::libinput::STYLUS_PREDICTION_METRICS_REPORTED,
+                            /*stylus_vendor_id=*/0,
+                            /*stylus_product_id=*/0,
+                            atomFields.deltaTimeBucketMilliseconds,
+                            atomFields.alongTrajectoryErrorMeanMillipixels,
+                            atomFields.alongTrajectoryErrorStdMillipixels,
+                            atomFields.offTrajectoryRmseMillipixels,
+                            atomFields.pressureRmseMilliunits,
+                            atomFields.highVelocityAlongTrajectoryRmse,
+                            atomFields.highVelocityOffTrajectoryRmse,
+                            atomFields.scaleInvariantAlongTrajectoryRmse,
+                            atomFields.scaleInvariantOffTrajectoryRmse);
+#endif
+}
+
+MotionPredictorMetricsManager::MotionPredictorMetricsManager(
+        nsecs_t predictionInterval,
+        size_t maxNumPredictions,
+        ReportAtomFunction reportAtomFunction)
       : mPredictionInterval(predictionInterval),
         mMaxNumPredictions(maxNumPredictions),
         mRecentGroundTruthPoints(maxNumPredictions + 1),
         mAggregatedMetrics(maxNumPredictions),
-        mAtomFields(maxNumPredictions) {}
+        mAtomFields(maxNumPredictions),
+        mReportAtomFunction(reportAtomFunction ? reportAtomFunction : defaultReportAtomFunction) {}
 
 void MotionPredictorMetricsManager::onRecord(const MotionEvent& inputEvent) {
     // Convert MotionEvent to GroundTruthPoint.
@@ -81,8 +104,8 @@
             if (mRecentGroundTruthPoints.size() >= 2) {
                 computeAtomFields();
                 reportMetrics();
-                break;
             }
+            break;
         }
     }
 }
@@ -345,28 +368,10 @@
 }
 
 void MotionPredictorMetricsManager::reportMetrics() {
-    // Report one atom for each time bucket.
+    LOG_ALWAYS_FATAL_IF(!mReportAtomFunction);
+    // Report one atom for each prediction time bucket.
     for (size_t i = 0; i < mAtomFields.size(); ++i) {
-        // Call stats_write logging function only on Android targets (not supported on host).
-#ifdef __ANDROID__
-        android::stats::libinput::
-                stats_write(android::stats::libinput::STYLUS_PREDICTION_METRICS_REPORTED,
-                            /*stylus_vendor_id=*/0,
-                            /*stylus_product_id=*/0, mAtomFields[i].deltaTimeBucketMilliseconds,
-                            mAtomFields[i].alongTrajectoryErrorMeanMillipixels,
-                            mAtomFields[i].alongTrajectoryErrorStdMillipixels,
-                            mAtomFields[i].offTrajectoryRmseMillipixels,
-                            mAtomFields[i].pressureRmseMilliunits,
-                            mAtomFields[i].highVelocityAlongTrajectoryRmse,
-                            mAtomFields[i].highVelocityOffTrajectoryRmse,
-                            mAtomFields[i].scaleInvariantAlongTrajectoryRmse,
-                            mAtomFields[i].scaleInvariantOffTrajectoryRmse);
-#endif
-    }
-
-    // Set mock atom fields, if available.
-    if (mMockLoggedAtomFields != nullptr) {
-        *mMockLoggedAtomFields = mAtomFields;
+        mReportAtomFunction(mAtomFields[i]);
     }
 }
 
diff --git a/libs/input/tests/MotionPredictorMetricsManager_test.cpp b/libs/input/tests/MotionPredictorMetricsManager_test.cpp
index b420a5a..31cc145 100644
--- a/libs/input/tests/MotionPredictorMetricsManager_test.cpp
+++ b/libs/input/tests/MotionPredictorMetricsManager_test.cpp
@@ -39,6 +39,7 @@
 using GroundTruthPoint = MotionPredictorMetricsManager::GroundTruthPoint;
 using PredictionPoint = MotionPredictorMetricsManager::PredictionPoint;
 using AtomFields = MotionPredictorMetricsManager::AtomFields;
+using ReportAtomFunction = MotionPredictorMetricsManager::ReportAtomFunction;
 
 inline constexpr int NANOS_PER_MILLIS = 1'000'000;
 
@@ -664,9 +665,16 @@
 
 // --- MotionPredictorMetricsManager tests. ---
 
-// Helper function that instantiates a MetricsManager with the given mock logged AtomFields. Takes
-// vectors of ground truth and prediction points of the same length, and passes these points to the
-// MetricsManager. The format of these vectors is expected to be:
+// Creates a mock atom reporting function that appends the reported atom to the given vector.
+ReportAtomFunction createMockReportAtomFunction(std::vector<AtomFields>& reportedAtomFields) {
+    return [&reportedAtomFields](const AtomFields& atomFields) -> void {
+        reportedAtomFields.push_back(atomFields);
+    };
+}
+
+// Helper function that instantiates a MetricsManager that reports metrics to outReportedAtomFields.
+// Takes vectors of ground truth and prediction points of the same length, and passes these points
+// to the MetricsManager. The format of these vectors is expected to be:
 //  • groundTruthPoints: chronologically-ordered ground truth points, with at least 2 elements.
 //  • predictionPoints: the first index points to a vector of predictions corresponding to the
 //    source ground truth point with the same index.
@@ -678,15 +686,16 @@
 //       prediction sets (that is, excluding the first and last). Thus, groundTruthPoints and
 //       predictionPoints should have size at least TEST_MAX_NUM_PREDICTIONS + 2.
 //
-// The passed-in outAtomFields will contain the logged AtomFields when the function returns.
+// When the function returns, outReportedAtomFields will contain the reported AtomFields.
 //
 // This function returns void so that it can use test assertions.
 void runMetricsManager(const std::vector<GroundTruthPoint>& groundTruthPoints,
                        const std::vector<std::vector<PredictionPoint>>& predictionPoints,
-                       std::vector<AtomFields>& outAtomFields) {
+                       std::vector<AtomFields>& outReportedAtomFields) {
     MotionPredictorMetricsManager metricsManager(TEST_PREDICTION_INTERVAL_NANOS,
-                                                 TEST_MAX_NUM_PREDICTIONS);
-    metricsManager.setMockLoggedAtomFields(&outAtomFields);
+                                                 TEST_MAX_NUM_PREDICTIONS,
+                                                 createMockReportAtomFunction(
+                                                         outReportedAtomFields));
 
     // Validate structure of groundTruthPoints and predictionPoints.
     ASSERT_EQ(predictionPoints.size(), groundTruthPoints.size());
@@ -712,18 +721,18 @@
 //  • Input: no prediction data.
 //  • Expectation: no metrics should be logged.
 TEST(MotionPredictorMetricsManagerTest, NoPredictions) {
-    std::vector<AtomFields> mockLoggedAtomFields;
+    std::vector<AtomFields> reportedAtomFields;
     MotionPredictorMetricsManager metricsManager(TEST_PREDICTION_INTERVAL_NANOS,
-                                                 TEST_MAX_NUM_PREDICTIONS);
-    metricsManager.setMockLoggedAtomFields(&mockLoggedAtomFields);
+                                                 TEST_MAX_NUM_PREDICTIONS,
+                                                 createMockReportAtomFunction(reportedAtomFields));
 
     metricsManager.onRecord(makeMotionEvent(
             GroundTruthPoint{{.position = Eigen::Vector2f(0, 0), .pressure = 0}, .timestamp = 0}));
     metricsManager.onRecord(makeLiftMotionEvent());
 
-    // Check that mockLoggedAtomFields is still empty (as it was initialized empty), ensuring that
+    // Check that reportedAtomFields is still empty (as it was initialized empty), ensuring that
     // no metrics were logged.
-    EXPECT_EQ(0u, mockLoggedAtomFields.size());
+    EXPECT_EQ(0u, reportedAtomFields.size());
 }
 
 // Perfect predictions test:
@@ -744,14 +753,14 @@
         groundTruthPoint.timestamp += TEST_PREDICTION_INTERVAL_NANOS;
     }
 
-    std::vector<AtomFields> atomFields;
-    runMetricsManager(groundTruthPoints, predictionPoints, atomFields);
+    std::vector<AtomFields> reportedAtomFields;
+    runMetricsManager(groundTruthPoints, predictionPoints, reportedAtomFields);
 
-    ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, atomFields.size());
+    ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, reportedAtomFields.size());
     // Check that errors are all zero, or NO_DATA_SENTINEL for unreported metrics.
-    for (size_t i = 0; i < atomFields.size(); ++i) {
+    for (size_t i = 0; i < reportedAtomFields.size(); ++i) {
         SCOPED_TRACE(testing::Message() << "i = " << i);
-        const AtomFields& atom = atomFields[i];
+        const AtomFields& atom = reportedAtomFields[i];
         const nsecs_t deltaTimeBucketNanos = TEST_PREDICTION_INTERVAL_NANOS * (i + 1);
         EXPECT_EQ(deltaTimeBucketNanos / NANOS_PER_MILLIS, atom.deltaTimeBucketMilliseconds);
         // General errors: reported for every time bucket.
@@ -764,7 +773,7 @@
         EXPECT_EQ(NO_DATA_SENTINEL, atom.highVelocityAlongTrajectoryRmse);
         EXPECT_EQ(NO_DATA_SENTINEL, atom.highVelocityOffTrajectoryRmse);
         // Scale-invariant errors: reported only for the last time bucket.
-        if (i + 1 == atomFields.size()) {
+        if (i + 1 == reportedAtomFields.size()) {
             EXPECT_EQ(0, atom.scaleInvariantAlongTrajectoryRmse);
             EXPECT_EQ(0, atom.scaleInvariantOffTrajectoryRmse);
         } else {
@@ -801,14 +810,14 @@
             computePressureRmses(groundTruthPoints, predictionPoints);
 
     // Run test.
-    std::vector<AtomFields> atomFields;
-    runMetricsManager(groundTruthPoints, predictionPoints, atomFields);
+    std::vector<AtomFields> reportedAtomFields;
+    runMetricsManager(groundTruthPoints, predictionPoints, reportedAtomFields);
 
     // Check logged metrics match expectations.
-    ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, atomFields.size());
-    for (size_t i = 0; i < atomFields.size(); ++i) {
+    ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, reportedAtomFields.size());
+    for (size_t i = 0; i < reportedAtomFields.size(); ++i) {
         SCOPED_TRACE(testing::Message() << "i = " << i);
-        const AtomFields& atom = atomFields[i];
+        const AtomFields& atom = reportedAtomFields[i];
         // Check time bucket delta matches expectation based on index and prediction interval.
         const nsecs_t deltaTimeBucketNanos = TEST_PREDICTION_INTERVAL_NANOS * (i + 1);
         EXPECT_EQ(deltaTimeBucketNanos / NANOS_PER_MILLIS, atom.deltaTimeBucketMilliseconds);
@@ -845,14 +854,14 @@
             computeGeneralPositionErrors(groundTruthPoints, predictionPoints);
 
     // Run test.
-    std::vector<AtomFields> atomFields;
-    runMetricsManager(groundTruthPoints, predictionPoints, atomFields);
+    std::vector<AtomFields> reportedAtomFields;
+    runMetricsManager(groundTruthPoints, predictionPoints, reportedAtomFields);
 
     // Check logged metrics match expectations.
-    ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, atomFields.size());
-    for (size_t i = 0; i < atomFields.size(); ++i) {
+    ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, reportedAtomFields.size());
+    for (size_t i = 0; i < reportedAtomFields.size(); ++i) {
         SCOPED_TRACE(testing::Message() << "i = " << i);
-        const AtomFields& atom = atomFields[i];
+        const AtomFields& atom = reportedAtomFields[i];
         // Check time bucket delta matches expectation based on index and prediction interval.
         const nsecs_t deltaTimeBucketNanos = TEST_PREDICTION_INTERVAL_NANOS * (i + 1);
         EXPECT_EQ(deltaTimeBucketNanos / NANOS_PER_MILLIS, atom.deltaTimeBucketMilliseconds);
@@ -896,14 +905,14 @@
             computeGeneralPositionErrors(groundTruthPoints, predictionPoints);
 
     // Run test.
-    std::vector<AtomFields> atomFields;
-    runMetricsManager(groundTruthPoints, predictionPoints, atomFields);
+    std::vector<AtomFields> reportedAtomFields;
+    runMetricsManager(groundTruthPoints, predictionPoints, reportedAtomFields);
 
     // Check logged metrics match expectations.
-    ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, atomFields.size());
-    for (size_t i = 0; i < atomFields.size(); ++i) {
+    ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, reportedAtomFields.size());
+    for (size_t i = 0; i < reportedAtomFields.size(); ++i) {
         SCOPED_TRACE(testing::Message() << "i = " << i);
-        const AtomFields& atom = atomFields[i];
+        const AtomFields& atom = reportedAtomFields[i];
         const nsecs_t deltaTimeBucketNanos = TEST_PREDICTION_INTERVAL_NANOS * (i + 1);
         EXPECT_EQ(deltaTimeBucketNanos / NANOS_PER_MILLIS, atom.deltaTimeBucketMilliseconds);
 
@@ -926,7 +935,7 @@
         // to general errors (where reported).
         //
         // As above, use absolute value for RMSE, since it must be non-negative.
-        if (i + 2 >= atomFields.size()) {
+        if (i + 2 >= reportedAtomFields.size()) {
             EXPECT_NEAR(static_cast<int>(
                                 1000 * std::abs(generalPositionErrors[i].alongTrajectoryErrorMean)),
                         atom.highVelocityAlongTrajectoryRmse, 1);
@@ -946,7 +955,7 @@
         // to scale-invariant errors by dividing by `strokeVelocty * TEST_MAX_NUM_PREDICTIONS`.
         //
         // As above, use absolute value for RMSE, since it must be non-negative.
-        if (i + 1 == atomFields.size()) {
+        if (i + 1 == reportedAtomFields.size()) {
             const float pathLength = strokeVelocity * TEST_MAX_NUM_PREDICTIONS;
             std::vector<float> alongTrajectoryAbsoluteErrors;
             std::vector<float> offTrajectoryAbsoluteErrors;
diff --git a/libs/input/tests/MotionPredictor_test.cpp b/libs/input/tests/MotionPredictor_test.cpp
index 4ac7ae9..3343114 100644
--- a/libs/input/tests/MotionPredictor_test.cpp
+++ b/libs/input/tests/MotionPredictor_test.cpp
@@ -147,4 +147,35 @@
     ASSERT_FALSE(predictor.isPredictionAvailable(/*deviceId=*/1, AINPUT_SOURCE_TOUCHSCREEN));
 }
 
+using AtomFields = MotionPredictorMetricsManager::AtomFields;
+using ReportAtomFunction = MotionPredictorMetricsManager::ReportAtomFunction;
+
+// Creates a mock atom reporting function that appends the reported atom to the given vector.
+// The passed-in pointer must not be nullptr.
+ReportAtomFunction createMockReportAtomFunction(std::vector<AtomFields>* reportedAtomFields) {
+    return [reportedAtomFields](const AtomFields& atomFields) -> void {
+        reportedAtomFields->push_back(atomFields);
+    };
+}
+
+TEST(MotionPredictorMetricsManagerIntegrationTest, ReportsMetrics) {
+    std::vector<AtomFields> reportedAtomFields;
+    MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0,
+                              []() { return true /*enable prediction*/; },
+                              createMockReportAtomFunction(&reportedAtomFields));
+
+    ASSERT_TRUE(predictor.record(getMotionEvent(DOWN, 1, 1, 0ms, /*deviceId=*/0)).ok());
+    ASSERT_TRUE(predictor.record(getMotionEvent(MOVE, 2, 2, 4ms, /*deviceId=*/0)).ok());
+    ASSERT_TRUE(predictor.record(getMotionEvent(MOVE, 3, 3, 8ms, /*deviceId=*/0)).ok());
+    ASSERT_TRUE(predictor.record(getMotionEvent(MOVE, 4, 4, 12ms, /*deviceId=*/0)).ok());
+    ASSERT_TRUE(predictor.record(getMotionEvent(MOVE, 5, 5, 16ms, /*deviceId=*/0)).ok());
+    ASSERT_TRUE(predictor.record(getMotionEvent(MOVE, 6, 6, 20ms, /*deviceId=*/0)).ok());
+    ASSERT_TRUE(predictor.record(getMotionEvent(UP, 7, 7, 24ms, /*deviceId=*/0)).ok());
+
+    // The number of atoms reported should equal the number of prediction time buckets, which is
+    // given by the prediction model's output length. For now, this value is always 5, and we
+    // hardcode it because it's not publicly accessible from the MotionPredictor.
+    EXPECT_EQ(5u, reportedAtomFields.size());
+}
+
 } // namespace android