SyncSM06: Add StateMachineShim

Test: atest StateMachineShimTest

Change-Id: Ic818aa55e7e0fd7a62dfce50a6ad719e6e1c44ec
diff --git a/Tethering/src/com/android/networkstack/tethering/TetheringConfiguration.java b/Tethering/src/com/android/networkstack/tethering/TetheringConfiguration.java
index 747cc20..502fee8 100644
--- a/Tethering/src/com/android/networkstack/tethering/TetheringConfiguration.java
+++ b/Tethering/src/com/android/networkstack/tethering/TetheringConfiguration.java
@@ -136,6 +136,9 @@
      */
     public static final int DEFAULT_TETHER_OFFLOAD_POLL_INTERVAL_MS = 5000;
 
+    /** A flag for using synchronous or asynchronous state machine. */
+    public static final boolean USE_SYNC_SM = false;
+
     public final String[] tetherableUsbRegexs;
     public final String[] tetherableWifiRegexs;
     public final String[] tetherableWigigRegexs;
diff --git a/Tethering/src/com/android/networkstack/tethering/util/StateMachineShim.java b/Tethering/src/com/android/networkstack/tethering/util/StateMachineShim.java
new file mode 100644
index 0000000..fc432f7
--- /dev/null
+++ b/Tethering/src/com/android/networkstack/tethering/util/StateMachineShim.java
@@ -0,0 +1,176 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.networkstack.tethering.util;
+
+import android.annotation.Nullable;
+import android.os.Looper;
+
+import com.android.internal.annotations.VisibleForTesting;
+import com.android.internal.util.State;
+import com.android.internal.util.StateMachine;
+import com.android.networkstack.tethering.util.SyncStateMachine.StateInfo;
+
+import java.util.List;
+
+/** A wrapper to decide whether use synchronous state machine for tethering. */
+public class StateMachineShim {
+    // Exactly one of mAsyncSM or mSyncSM is non-null.
+    private final StateMachine mAsyncSM;
+    private final SyncStateMachine mSyncSM;
+
+    /**
+     * The Looper parameter is only needed for AsyncSM, so if looper is null, the shim will be
+     * created for SyncSM.
+     */
+    public StateMachineShim(final String name, @Nullable final Looper looper) {
+        this(name, looper, new Dependencies());
+    }
+
+    @VisibleForTesting
+    public StateMachineShim(final String name, @Nullable final Looper looper,
+            final Dependencies deps) {
+        if (looper == null) {
+            mAsyncSM = null;
+            mSyncSM = deps.makeSyncStateMachine(name, Thread.currentThread());
+        } else {
+            mAsyncSM = deps.makeAsyncStateMachine(name, looper);
+            mSyncSM = null;
+        }
+    }
+
+    /** A dependencies class which used for testing injection. */
+    @VisibleForTesting
+    public static class Dependencies {
+        /** Create SyncSM instance, for injection. */
+        public SyncStateMachine makeSyncStateMachine(final String name, final Thread thread) {
+            return new SyncStateMachine(name, thread);
+        }
+
+        /** Create AsyncSM instance, for injection. */
+        public AsyncStateMachine makeAsyncStateMachine(final String name, final Looper looper) {
+            return new AsyncStateMachine(name, looper);
+        }
+    }
+
+    /** Start the state machine */
+    public void start(final State initialState) {
+        if (mSyncSM != null) {
+            mSyncSM.start(initialState);
+        } else {
+            mAsyncSM.setInitialState(initialState);
+            mAsyncSM.start();
+        }
+    }
+
+    /** Add states to state machine. */
+    public void addAllStates(final List<StateInfo> stateInfos) {
+        if (mSyncSM != null) {
+            mSyncSM.addAllStates(stateInfos);
+        } else {
+            for (final StateInfo info : stateInfos) {
+                mAsyncSM.addState(info.state, info.parent);
+            }
+        }
+    }
+
+    /**
+     * Transition to given state.
+     *
+     * SyncSM doesn't allow this be called during state transition (#enter() or #exit() methods),
+     * or multiple times while processing a single message.
+     */
+    public void transitionTo(final State state) {
+        if (mSyncSM != null) {
+            mSyncSM.transitionTo(state);
+        } else {
+            mAsyncSM.transitionTo(state);
+        }
+    }
+
+    /** Send message to state machine. */
+    public void sendMessage(int what) {
+        sendMessage(what, 0, 0, null);
+    }
+
+    /** Send message to state machine. */
+    public void sendMessage(int what, Object obj) {
+        sendMessage(what, 0, 0, obj);
+    }
+
+    /** Send message to state machine. */
+    public void sendMessage(int what, int arg1) {
+        sendMessage(what, arg1, 0, null);
+    }
+
+    /**
+     * Send message to state machine.
+     *
+     * If using asynchronous state machine, putting the message into looper's message queue.
+     * Tethering runs on single looper thread that ipServers and mainSM all share with same message
+     * queue. The enqueued message will be processed by asynchronous state machine when all the
+     * messages before such enqueued message are processed.
+     * If using synchronous state machine, the message is processed right away without putting into
+     * looper's message queue.
+     */
+    public void sendMessage(int what, int arg1, int arg2, Object obj) {
+        if (mSyncSM != null) {
+            mSyncSM.processMessage(what, arg1, arg2, obj);
+        } else {
+            mAsyncSM.sendMessage(what, arg1, arg2, obj);
+        }
+    }
+
+    /**
+     * Send message after delayMillis millisecond.
+     *
+     * This can only be used with async state machine, so this will throw if using sync state
+     * machine.
+     */
+    public void sendMessageDelayedToAsyncSM(final int what, final long delayMillis) {
+        if (mSyncSM != null) {
+            throw new IllegalStateException("sendMessageDelayed can only be used with async SM");
+        }
+
+        mAsyncSM.sendMessageDelayed(what, delayMillis);
+    }
+
+    /**
+     * Send self message.
+     * This can only be used with sync state machine, so this will throw if using async state
+     * machine.
+     */
+    public void sendSelfMessageToSyncSM(final int what, final Object obj) {
+        if (mSyncSM == null) {
+            throw new IllegalStateException("sendSelfMessage can only be used with sync SM");
+        }
+
+        mSyncSM.sendSelfMessage(what, 0, 0, obj);
+    }
+
+    /**
+     * An alias StateMahchine class with public construtor.
+     *
+     * Since StateMachine.java only provides protected construtor, adding a child class so that this
+     * shim could create StateMachine instance.
+     */
+    @VisibleForTesting
+    public static class AsyncStateMachine extends StateMachine {
+        public AsyncStateMachine(final String name, final Looper looper) {
+            super(name, looper);
+        }
+    }
+}
diff --git a/Tethering/tests/unit/src/com/android/networkstack/tethering/util/StateMachineShimTest.kt b/Tethering/tests/unit/src/com/android/networkstack/tethering/util/StateMachineShimTest.kt
new file mode 100644
index 0000000..2c4df76
--- /dev/null
+++ b/Tethering/tests/unit/src/com/android/networkstack/tethering/util/StateMachineShimTest.kt
@@ -0,0 +1,128 @@
+/**
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.android.networkstack.tethering.util
+
+import android.os.Looper
+import androidx.test.ext.junit.runners.AndroidJUnit4
+import androidx.test.filters.SmallTest
+import com.android.internal.util.State
+import com.android.networkstack.tethering.util.StateMachineShim.AsyncStateMachine
+import com.android.networkstack.tethering.util.StateMachineShim.Dependencies
+import com.android.networkstack.tethering.util.SyncStateMachine.StateInfo
+import kotlin.test.assertFailsWith
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.mockito.Mockito.inOrder
+import org.mockito.Mockito.mock
+import org.mockito.Mockito.verify
+import org.mockito.Mockito.verifyNoMoreInteractions
+
+@RunWith(AndroidJUnit4::class)
+@SmallTest
+class StateMachineShimTest {
+    private val mSyncSM = mock(SyncStateMachine::class.java)
+    private val mAsyncSM = mock(AsyncStateMachine::class.java)
+    private val mState1 = mock(State::class.java)
+    private val mState2 = mock(State::class.java)
+
+    inner class MyDependencies() : Dependencies() {
+
+        override fun makeSyncStateMachine(name: String, thread: Thread) = mSyncSM
+
+        override fun makeAsyncStateMachine(name: String, looper: Looper) = mAsyncSM
+    }
+
+    @Test
+    fun testUsingSyncStateMachine() {
+        val inOrder = inOrder(mSyncSM, mAsyncSM)
+        val shimUsingSyncSM = StateMachineShim("ShimTest", null, MyDependencies())
+        shimUsingSyncSM.start(mState1)
+        inOrder.verify(mSyncSM).start(mState1)
+
+        val allStates = ArrayList<StateInfo>()
+        allStates.add(StateInfo(mState1, null))
+        allStates.add(StateInfo(mState2, mState1))
+        shimUsingSyncSM.addAllStates(allStates)
+        inOrder.verify(mSyncSM).addAllStates(allStates)
+
+        shimUsingSyncSM.transitionTo(mState1)
+        inOrder.verify(mSyncSM).transitionTo(mState1)
+
+        val what = 10
+        shimUsingSyncSM.sendMessage(what)
+        inOrder.verify(mSyncSM).processMessage(what, 0, 0, null)
+        val obj = Object()
+        shimUsingSyncSM.sendMessage(what, obj)
+        inOrder.verify(mSyncSM).processMessage(what, 0, 0, obj)
+        val arg1 = 11
+        shimUsingSyncSM.sendMessage(what, arg1)
+        inOrder.verify(mSyncSM).processMessage(what, arg1, 0, null)
+        val arg2 = 12
+        shimUsingSyncSM.sendMessage(what, arg1, arg2, obj)
+        inOrder.verify(mSyncSM).processMessage(what, arg1, arg2, obj)
+
+        assertFailsWith(IllegalStateException::class) {
+            shimUsingSyncSM.sendMessageDelayedToAsyncSM(what, 1000 /* delayMillis */)
+        }
+
+        shimUsingSyncSM.sendSelfMessageToSyncSM(what, obj)
+        inOrder.verify(mSyncSM).sendSelfMessage(what, 0, 0, obj)
+
+        verifyNoMoreInteractions(mSyncSM, mAsyncSM)
+    }
+
+    @Test
+    fun testUsingAsyncStateMachine() {
+        val inOrder = inOrder(mSyncSM, mAsyncSM)
+        val shimUsingAsyncSM = StateMachineShim("ShimTest", mock(Looper::class.java),
+                MyDependencies())
+        shimUsingAsyncSM.start(mState1)
+        inOrder.verify(mAsyncSM).setInitialState(mState1)
+        inOrder.verify(mAsyncSM).start()
+
+        val allStates = ArrayList<StateInfo>()
+        allStates.add(StateInfo(mState1, null))
+        allStates.add(StateInfo(mState2, mState1))
+        shimUsingAsyncSM.addAllStates(allStates)
+        inOrder.verify(mAsyncSM).addState(mState1, null)
+        inOrder.verify(mAsyncSM).addState(mState2, mState1)
+
+        shimUsingAsyncSM.transitionTo(mState1)
+        inOrder.verify(mAsyncSM).transitionTo(mState1)
+
+        val what = 10
+        shimUsingAsyncSM.sendMessage(what)
+        inOrder.verify(mAsyncSM).sendMessage(what, 0, 0, null)
+        val obj = Object()
+        shimUsingAsyncSM.sendMessage(what, obj)
+        inOrder.verify(mAsyncSM).sendMessage(what, 0, 0, obj)
+        val arg1 = 11
+        shimUsingAsyncSM.sendMessage(what, arg1)
+        inOrder.verify(mAsyncSM).sendMessage(what, arg1, 0, null)
+        val arg2 = 12
+        shimUsingAsyncSM.sendMessage(what, arg1, arg2, obj)
+        inOrder.verify(mAsyncSM).sendMessage(what, arg1, arg2, obj)
+
+        shimUsingAsyncSM.sendMessageDelayedToAsyncSM(what, 1000 /* delayMillis */)
+        inOrder.verify(mAsyncSM).sendMessageDelayed(what, 1000)
+
+        assertFailsWith(IllegalStateException::class) {
+            shimUsingAsyncSM.sendSelfMessageToSyncSM(what, obj)
+        }
+
+        verifyNoMoreInteractions(mSyncSM, mAsyncSM)
+    }
+}