Add reset() and setEnabled() to MultiStateCounter

Bug: 197162116
Test: atest libbattery_test
Change-Id: I14006af5e6a3e16d7849beb6def53de77b15b7bc
diff --git a/libs/battery/MultiStateCounter.h b/libs/battery/MultiStateCounter.h
index 40de068..e1ee07c 100644
--- a/libs/battery/MultiStateCounter.h
+++ b/libs/battery/MultiStateCounter.h
@@ -42,6 +42,7 @@
     T lastValue;
     time_t lastUpdateTimestamp;
     T deltaValue;
+    bool isEnabled;
 
     struct State {
         time_t timeInStateSinceUpdate;
@@ -55,12 +56,16 @@
 
     virtual ~MultiStateCounter();
 
+    void setEnabled(bool enabled, time_t timestamp);
+
     void setState(state_t state, time_t timestamp);
 
     void setValue(state_t state, const T& value);
 
     void updateValue(const T& value, time_t timestamp);
 
+    void reset();
+
     uint16_t getStateCount();
 
     const T& getCount(state_t state);
@@ -96,7 +101,8 @@
         emptyValue(emptyValue),
         lastValue(emptyValue),
         lastUpdateTimestamp(-1),
-        deltaValue(emptyValue) {
+        deltaValue(emptyValue),
+        isEnabled(true) {
     states = new State[stateCount];
     for (int i = 0; i < stateCount; i++) {
         states[i].timeInStateSinceUpdate = 0;
@@ -110,8 +116,27 @@
 };
 
 template <class T>
-void MultiStateCounter<T>::setState(state_t state, time_t timestamp) {
+void MultiStateCounter<T>::setEnabled(bool enabled, time_t timestamp) {
+    if (enabled == isEnabled) {
+        return;
+    }
+
+    if (!enabled) {
+        // Confirm the current state for the side-effect of updating the time-in-state
+        // counter for the current state.
+        setState(currentState, timestamp);
+    }
+
+    isEnabled = enabled;
+
     if (lastStateChangeTimestamp >= 0) {
+        lastStateChangeTimestamp = timestamp;
+    }
+}
+
+template <class T>
+void MultiStateCounter<T>::setState(state_t state, time_t timestamp) {
+    if (isEnabled && lastStateChangeTimestamp >= 0) {
         if (timestamp >= lastStateChangeTimestamp) {
             states[currentState].timeInStateSinceUpdate += timestamp - lastStateChangeTimestamp;
         } else {
@@ -137,31 +162,35 @@
 
 template <class T>
 void MultiStateCounter<T>::updateValue(const T& value, time_t timestamp) {
-    // Confirm the current state for the side-effect of updating the time-in-state
-    // counter for the current state.
-    setState(currentState, timestamp);
+    // If the counter is disabled, we ignore the update, except when the counter got disabled after
+    // the previous update, in which case we still need to pick up the residual delta.
+    if (isEnabled || lastUpdateTimestamp < lastStateChangeTimestamp) {
+        // Confirm the current state for the side-effect of updating the time-in-state
+        // counter for the current state.
+        setState(currentState, timestamp);
 
-    if (lastUpdateTimestamp >= 0) {
-        if (timestamp > lastUpdateTimestamp) {
-            if (delta(lastValue, value, &deltaValue)) {
-                time_t timeSinceUpdate = timestamp - lastUpdateTimestamp;
-                for (int i = 0; i < stateCount; i++) {
-                    time_t timeInState = states[i].timeInStateSinceUpdate;
-                    if (timeInState) {
-                        add(&states[i].counter, deltaValue, timeInState, timeSinceUpdate);
-                        states[i].timeInStateSinceUpdate = 0;
+        if (lastUpdateTimestamp >= 0) {
+            if (timestamp > lastUpdateTimestamp) {
+                if (delta(lastValue, value, &deltaValue)) {
+                    time_t timeSinceUpdate = timestamp - lastUpdateTimestamp;
+                    for (int i = 0; i < stateCount; i++) {
+                        time_t timeInState = states[i].timeInStateSinceUpdate;
+                        if (timeInState) {
+                            add(&states[i].counter, deltaValue, timeInState, timeSinceUpdate);
+                            states[i].timeInStateSinceUpdate = 0;
+                        }
                     }
+                } else {
+                    std::stringstream str;
+                    str << "updateValue is called with a value " << valueToString(value)
+                        << ", which is lower than the previous value " << valueToString(lastValue)
+                        << "\n";
+                    ALOGE("%s", str.str().c_str());
                 }
-            } else {
-                std::stringstream str;
-                str << "updateValue is called with a value " << valueToString(value)
-                    << ", which is lower than the previous value " << valueToString(lastValue)
-                    << "\n";
-                ALOGE("%s", str.str().c_str());
+            } else if (timestamp < lastUpdateTimestamp) {
+                ALOGE("updateValue is called with an earlier timestamp: %lu, previous: %lu\n",
+                      (unsigned long)timestamp, (unsigned long)lastUpdateTimestamp);
             }
-        } else if (timestamp < lastUpdateTimestamp) {
-            ALOGE("updateValue is called with an earlier timestamp: %lu, previous timestamp: %lu\n",
-                  (unsigned long)timestamp, (unsigned long)lastUpdateTimestamp);
         }
     }
     lastValue = value;
@@ -169,6 +198,16 @@
 }
 
 template <class T>
+void MultiStateCounter<T>::reset() {
+    lastStateChangeTimestamp = -1;
+    lastUpdateTimestamp = -1;
+    for (int i = 0; i < stateCount; i++) {
+        states[i].timeInStateSinceUpdate = 0;
+        states[i].counter = emptyValue;
+    }
+}
+
+template <class T>
 uint16_t MultiStateCounter<T>::getStateCount() {
     return stateCount;
 }
diff --git a/libs/battery/MultiStateCounterTest.cpp b/libs/battery/MultiStateCounterTest.cpp
index 87c80c5..319ba76 100644
--- a/libs/battery/MultiStateCounterTest.cpp
+++ b/libs/battery/MultiStateCounterTest.cpp
@@ -71,6 +71,83 @@
     EXPECT_DOUBLE_EQ(4.0, testCounter.getCount(2));
 }
 
+TEST_F(MultiStateCounterTest, setEnabled) {
+    DoubleMultiStateCounter testCounter(3, 0);
+    testCounter.updateValue(0, 0);
+    testCounter.setState(1, 0);
+    testCounter.setEnabled(false, 1000);
+    testCounter.setState(2, 2000);
+    testCounter.updateValue(6.0, 3000);
+
+    // In state 1: accumulated 1000 before disabled, that's 6.0 * 1000/3000 = 2.0
+    // In state 2: 0, since it is still disabled
+    EXPECT_DOUBLE_EQ(0, testCounter.getCount(0));
+    EXPECT_DOUBLE_EQ(2.0, testCounter.getCount(1));
+    EXPECT_DOUBLE_EQ(0, testCounter.getCount(2));
+
+    // Should have no effect since the counter is disabled
+    testCounter.setState(0, 3500);
+
+    // Should have no effect since the counter is disabled
+    testCounter.updateValue(10.0, 4000);
+
+    EXPECT_DOUBLE_EQ(0, testCounter.getCount(0));
+    EXPECT_DOUBLE_EQ(2.0, testCounter.getCount(1));
+    EXPECT_DOUBLE_EQ(0, testCounter.getCount(2));
+
+    testCounter.setState(2, 4500);
+
+    // Enable the counter to partially accumulate deltas for the current state, 2
+    testCounter.setEnabled(true, 5000);
+    testCounter.setEnabled(false, 6000);
+    testCounter.setEnabled(true, 7000);
+    testCounter.updateValue(20.0, 8000);
+
+    // The delta is 10.0 over 5000-3000=2000.
+    // Counter has been enabled in state 2 for (6000-5000)+(8000-7000) = 2000,
+    // so its share is (20.0-10.0) * 2000/(8000-4000) = 5.0
+    EXPECT_DOUBLE_EQ(0, testCounter.getCount(0));
+    EXPECT_DOUBLE_EQ(2.0, testCounter.getCount(1));
+    EXPECT_DOUBLE_EQ(5.0, testCounter.getCount(2));
+
+    testCounter.reset();
+    testCounter.setState(0, 0);
+    testCounter.updateValue(0, 0);
+    testCounter.setState(1, 2000);
+    testCounter.setEnabled(false, 3000);
+    testCounter.updateValue(200, 5000);
+
+    // 200 over 5000 = 40 per second
+    // Counter was in state 0 from 0 to 2000, so 2 sec, so the count should be 40 * 2 = 80
+    // It stayed in state 1 from 2000 to 3000, at which point the counter was disabled,
+    // so the count for state 1 should be 40 * 1 = 40.
+    // The remaining 2 seconds from 3000 to 5000 don't count because the counter was disabled.
+    EXPECT_DOUBLE_EQ(80.0, testCounter.getCount(0));
+    EXPECT_DOUBLE_EQ(40.0, testCounter.getCount(1));
+    EXPECT_DOUBLE_EQ(0, testCounter.getCount(2));
+}
+
+TEST_F(MultiStateCounterTest, reset) {
+    DoubleMultiStateCounter testCounter(3, 0);
+    testCounter.updateValue(0, 0);
+    testCounter.setState(1, 0);
+    testCounter.updateValue(2.72, 3000);
+
+    testCounter.reset();
+
+    EXPECT_DOUBLE_EQ(0, testCounter.getCount(0));
+    EXPECT_DOUBLE_EQ(0, testCounter.getCount(1));
+    EXPECT_DOUBLE_EQ(0, testCounter.getCount(2));
+
+    // Assert that we can still continue accumulating after a reset
+    testCounter.updateValue(0, 4000);
+    testCounter.updateValue(3.14, 5000);
+
+    EXPECT_DOUBLE_EQ(0, testCounter.getCount(0));
+    EXPECT_DOUBLE_EQ(3.14, testCounter.getCount(1));
+    EXPECT_DOUBLE_EQ(0, testCounter.getCount(2));
+}
+
 TEST_F(MultiStateCounterTest, timeAdjustment_setState) {
     DoubleMultiStateCounter testCounter(3, 0);
     testCounter.updateValue(0, 0);