Break down input device usage by source

Bug: 275726706
Test: atest inputflinger_tests
Test: statsd_testdrive
Change-Id: Ic652ecd7c66b9e9c3b46d84454485cf51140a14b
diff --git a/services/inputflinger/InputDeviceMetricsCollector.cpp b/services/inputflinger/InputDeviceMetricsCollector.cpp
index 71f3717..3e25cc3 100644
--- a/services/inputflinger/InputDeviceMetricsCollector.cpp
+++ b/services/inputflinger/InputDeviceMetricsCollector.cpp
@@ -53,18 +53,28 @@
     nanoseconds getCurrentTime() override { return nanoseconds(systemTime(SYSTEM_TIME_MONOTONIC)); }
 
     void logInputDeviceUsageReported(const InputDeviceIdentifier& identifier,
-                                     nanoseconds sessionDuration) override {
+                                     const DeviceUsageReport& report) override {
         const int32_t durationMillis =
-                std::chrono::duration_cast<std::chrono::milliseconds>(sessionDuration).count();
+                std::chrono::duration_cast<std::chrono::milliseconds>(report.usageDuration).count();
         const static std::vector<int32_t> empty;
 
         ALOGD_IF(DEBUG, "Usage session reported for device: %s", identifier.name.c_str());
         ALOGD_IF(DEBUG, "    Total duration: %dms", durationMillis);
+        ALOGD_IF(DEBUG, "    Source breakdown:");
+
+        std::vector<int32_t> sources;
+        std::vector<int32_t> durationsPerSource;
+        for (auto& [src, dur] : report.sourceBreakdown) {
+            sources.push_back(ftl::to_underlying(src));
+            int32_t durMillis = std::chrono::duration_cast<std::chrono::milliseconds>(dur).count();
+            durationsPerSource.emplace_back(durMillis);
+            ALOGD_IF(DEBUG, "        - usageSource: %s\t duration: %dms",
+                     ftl::enum_string(src).c_str(), durMillis);
+        }
 
         util::stats_write(util::INPUTDEVICE_USAGE_REPORTED, identifier.vendor, identifier.product,
                           identifier.version, linuxBusToInputDeviceBusEnum(identifier.bus),
-                          durationMillis, /*usage_sources=*/empty,
-                          /*usage_durations_per_source=*/empty, /*uids=*/empty,
+                          durationMillis, sources, durationsPerSource, /*uids=*/empty,
                           /*usage_durations_per_uid=*/empty);
     }
 } sStatsdLogger;
@@ -181,54 +191,58 @@
 
 void InputDeviceMetricsCollector::notifyInputDevicesChanged(
         const NotifyInputDevicesChangedArgs& args) {
-    processUsages();
+    reportCompletedSessions();
     onInputDevicesChanged(args.inputDeviceInfos);
     mNextListener.notify(args);
 }
 
 void InputDeviceMetricsCollector::notifyConfigurationChanged(
         const NotifyConfigurationChangedArgs& args) {
-    processUsages();
+    reportCompletedSessions();
     mNextListener.notify(args);
 }
 
 void InputDeviceMetricsCollector::notifyKey(const NotifyKeyArgs& args) {
-    processUsages();
-    onInputDeviceUsage(DeviceId{args.deviceId}, nanoseconds(args.eventTime));
+    reportCompletedSessions();
+    const SourceProvider getSources = [&args](const InputDeviceInfo& info) {
+        return std::set{getUsageSourceForKeyArgs(info, args)};
+    };
+    onInputDeviceUsage(DeviceId{args.deviceId}, nanoseconds(args.eventTime), getSources);
 
     mNextListener.notify(args);
 }
 
 void InputDeviceMetricsCollector::notifyMotion(const NotifyMotionArgs& args) {
-    processUsages();
-    onInputDeviceUsage(DeviceId{args.deviceId}, nanoseconds(args.eventTime));
+    reportCompletedSessions();
+    onInputDeviceUsage(DeviceId{args.deviceId}, nanoseconds(args.eventTime),
+                       [&args](const auto&) { return getUsageSourcesForMotionArgs(args); });
 
     mNextListener.notify(args);
 }
 
 void InputDeviceMetricsCollector::notifySwitch(const NotifySwitchArgs& args) {
-    processUsages();
+    reportCompletedSessions();
     mNextListener.notify(args);
 }
 
 void InputDeviceMetricsCollector::notifySensor(const NotifySensorArgs& args) {
-    processUsages();
+    reportCompletedSessions();
     mNextListener.notify(args);
 }
 
 void InputDeviceMetricsCollector::notifyVibratorState(const NotifyVibratorStateArgs& args) {
-    processUsages();
+    reportCompletedSessions();
     mNextListener.notify(args);
 }
 
 void InputDeviceMetricsCollector::notifyDeviceReset(const NotifyDeviceResetArgs& args) {
-    processUsages();
+    reportCompletedSessions();
     mNextListener.notify(args);
 }
 
 void InputDeviceMetricsCollector::notifyPointerCaptureChanged(
         const NotifyPointerCaptureChangedArgs& args) {
-    processUsages();
+    reportCompletedSessions();
     mNextListener.notify(args);
 }
 
