Merge "Fixing robolectric tests not running on UI thread after changing looper mode" into main
diff --git a/tests/multivalentTests/src/com/android/launcher3/icons/FastBitmapDrawableTest.java b/tests/multivalentTests/src/com/android/launcher3/icons/FastBitmapDrawableTest.java
index fbbfb2a..7e9b68d 100644
--- a/tests/multivalentTests/src/com/android/launcher3/icons/FastBitmapDrawableTest.java
+++ b/tests/multivalentTests/src/com/android/launcher3/icons/FastBitmapDrawableTest.java
@@ -37,11 +37,15 @@
 import android.view.animation.PathInterpolator;
 
 import androidx.test.annotation.UiThreadTest;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
 import androidx.test.filters.SmallTest;
-import androidx.test.runner.AndroidJUnit4;
+
+import com.android.launcher3.util.rule.RobolectricUiThreadRule;
 
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
+import org.junit.rules.TestRule;
 import org.junit.runner.RunWith;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
@@ -51,11 +55,14 @@
  * Tests for FastBitmapDrawable.
  */
 @SmallTest
-@UiThreadTest
 @RunWith(AndroidJUnit4.class)
+@UiThreadTest
 public class FastBitmapDrawableTest {
     private static final float EPSILON = 0.00001f;
 
+    @Rule
+    public final TestRule roboUiThreadRule = new RobolectricUiThreadRule();
+
     @Spy
     FastBitmapDrawable mFastBitmapDrawable =
             spy(new FastBitmapDrawable(Bitmap.createBitmap(100, 100, Bitmap.Config.ARGB_8888)));
diff --git a/tests/multivalentTests/src/com/android/launcher3/logging/StartupLatencyLoggerTest.kt b/tests/multivalentTests/src/com/android/launcher3/logging/StartupLatencyLoggerTest.kt
index 130dfad..12f6c8c 100644
--- a/tests/multivalentTests/src/com/android/launcher3/logging/StartupLatencyLoggerTest.kt
+++ b/tests/multivalentTests/src/com/android/launcher3/logging/StartupLatencyLoggerTest.kt
@@ -4,8 +4,10 @@
 import androidx.test.annotation.UiThreadTest
 import androidx.test.ext.junit.runners.AndroidJUnit4
 import androidx.test.filters.SmallTest
+import com.android.launcher3.util.rule.RobolectricUiThreadRule
 import com.google.common.truth.Truth.assertThat
 import org.junit.Before
+import org.junit.Rule
 import org.junit.Test
 import org.junit.runner.RunWith
 
@@ -14,6 +16,8 @@
 @RunWith(AndroidJUnit4::class)
 class StartupLatencyLoggerTest {
 
+    @get:Rule val roboUiThreadRule = RobolectricUiThreadRule()
+
     private val underTest = ColdRebootStartupLatencyLogger()
 
     @Before
diff --git a/tests/multivalentTests/src/com/android/launcher3/util/rule/RobolectricUiThreadRule.kt b/tests/multivalentTests/src/com/android/launcher3/util/rule/RobolectricUiThreadRule.kt
new file mode 100644
index 0000000..18cd1e4
--- /dev/null
+++ b/tests/multivalentTests/src/com/android/launcher3/util/rule/RobolectricUiThreadRule.kt
@@ -0,0 +1,90 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.launcher3.util.rule
+
+import androidx.test.annotation.UiThreadTest
+import androidx.test.platform.app.InstrumentationRegistry
+import java.util.Locale
+import java.util.concurrent.atomic.AtomicReference
+import org.junit.rules.TestRule
+import org.junit.runner.Description
+import org.junit.runners.model.Statement
+
+/**
+ * A test rule to add support for @UiThreadTest annotations when running in robolectric until is it
+ * natively supported by the robolectric runner:
+ * https://github.com/robolectric/robolectric/issues/9026
+ */
+class RobolectricUiThreadRule : TestRule {
+
+    override fun apply(base: Statement, description: Description): Statement =
+        if (!shouldRunOnUiThread(description)) base else UiThreadStatement(base)
+
+    private fun shouldRunOnUiThread(description: Description): Boolean {
+        if (!isRunningInRobolectric()) {
+            // If not running in robolectric, let the default runner handle this
+            return false
+        }
+        var clazz = description.testClass
+        try {
+            if (
+                clazz
+                    .getDeclaredMethod(description.methodName)
+                    .getAnnotation(UiThreadTest::class.java) != null
+            ) {
+                return true
+            }
+        } catch (_: Exception) {
+            // Ignore
+        }
+
+        while (!clazz.isAnnotationPresent(UiThreadTest::class.java)) {
+            clazz = clazz.superclass ?: return false
+        }
+        return true
+    }
+
+    private fun isRunningInRobolectric(): Boolean {
+        if (
+            System.getProperty("java.runtime.name")
+                .lowercase(Locale.getDefault())
+                .contains("android")
+        )
+            return false
+        return try {
+            // Check if robolectric runner exists
+            Class.forName("org.robolectric.RobolectricTestRunner") != null
+        } catch (e: ClassNotFoundException) {
+            false
+        }
+    }
+
+    private class UiThreadStatement(val base: Statement) : Statement() {
+
+        override fun evaluate() {
+            val exceptionRef = AtomicReference<Throwable>()
+            InstrumentationRegistry.getInstrumentation().runOnMainSync {
+                try {
+                    base.evaluate()
+                } catch (throwable: Throwable) {
+                    exceptionRef.set(throwable)
+                }
+            }
+            exceptionRef.get()?.let { throw it }
+        }
+    }
+}