Fix One Euro filter's units of computation
Changed the units that the One Euro filter uses to compute the filtered
coordinates. This was causing a crash because if two timestamps were
sufficiently close to each other, by the time of implicitly converting
from nanoseconds to seconds, they were considered equal. This led to a
zero division when calculating the sampling frequency. Now, everything
is handled in the scale of nanoseconds, and conversion are done if and
only if they're necessary.
Bug: 297226446
Flag: EXEMPT bugfix
Test: TEST=libinput_tests; m $TEST && $ANDROID_HOST_OUT/nativetest64/$TEST/$TEST
Change-Id: I7fced6db447074cccb3d938eb9dc7a9707433f53
diff --git a/include/input/CoordinateFilter.h b/include/input/CoordinateFilter.h
index f36472d..8f2e605 100644
--- a/include/input/CoordinateFilter.h
+++ b/include/input/CoordinateFilter.h
@@ -44,7 +44,7 @@
* the previous call.
* @param coords Coordinates to be overwritten by the corresponding filtered coordinates.
*/
- void filter(std::chrono::duration<float> timestamp, PointerCoords& coords);
+ void filter(std::chrono::nanoseconds timestamp, PointerCoords& coords);
private:
OneEuroFilter mXFilter;
diff --git a/include/input/OneEuroFilter.h b/include/input/OneEuroFilter.h
index a0168e4..bdd82b2 100644
--- a/include/input/OneEuroFilter.h
+++ b/include/input/OneEuroFilter.h
@@ -56,7 +56,7 @@
* provided in the previous call.
* @param rawPosition Position to be filtered.
*/
- float filter(std::chrono::duration<float> timestamp, float rawPosition);
+ float filter(std::chrono::nanoseconds timestamp, float rawPosition);
private:
/**
@@ -67,7 +67,7 @@
/**
* Slope of the cutoff frequency criterion. This is the term scaling the absolute value of the
- * filtered signal's speed. The data member is dimensionless, that is, it does not have units.
+ * filtered signal's speed. Units are 1 / position.
*/
const float mBeta;
@@ -78,9 +78,9 @@
const float mSpeedCutoffFreq;
/**
- * The timestamp from the previous call. Units are seconds.
+ * The timestamp from the previous call.
*/
- std::optional<std::chrono::duration<float>> mPrevTimestamp;
+ std::optional<std::chrono::nanoseconds> mPrevTimestamp;
/**
* The raw position from the previous call.
@@ -88,7 +88,7 @@
std::optional<float> mPrevRawPosition;
/**
- * The filtered velocity from the previous call. Units are position per second.
+ * The filtered velocity from the previous call. Units are position per nanosecond.
*/
std::optional<float> mPrevFilteredVelocity;
diff --git a/libs/input/CoordinateFilter.cpp b/libs/input/CoordinateFilter.cpp
index d231474..a32685b 100644
--- a/libs/input/CoordinateFilter.cpp
+++ b/libs/input/CoordinateFilter.cpp
@@ -23,7 +23,7 @@
CoordinateFilter::CoordinateFilter(float minCutoffFreq, float beta)
: mXFilter{minCutoffFreq, beta}, mYFilter{minCutoffFreq, beta} {}
-void CoordinateFilter::filter(std::chrono::duration<float> timestamp, PointerCoords& coords) {
+void CoordinateFilter::filter(std::chrono::nanoseconds timestamp, PointerCoords& coords) {
coords.setAxisValue(AMOTION_EVENT_AXIS_X, mXFilter.filter(timestamp, coords.getX()));
coords.setAxisValue(AMOTION_EVENT_AXIS_Y, mYFilter.filter(timestamp, coords.getY()));
}
diff --git a/libs/input/OneEuroFilter.cpp b/libs/input/OneEuroFilter.cpp
index 400d7c9..7b0d104 100644
--- a/libs/input/OneEuroFilter.cpp
+++ b/libs/input/OneEuroFilter.cpp
@@ -25,16 +25,24 @@
namespace android {
namespace {
+using namespace std::literals::chrono_literals;
+
+const float kHertzPerGigahertz = 1E9f;
+const float kGigahertzPerHertz = 1E-9f;
+
+// filteredSpeed's units are position per nanosecond. beta's units are 1 / position.
inline float cutoffFreq(float minCutoffFreq, float beta, float filteredSpeed) {
- return minCutoffFreq + beta * std::abs(filteredSpeed);
+ return kHertzPerGigahertz *
+ ((minCutoffFreq * kGigahertzPerHertz) + beta * std::abs(filteredSpeed));
}
-inline float smoothingFactor(std::chrono::duration<float> samplingPeriod, float cutoffFreq) {
- return samplingPeriod.count() / (samplingPeriod.count() + (1.0 / (2.0 * M_PI * cutoffFreq)));
+inline float smoothingFactor(std::chrono::nanoseconds samplingPeriod, float cutoffFreq) {
+ const float constant = 2.0f * M_PI * samplingPeriod.count() * (cutoffFreq * kGigahertzPerHertz);
+ return constant / (constant + 1);
}
-inline float lowPassFilter(float rawPosition, float prevFilteredPosition, float smoothingFactor) {
- return smoothingFactor * rawPosition + (1 - smoothingFactor) * prevFilteredPosition;
+inline float lowPassFilter(float rawValue, float prevFilteredValue, float smoothingFactor) {
+ return smoothingFactor * rawValue + (1 - smoothingFactor) * prevFilteredValue;
}
} // namespace
@@ -42,17 +50,17 @@
OneEuroFilter::OneEuroFilter(float minCutoffFreq, float beta, float speedCutoffFreq)
: mMinCutoffFreq{minCutoffFreq}, mBeta{beta}, mSpeedCutoffFreq{speedCutoffFreq} {}
-float OneEuroFilter::filter(std::chrono::duration<float> timestamp, float rawPosition) {
- LOG_IF(FATAL, mPrevFilteredPosition.has_value() && (timestamp <= *mPrevTimestamp))
- << "Timestamp must be greater than mPrevTimestamp";
+float OneEuroFilter::filter(std::chrono::nanoseconds timestamp, float rawPosition) {
+ LOG_IF(FATAL, mPrevTimestamp.has_value() && (*mPrevTimestamp >= timestamp))
+ << "Timestamp must be greater than mPrevTimestamp. Timestamp: " << timestamp.count()
+ << "ns. mPrevTimestamp: " << mPrevTimestamp->count() << "ns";
- const std::chrono::duration<float> samplingPeriod = (mPrevTimestamp.has_value())
- ? (timestamp - *mPrevTimestamp)
- : std::chrono::duration<float>{1.0};
+ const std::chrono::nanoseconds samplingPeriod =
+ (mPrevTimestamp.has_value()) ? (timestamp - *mPrevTimestamp) : 1s;
const float rawVelocity = (mPrevFilteredPosition.has_value())
- ? ((rawPosition - *mPrevFilteredPosition) / samplingPeriod.count())
- : 0.0;
+ ? ((rawPosition - *mPrevFilteredPosition) / (samplingPeriod.count()))
+ : 0.0f;
const float speedSmoothingFactor = smoothingFactor(samplingPeriod, mSpeedCutoffFreq);
diff --git a/libs/input/tests/Android.bp b/libs/input/tests/Android.bp
index 46e8190..d1c564d 100644
--- a/libs/input/tests/Android.bp
+++ b/libs/input/tests/Android.bp
@@ -17,6 +17,7 @@
"IdGenerator_test.cpp",
"InputChannel_test.cpp",
"InputConsumer_test.cpp",
+ "InputConsumerFilteredResampling_test.cpp",
"InputConsumerResampling_test.cpp",
"InputDevice_test.cpp",
"InputEvent_test.cpp",
diff --git a/libs/input/tests/InputConsumerFilteredResampling_test.cpp b/libs/input/tests/InputConsumerFilteredResampling_test.cpp
new file mode 100644
index 0000000..757cd18
--- /dev/null
+++ b/libs/input/tests/InputConsumerFilteredResampling_test.cpp
@@ -0,0 +1,218 @@
+/**
+ * Copyright 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <input/InputConsumerNoResampling.h>
+
+#include <chrono>
+#include <iostream>
+#include <memory>
+#include <queue>
+
+#include <TestEventMatchers.h>
+#include <TestInputChannel.h>
+#include <android-base/logging.h>
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include <input/Input.h>
+#include <input/InputEventBuilders.h>
+#include <input/Resampler.h>
+#include <utils/Looper.h>
+#include <utils/StrongPointer.h>
+
+namespace android {
+namespace {
+
+using std::chrono::nanoseconds;
+
+using ::testing::AllOf;
+using ::testing::Matcher;
+
+const int32_t ACTION_DOWN = AMOTION_EVENT_ACTION_DOWN;
+const int32_t ACTION_MOVE = AMOTION_EVENT_ACTION_MOVE;
+
+struct Pointer {
+ int32_t id{0};
+ ToolType toolType{ToolType::FINGER};
+ float x{0.0f};
+ float y{0.0f};
+ bool isResampled{false};
+
+ PointerBuilder asPointerBuilder() const {
+ return PointerBuilder{id, toolType}.x(x).y(y).isResampled(isResampled);
+ }
+};
+
+} // namespace
+
+class InputConsumerFilteredResamplingTest : public ::testing::Test, public InputConsumerCallbacks {
+protected:
+ InputConsumerFilteredResamplingTest()
+ : mClientTestChannel{std::make_shared<TestInputChannel>("TestChannel")},
+ mLooper{sp<Looper>::make(/*allowNonCallbacks=*/false)} {
+ Looper::setForThread(mLooper);
+ mConsumer = std::make_unique<
+ InputConsumerNoResampling>(mClientTestChannel, mLooper, *this, []() {
+ return std::make_unique<FilteredLegacyResampler>(/*minCutoffFreq=*/4.7, /*beta=*/0.01);
+ });
+ }
+
+ void invokeLooperCallback() const {
+ sp<LooperCallback> callback;
+ ASSERT_TRUE(mLooper->getFdStateDebug(mClientTestChannel->getFd(), /*ident=*/nullptr,
+ /*events=*/nullptr, &callback, /*data=*/nullptr));
+ ASSERT_NE(callback, nullptr);
+ callback->handleEvent(mClientTestChannel->getFd(), ALOOPER_EVENT_INPUT, /*data=*/nullptr);
+ }
+
+ void assertOnBatchedInputEventPendingWasCalled() {
+ ASSERT_GT(mOnBatchedInputEventPendingInvocationCount, 0UL)
+ << "onBatchedInputEventPending was not called";
+ --mOnBatchedInputEventPendingInvocationCount;
+ }
+
+ void assertReceivedMotionEvent(const Matcher<MotionEvent>& matcher) {
+ ASSERT_TRUE(!mMotionEvents.empty()) << "No motion events were received";
+ std::unique_ptr<MotionEvent> motionEvent = std::move(mMotionEvents.front());
+ mMotionEvents.pop();
+ ASSERT_NE(motionEvent, nullptr) << "The consumed motion event must not be nullptr";
+ EXPECT_THAT(*motionEvent, matcher);
+ }
+
+ InputMessage nextPointerMessage(nanoseconds eventTime, int32_t action, const Pointer& pointer);
+
+ std::shared_ptr<TestInputChannel> mClientTestChannel;
+ sp<Looper> mLooper;
+ std::unique_ptr<InputConsumerNoResampling> mConsumer;
+
+ // Batched input events
+ std::queue<std::unique_ptr<KeyEvent>> mKeyEvents;
+ std::queue<std::unique_ptr<MotionEvent>> mMotionEvents;
+ std::queue<std::unique_ptr<FocusEvent>> mFocusEvents;
+ std::queue<std::unique_ptr<CaptureEvent>> mCaptureEvents;
+ std::queue<std::unique_ptr<DragEvent>> mDragEvents;
+ std::queue<std::unique_ptr<TouchModeEvent>> mTouchModeEvents;
+
+private:
+ // InputConsumer callbacks
+ void onKeyEvent(std::unique_ptr<KeyEvent> event, uint32_t seq) override {
+ mKeyEvents.push(std::move(event));
+ mConsumer->finishInputEvent(seq, /*handled=*/true);
+ }
+
+ void onMotionEvent(std::unique_ptr<MotionEvent> event, uint32_t seq) override {
+ mMotionEvents.push(std::move(event));
+ mConsumer->finishInputEvent(seq, /*handled=*/true);
+ }
+
+ void onBatchedInputEventPending(int32_t pendingBatchSource) override {
+ if (!mConsumer->probablyHasInput()) {
+ ADD_FAILURE() << "Should deterministically have input because there is a batch";
+ }
+ ++mOnBatchedInputEventPendingInvocationCount;
+ }
+
+ void onFocusEvent(std::unique_ptr<FocusEvent> event, uint32_t seq) override {
+ mFocusEvents.push(std::move(event));
+ mConsumer->finishInputEvent(seq, /*handled=*/true);
+ }
+
+ void onCaptureEvent(std::unique_ptr<CaptureEvent> event, uint32_t seq) override {
+ mCaptureEvents.push(std::move(event));
+ mConsumer->finishInputEvent(seq, /*handled=*/true);
+ }
+
+ void onDragEvent(std::unique_ptr<DragEvent> event, uint32_t seq) override {
+ mDragEvents.push(std::move(event));
+ mConsumer->finishInputEvent(seq, /*handled=*/true);
+ }
+
+ void onTouchModeEvent(std::unique_ptr<TouchModeEvent> event, uint32_t seq) override {
+ mTouchModeEvents.push(std::move(event));
+ mConsumer->finishInputEvent(seq, /*handled=*/true);
+ }
+
+ uint32_t mLastSeq{0};
+ size_t mOnBatchedInputEventPendingInvocationCount{0};
+};
+
+InputMessage InputConsumerFilteredResamplingTest::nextPointerMessage(nanoseconds eventTime,
+ int32_t action,
+ const Pointer& pointer) {
+ ++mLastSeq;
+ return InputMessageBuilder{InputMessage::Type::MOTION, mLastSeq}
+ .eventTime(eventTime.count())
+ .source(AINPUT_SOURCE_TOUCHSCREEN)
+ .action(action)
+ .pointer(pointer.asPointerBuilder())
+ .build();
+}
+
+TEST_F(InputConsumerFilteredResamplingTest, NeighboringTimestampsDoNotResultInZeroDivision) {
+ mClientTestChannel->enqueueMessage(
+ nextPointerMessage(0ms, ACTION_DOWN, Pointer{.x = 0.0f, .y = 0.0f}));
+
+ invokeLooperCallback();
+
+ assertReceivedMotionEvent(AllOf(WithMotionAction(ACTION_DOWN), WithSampleCount(1)));
+
+ const std::chrono::nanoseconds initialTime{56'821'700'000'000};
+
+ mClientTestChannel->enqueueMessage(nextPointerMessage(initialTime + 4'929'000ns, ACTION_MOVE,
+ Pointer{.x = 1.0f, .y = 1.0f}));
+ mClientTestChannel->enqueueMessage(nextPointerMessage(initialTime + 9'352'000ns, ACTION_MOVE,
+ Pointer{.x = 2.0f, .y = 2.0f}));
+ mClientTestChannel->enqueueMessage(nextPointerMessage(initialTime + 14'531'000ns, ACTION_MOVE,
+ Pointer{.x = 3.0f, .y = 3.0f}));
+
+ invokeLooperCallback();
+ mConsumer->consumeBatchedInputEvents(initialTime.count() + 18'849'395 /*ns*/);
+
+ assertOnBatchedInputEventPendingWasCalled();
+ // Three samples are expected. The first two of the batch, and the resampled one. The
+ // coordinates of the resampled sample are hardcoded because the matcher requires them. However,
+ // the primary intention here is to check that the last sample is resampled.
+ assertReceivedMotionEvent(AllOf(WithMotionAction(ACTION_MOVE), WithSampleCount(3),
+ WithSample(/*sampleIndex=*/2,
+ Sample{initialTime + 13'849'395ns,
+ {PointerArgs{.x = 1.3286f,
+ .y = 1.3286f,
+ .isResampled = true}}})));
+
+ mClientTestChannel->enqueueMessage(nextPointerMessage(initialTime + 20'363'000ns, ACTION_MOVE,
+ Pointer{.x = 4.0f, .y = 4.0f}));
+ mClientTestChannel->enqueueMessage(nextPointerMessage(initialTime + 25'745'000ns, ACTION_MOVE,
+ Pointer{.x = 5.0f, .y = 5.0f}));
+ // This sample is part of the stream of messages, but should not be consumed because its
+ // timestamp is greater than the ajusted frame time.
+ mClientTestChannel->enqueueMessage(nextPointerMessage(initialTime + 31'337'000ns, ACTION_MOVE,
+ Pointer{.x = 6.0f, .y = 6.0f}));
+
+ invokeLooperCallback();
+ mConsumer->consumeBatchedInputEvents(initialTime.count() + 35'516'062 /*ns*/);
+
+ assertOnBatchedInputEventPendingWasCalled();
+ // Four samples are expected because the last sample of the previous batch was not consumed.
+ assertReceivedMotionEvent(AllOf(WithMotionAction(ACTION_MOVE), WithSampleCount(4)));
+
+ mClientTestChannel->assertFinishMessage(/*seq=*/1, /*handled=*/true);
+ mClientTestChannel->assertFinishMessage(/*seq=*/2, /*handled=*/true);
+ mClientTestChannel->assertFinishMessage(/*seq=*/3, /*handled=*/true);
+ mClientTestChannel->assertFinishMessage(/*seq=*/4, /*handled=*/true);
+ mClientTestChannel->assertFinishMessage(/*seq=*/5, /*handled=*/true);
+ mClientTestChannel->assertFinishMessage(/*seq=*/6, /*handled=*/true);
+}
+
+} // namespace android
diff --git a/libs/input/tests/OneEuroFilter_test.cpp b/libs/input/tests/OneEuroFilter_test.cpp
index 270e789..8645508 100644
--- a/libs/input/tests/OneEuroFilter_test.cpp
+++ b/libs/input/tests/OneEuroFilter_test.cpp
@@ -98,7 +98,10 @@
std::vector<Sample> filteredSignal;
for (const Sample& sample : signal) {
filteredSignal.push_back(
- Sample{sample.timestamp, mFilter.filter(sample.timestamp, sample.value)});
+ Sample{sample.timestamp,
+ mFilter.filter(std::chrono::duration_cast<std::chrono::nanoseconds>(
+ sample.timestamp),
+ sample.value)});
}
return filteredSignal;
}
diff --git a/libs/input/tests/TestEventMatchers.h b/libs/input/tests/TestEventMatchers.h
index 3589de5..de96600 100644
--- a/libs/input/tests/TestEventMatchers.h
+++ b/libs/input/tests/TestEventMatchers.h
@@ -17,6 +17,7 @@
#pragma once
#include <chrono>
+#include <cmath>
#include <ostream>
#include <vector>
@@ -150,14 +151,18 @@
++pointerIndex) {
const PointerCoords& pointerCoords =
*(motionEvent.getHistoricalRawPointerCoords(pointerIndex, mSampleIndex));
- if ((pointerCoords.getX() != mSample.pointers[pointerIndex].x) ||
- (pointerCoords.getY() != mSample.pointers[pointerIndex].y)) {
+
+ if ((std::abs(pointerCoords.getX() - mSample.pointers[pointerIndex].x) >
+ MotionEvent::ROUNDING_PRECISION) ||
+ (std::abs(pointerCoords.getY() - mSample.pointers[pointerIndex].y) >
+ MotionEvent::ROUNDING_PRECISION)) {
*os << "sample coordinates mismatch at pointer index " << pointerIndex
<< ". sample: (" << pointerCoords.getX() << ", " << pointerCoords.getY()
<< ") expected: (" << mSample.pointers[pointerIndex].x << ", "
<< mSample.pointers[pointerIndex].y << ")";
return false;
}
+
if (motionEvent.isResampled(pointerIndex, mSampleIndex) !=
mSample.pointers[pointerIndex].isResampled) {
*os << "resampling flag mismatch. sample: "