@@ -241,69 +255,124 @@
 }
 
 void InputDeviceMetricsCollector::onInputDevicesChanged(const std::vector<InputDeviceInfo>& infos) {
-    std::map<DeviceId, InputDeviceIdentifier> newDeviceIds;
+    std::map<DeviceId, InputDeviceInfo> newDeviceInfos;
 
     for (const InputDeviceInfo& info : infos) {
         if (isIgnoredInputDeviceId(info.getId())) {
             continue;
         }
-        newDeviceIds.emplace(info.getId(), info.getIdentifier());
+        newDeviceInfos.emplace(info.getId(), info);
     }
 
-    for (auto [deviceId, identifier] : mLoggedDeviceInfos) {
-        if (newDeviceIds.count(deviceId) != 0) {
+    for (auto [deviceId, info] : mLoggedDeviceInfos) {
+        if (newDeviceInfos.count(deviceId) != 0) {
             continue;
         }
-        onInputDeviceRemoved(deviceId, identifier);
+        onInputDeviceRemoved(deviceId, info.getIdentifier());
     }
 
-    std::swap(newDeviceIds, mLoggedDeviceInfos);
+    std::swap(newDeviceInfos, mLoggedDeviceInfos);
 }
 
 void InputDeviceMetricsCollector::onInputDeviceRemoved(DeviceId deviceId,
                                                        const InputDeviceIdentifier& identifier) {
-    // Report usage for that device if there is an active session.
     auto it = mActiveUsageSessions.find(deviceId);
-    if (it != mActiveUsageSessions.end()) {
-        mLogger.logInputDeviceUsageReported(identifier, it->second.end - it->second.start);
-        mActiveUsageSessions.erase(it);
+    if (it == mActiveUsageSessions.end()) {
+        return;
     }
+    // Report usage for that device if there is an active session.
+    auto& [_, activeSession] = *it;
+    mLogger.logInputDeviceUsageReported(identifier, activeSession.finishSession());
+    mActiveUsageSessions.erase(it);
+
     // We don't remove this from mLoggedDeviceInfos because it will be updated in
     // onInputDevicesChanged().
 }
 
-void InputDeviceMetricsCollector::onInputDeviceUsage(DeviceId deviceId, nanoseconds eventTime) {
-    if (mLoggedDeviceInfos.count(deviceId) == 0) {
+void InputDeviceMetricsCollector::onInputDeviceUsage(DeviceId deviceId, nanoseconds eventTime,
+                                                     const SourceProvider& getSources) {
+    auto infoIt = mLoggedDeviceInfos.find(deviceId);
+    if (infoIt == mLoggedDeviceInfos.end()) {
         // Do not track usage for devices that are not logged.
         return;
     }
 
-    auto [it, inserted] = mActiveUsageSessions.try_emplace(deviceId, eventTime, eventTime);
-    if (!inserted) {
-        it->second.end = eventTime;
+    auto [sessionIt, _] =
+            mActiveUsageSessions.try_emplace(deviceId, mUsageSessionTimeout, eventTime);
+    for (InputDeviceUsageSource source : getSources(infoIt->second)) {
+        sessionIt->second.recordUsage(eventTime, source);
     }
 }
 
-void InputDeviceMetricsCollector::processUsages() {
-    const auto usageSessionExpiryTime = mLogger.getCurrentTime() - mUsageSessionTimeout;
+void InputDeviceMetricsCollector::reportCompletedSessions() {
+    const auto currentTime = mLogger.getCurrentTime();
 
     std::vector<DeviceId> completedUsageSessions;
 
-    for (const auto& [deviceId, usageSession] : mActiveUsageSessions) {
-        if (usageSession.end <= usageSessionExpiryTime) {
+    for (auto& [deviceId, activeSession] : mActiveUsageSessions) {
+        if (activeSession.checkIfCompletedAt(currentTime)) {
             completedUsageSessions.emplace_back(deviceId);
         }
     }
 
     for (DeviceId deviceId : completedUsageSessions) {
-        const auto it = mLoggedDeviceInfos.find(deviceId);
-        LOG_ALWAYS_FATAL_IF(it == mLoggedDeviceInfos.end());
+        const auto infoIt = mLoggedDeviceInfos.find(deviceId);
+        LOG_ALWAYS_FATAL_IF(infoIt == mLoggedDeviceInfos.end());
 
-        const auto& session = mActiveUsageSessions[deviceId];
-        mLogger.logInputDeviceUsageReported(it->second, session.end - session.start);
-
-        mActiveUsageSessions.erase(deviceId);
+        auto activeSessionIt = mActiveUsageSessions.find(deviceId);
+        LOG_ALWAYS_FATAL_IF(activeSessionIt == mActiveUsageSessions.end());
+        auto& [_, activeSession] = *activeSessionIt;
+        mLogger.logInputDeviceUsageReported(infoIt->second.getIdentifier(),
+                                            activeSession.finishSession());
+        mActiveUsageSessions.erase(activeSessionIt);
     }
 }
 
+// --- InputDeviceMetricsCollector::ActiveSession ---
+
+InputDeviceMetricsCollector::ActiveSession::ActiveSession(nanoseconds usageSessionTimeout,
+                                                          nanoseconds startTime)
+      : mUsageSessionTimeout(usageSessionTimeout), mDeviceSession({startTime, startTime}) {}
+
+void InputDeviceMetricsCollector::ActiveSession::recordUsage(nanoseconds eventTime,
+                                                             InputDeviceUsageSource source) {
+    // We assume that event times for subsequent events are always monotonically increasing for each
+    // input device.
+    auto [activeSourceIt, inserted] =
+            mActiveSessionsBySource.try_emplace(source, eventTime, eventTime);
+    if (!inserted) {
+        activeSourceIt->second.end = eventTime;
+    }
+    mDeviceSession.end = eventTime;
+}
+
+bool InputDeviceMetricsCollector::ActiveSession::checkIfCompletedAt(nanoseconds timestamp) {
+    const auto sessionExpiryTime = timestamp - mUsageSessionTimeout;
+    std::vector<InputDeviceUsageSource> completedSourceSessionsForDevice;
+    for (auto& [source, session] : mActiveSessionsBySource) {
+        if (session.end <= sessionExpiryTime) {
+            completedSourceSessionsForDevice.emplace_back(source);
+        }
+    }
+    for (InputDeviceUsageSource source : completedSourceSessionsForDevice) {
+        auto it = mActiveSessionsBySource.find(source);
+        const auto& [_, session] = *it;
+        mSourceUsageBreakdown.emplace_back(source, session.end - session.start);
+        mActiveSessionsBySource.erase(it);
+    }
+    return mActiveSessionsBySource.empty();
+}
+
+InputDeviceMetricsLogger::DeviceUsageReport
+InputDeviceMetricsCollector::ActiveSession::finishSession() {
+    const auto deviceUsageDuration = mDeviceSession.end - mDeviceSession.start;
+
+    for (const auto& [source, sourceSession] : mActiveSessionsBySource) {
+        mSourceUsageBreakdown.emplace_back(source, sourceSession.end - sourceSession.start);
+    }
+    mActiveSessionsBySource.clear();
+
+    return {deviceUsageDuration, mSourceUsageBreakdown};
+}
+
 } // namespace android
diff --git a/services/inputflinger/InputDeviceMetricsCollector.h b/services/inputflinger/InputDeviceMetricsCollector.h
index 4ef860f..e2e79e4 100644
--- a/services/inputflinger/InputDeviceMetricsCollector.h
+++ b/services/inputflinger/InputDeviceMetricsCollector.h
@@ -23,6 +23,7 @@
 #include <input/InputDevice.h>
 #include <statslog.h>
 #include <chrono>
+#include <functional>
 #include <map>
 #include <set>
 #include <vector>
@@ -79,8 +80,25 @@
 class InputDeviceMetricsLogger {
 public:
     virtual std::chrono::nanoseconds getCurrentTime() = 0;
+
+    // Describes the breakdown of an input device usage session by its usage sources.
+    // An input device can have more than one usage source. For example, some game controllers have
+    // buttons, joysticks, and touchpads. We track usage by these sources to get a better picture of
+    // the device usage. The source breakdown of a 10 minute usage session could look like this:
+    //   { {GAMEPAD, <9 mins>}, {TOUCHPAD, <2 mins>}, {TOUCHPAD, <3 mins>} }
+    // This would indicate that the GAMEPAD source was used first, and that source usage session
+    // lasted for 9 mins. During that time, the TOUCHPAD was used for 2 mins, until its source
+    // usage session expired. The TOUCHPAD was then used again later for another 3 mins.
+    using SourceUsageBreakdown =
+            std::vector<std::pair<InputDeviceUsageSource, std::chrono::nanoseconds /*duration*/>>;
+
+    struct DeviceUsageReport {
+        std::chrono::nanoseconds usageDuration;
+        SourceUsageBreakdown sourceBreakdown;
+    };
+
     virtual void logInputDeviceUsageReported(const InputDeviceIdentifier&,
-                                             std::chrono::nanoseconds duration) = 0;
+                                             const DeviceUsageReport&) = 0;
     virtual ~InputDeviceMetricsLogger() = default;
 };
 
@@ -116,23 +134,42 @@
                       ftl::Orderable<DeviceId> {
         using Constructible::Constructible;
     };
-    static std::string toString(const DeviceId& id) {
+    static inline std::string toString(const DeviceId& id) {
         return std::to_string(ftl::to_underlying(id));
     }
 
-    std::map<DeviceId, InputDeviceIdentifier> mLoggedDeviceInfos;
+    std::map<DeviceId, InputDeviceInfo> mLoggedDeviceInfos;
 
-    struct UsageSession {
-        std::chrono::nanoseconds start;
-        std::chrono::nanoseconds end;
+    class ActiveSession {
+    public:
+        explicit ActiveSession(std::chrono::nanoseconds usageSessionTimeout,
+                               std::chrono::nanoseconds startTime);
+        void recordUsage(std::chrono::nanoseconds eventTime, InputDeviceUsageSource source);
+        bool checkIfCompletedAt(std::chrono::nanoseconds timestamp);
+        InputDeviceMetricsLogger::DeviceUsageReport finishSession();
+
+    private:
+        struct UsageSession {
+            std::chrono::nanoseconds start{};
+            std::chrono::nanoseconds end{};
+        };
+
+        const std::chrono::nanoseconds mUsageSessionTimeout;
+        UsageSession mDeviceSession{};
+
+        std::map<InputDeviceUsageSource, UsageSession> mActiveSessionsBySource{};
+        InputDeviceMetricsLogger::SourceUsageBreakdown mSourceUsageBreakdown{};
     };
+
     // The input devices that currently have active usage sessions.
-    std::map<DeviceId, UsageSession> mActiveUsageSessions;
+    std::map<DeviceId, ActiveSession> mActiveUsageSessions;
 
     void onInputDevicesChanged(const std::vector<InputDeviceInfo>& infos);
     void onInputDeviceRemoved(DeviceId deviceId, const InputDeviceIdentifier& identifier);
-    void onInputDeviceUsage(DeviceId deviceId, std::chrono::nanoseconds eventTime);
-    void processUsages();
+    using SourceProvider = std::function<std::set<InputDeviceUsageSource>(const InputDeviceInfo&)>;
+    void onInputDeviceUsage(DeviceId deviceId, std::chrono::nanoseconds eventTime,
+                            const SourceProvider& getSources);
+    void reportCompletedSessions();
 };
 
 } // namespace android
diff --git a/services/inputflinger/tests/InputDeviceMetricsCollector_test.cpp b/services/inputflinger/tests/InputDeviceMetricsCollector_test.cpp
index 85529bd..e38f88c 100644
--- a/services/inputflinger/tests/InputDeviceMetricsCollector_test.cpp
+++ b/services/inputflinger/tests/InputDeviceMetricsCollector_test.cpp
@@ -336,16 +336,20 @@
 
 // --- InputDeviceMetricsCollectorTest ---
 
-class InputDeviceMetricsCollectorTest : public testing::Test, InputDeviceMetricsLogger {
+class InputDeviceMetricsCollectorTest : public testing::Test, public InputDeviceMetricsLogger {
 protected:
     TestInputListener mTestListener;
     InputDeviceMetricsCollector mMetricsCollector{mTestListener, *this, USAGE_TIMEOUT};
 
-    void assertUsageLogged(InputDeviceIdentifier identifier, nanoseconds duration) {
+    void assertUsageLogged(const InputDeviceIdentifier& identifier, nanoseconds duration,
+                           std::optional<SourceUsageBreakdown> sourceBreakdown = {}) {
         ASSERT_GE(mLoggedUsageSessions.size(), 1u);
-        const auto& session = *mLoggedUsageSessions.begin();
-        ASSERT_EQ(identifier, std::get<InputDeviceIdentifier>(session));
-        ASSERT_EQ(duration, std::get<nanoseconds>(session));
+        const auto& [loggedIdentifier, report] = *mLoggedUsageSessions.begin();
+        ASSERT_EQ(identifier, loggedIdentifier);
+        ASSERT_EQ(duration, report.usageDuration);
+        if (sourceBreakdown) {
+            ASSERT_EQ(sourceBreakdown, report.sourceBreakdown);
+        }
         mLoggedUsageSessions.erase(mLoggedUsageSessions.begin());
     }
 
@@ -367,14 +371,14 @@
     }
 
 private:
-    std::vector<std::tuple<InputDeviceIdentifier, nanoseconds>> mLoggedUsageSessions;
+    std::vector<std::tuple<InputDeviceIdentifier, DeviceUsageReport>> mLoggedUsageSessions;
     nanoseconds mCurrentTime{TIME};
 
     nanoseconds getCurrentTime() override { return mCurrentTime; }
 
     void logInputDeviceUsageReported(const InputDeviceIdentifier& identifier,
-                                     nanoseconds duration) override {
-        mLoggedUsageSessions.emplace_back(identifier, duration);
+                                     const DeviceUsageReport& report) override {
+        mLoggedUsageSessions.emplace_back(identifier, report);
     }
 };
 
@@ -509,4 +513,113 @@
     ASSERT_NO_FATAL_FAILURE(assertUsageNotLogged());
 }
 
+TEST_F(InputDeviceMetricsCollectorTest, BreakdownUsageBySource) {
+    mMetricsCollector.notifyInputDevicesChanged({/*id=*/0, {generateTestDeviceInfo()}});
+    InputDeviceMetricsLogger::SourceUsageBreakdown expectedSourceBreakdown;
+
+    // Use touchscreen.
+    mMetricsCollector.notifyMotion(generateMotionArgs(DEVICE_ID, TOUCHSCREEN));
+    setCurrentTime(TIME + 100ns);
+    mMetricsCollector.notifyMotion(generateMotionArgs(DEVICE_ID, TOUCHSCREEN));
+    ASSERT_NO_FATAL_FAILURE(assertUsageNotLogged());
+
+    // Use a stylus with the same input device.
+    setCurrentTime(TIME + 200ns);
+    mMetricsCollector.notifyMotion(generateMotionArgs(DEVICE_ID, STYLUS, {ToolType::STYLUS}));
+    setCurrentTime(TIME + 400ns);
+    mMetricsCollector.notifyMotion(generateMotionArgs(DEVICE_ID, STYLUS, {ToolType::STYLUS}));
+    ASSERT_NO_FATAL_FAILURE(assertUsageNotLogged());
+
+    // Touchscreen was used again after its usage timeout expired.
+    // This should be tracked as a separate usage of the source in the breakdown.
+    setCurrentTime(TIME + 300ns + USAGE_TIMEOUT);
+    mMetricsCollector.notifyMotion(generateMotionArgs(DEVICE_ID));
+    expectedSourceBreakdown.emplace_back(InputDeviceUsageSource::TOUCHSCREEN, 100ns);
+    ASSERT_NO_FATAL_FAILURE(assertUsageNotLogged());
+
+    // Continue stylus and touchscreen usages.
+    setCurrentTime(TIME + 350ns + USAGE_TIMEOUT);
+    mMetricsCollector.notifyMotion(generateMotionArgs(DEVICE_ID, STYLUS, {ToolType::STYLUS}));
+    setCurrentTime(TIME + 450ns + USAGE_TIMEOUT);
+    mMetricsCollector.notifyMotion(generateMotionArgs(DEVICE_ID, TOUCHSCREEN));
+    ASSERT_NO_FATAL_FAILURE(assertUsageNotLogged());
+
+    // Touchscreen was used after the stylus's usage timeout expired.
+    // The stylus usage should be tracked in the source breakdown.
+    setCurrentTime(TIME + 400ns + USAGE_TIMEOUT + USAGE_TIMEOUT);
+    mMetricsCollector.notifyMotion(generateMotionArgs(DEVICE_ID, TOUCHSCREEN));
+    expectedSourceBreakdown.emplace_back(InputDeviceUsageSource::STYLUS_DIRECT,
+                                         150ns + USAGE_TIMEOUT);
+    ASSERT_NO_FATAL_FAILURE(assertUsageNotLogged());
+
+    // Remove all devices to force the usage session to be logged.
+    setCurrentTime(TIME + 500ns + USAGE_TIMEOUT);
+    mMetricsCollector.notifyInputDevicesChanged({});
+    expectedSourceBreakdown.emplace_back(InputDeviceUsageSource::TOUCHSCREEN,
+                                         100ns + USAGE_TIMEOUT);
+    // Verify that only one usage session was logged for the device, and that session was broken
+    // down by source correctly.
+    ASSERT_NO_FATAL_FAILURE(assertUsageLogged(getIdentifier(),
+                                              400ns + USAGE_TIMEOUT + USAGE_TIMEOUT,
+                                              expectedSourceBreakdown));
+
+    ASSERT_NO_FATAL_FAILURE(assertUsageNotLogged());
+}
+
+TEST_F(InputDeviceMetricsCollectorTest, BreakdownUsageBySource_TrackSourceByDevice) {
+    mMetricsCollector.notifyInputDevicesChanged(
+            {/*id=*/0, {generateTestDeviceInfo(DEVICE_ID), generateTestDeviceInfo(DEVICE_ID_2)}});
+    InputDeviceMetricsLogger::SourceUsageBreakdown expectedSourceBreakdown1;
+    InputDeviceMetricsLogger::SourceUsageBreakdown expectedSourceBreakdown2;
+
+    // Use both devices, with different sources.
+    mMetricsCollector.notifyMotion(generateMotionArgs(DEVICE_ID, TOUCHSCREEN));
+    mMetricsCollector.notifyMotion(generateMotionArgs(DEVICE_ID_2, STYLUS, {ToolType::STYLUS}));
+    setCurrentTime(TIME + 100ns);
+    mMetricsCollector.notifyMotion(generateMotionArgs(DEVICE_ID, TOUCHSCREEN));
+    mMetricsCollector.notifyMotion(generateMotionArgs(DEVICE_ID_2, STYLUS, {ToolType::STYLUS}));
+    ASSERT_NO_FATAL_FAILURE(assertUsageNotLogged());
+
+    // Remove all devices to force the usage session to be logged.
+    mMetricsCollector.notifyInputDevicesChanged({});
+    expectedSourceBreakdown1.emplace_back(InputDeviceUsageSource::TOUCHSCREEN, 100ns);
+    expectedSourceBreakdown2.emplace_back(InputDeviceUsageSource::STYLUS_DIRECT, 100ns);
+    ASSERT_NO_FATAL_FAILURE(
+            assertUsageLogged(getIdentifier(DEVICE_ID), 100ns, expectedSourceBreakdown1));
+    ASSERT_NO_FATAL_FAILURE(
+            assertUsageLogged(getIdentifier(DEVICE_ID_2), 100ns, expectedSourceBreakdown2));
+
+    ASSERT_NO_FATAL_FAILURE(assertUsageNotLogged());
+}
+
+TEST_F(InputDeviceMetricsCollectorTest, BreakdownUsageBySource_MultiSourceEvent) {
+    mMetricsCollector.notifyInputDevicesChanged({/*id=*/0, {generateTestDeviceInfo(DEVICE_ID)}});
+    InputDeviceMetricsLogger::SourceUsageBreakdown expectedSourceBreakdown;
+
+    mMetricsCollector.notifyMotion(generateMotionArgs(DEVICE_ID, TOUCHSCREEN | STYLUS, //
+                                                      {ToolType::STYLUS}));
+    setCurrentTime(TIME + 100ns);
+    mMetricsCollector.notifyMotion(generateMotionArgs(DEVICE_ID, TOUCHSCREEN | STYLUS, //
+                                                      {ToolType::STYLUS, ToolType::FINGER}));
+    setCurrentTime(TIME + 200ns);
+    mMetricsCollector.notifyMotion(generateMotionArgs(DEVICE_ID, TOUCHSCREEN | STYLUS, //
+                                                      {ToolType::STYLUS, ToolType::FINGER}));
+    setCurrentTime(TIME + 300ns);
+    mMetricsCollector.notifyMotion(generateMotionArgs(DEVICE_ID, TOUCHSCREEN | STYLUS, //
+                                                      {ToolType::FINGER}));
+    setCurrentTime(TIME + 400ns);
+    mMetricsCollector.notifyMotion(generateMotionArgs(DEVICE_ID, TOUCHSCREEN | STYLUS, //
+                                                      {ToolType::FINGER}));
+    ASSERT_NO_FATAL_FAILURE(assertUsageNotLogged());
+
+    // Remove all devices to force the usage session to be logged.
+    mMetricsCollector.notifyInputDevicesChanged({});
+    expectedSourceBreakdown.emplace_back(InputDeviceUsageSource::STYLUS_DIRECT, 200ns);
+    expectedSourceBreakdown.emplace_back(InputDeviceUsageSource::TOUCHSCREEN, 300ns);
+    ASSERT_NO_FATAL_FAILURE(
+            assertUsageLogged(getIdentifier(DEVICE_ID), 400ns, expectedSourceBreakdown));
+
+    ASSERT_NO_FATAL_FAILURE(assertUsageNotLogged());
+}
+
 } // namespace android