diff --git a/android/TerminalApp/Android.bp b/android/TerminalApp/Android.bp
index e1e236a..c18ada4 100644
--- a/android/TerminalApp/Android.bp
+++ b/android/TerminalApp/Android.bp
@@ -14,6 +14,7 @@
         // TODO(b/330257000): will be removed when binder RPC is used
         "android.system.virtualizationservice_internal-java",
         "androidx-constraintlayout_constraintlayout",
+        "androidx.navigation_navigation-fragment-ktx",
         "androidx.window_window",
         "androidx.work_work-runtime",
         "apache-commons-compress",
diff --git a/android/TerminalApp/java/com/android/virtualization/terminal/MainActivity.kt b/android/TerminalApp/java/com/android/virtualization/terminal/MainActivity.kt
index 487b7e8..3bdea72 100644
--- a/android/TerminalApp/java/com/android/virtualization/terminal/MainActivity.kt
+++ b/android/TerminalApp/java/com/android/virtualization/terminal/MainActivity.kt
@@ -45,7 +45,7 @@
 import androidx.activity.result.ActivityResultCallback
 import androidx.activity.result.ActivityResultLauncher
 import androidx.activity.result.contract.ActivityResultContracts.StartActivityForResult
-import androidx.lifecycle.ViewModelProvider
+import androidx.activity.viewModels
 import androidx.viewpager2.widget.ViewPager2
 import com.android.internal.annotations.VisibleForTesting
 import com.android.microdroid.test.common.DeviceProperties
@@ -72,11 +72,11 @@
     private lateinit var image: InstalledImage
     private lateinit var accessibilityManager: AccessibilityManager
     private lateinit var manageExternalStorageActivityResultLauncher: ActivityResultLauncher<Intent>
-    private lateinit var terminalViewModel: TerminalViewModel
     private lateinit var viewPager: ViewPager2
     private lateinit var tabLayout: TabLayout
     private lateinit var terminalTabAdapter: TerminalTabAdapter
     private val terminalInfo = CompletableFuture<TerminalInfo>()
+    private val terminalViewModel: TerminalViewModel by viewModels()
 
     override fun onCreate(savedInstanceState: Bundle?) {
         super.onCreate(savedInstanceState)
@@ -111,7 +111,6 @@
     }
 
     private fun initializeUi() {
-        terminalViewModel = ViewModelProvider(this)[TerminalViewModel::class.java]
         setContentView(R.layout.activity_headless)
         tabLayout = findViewById<TabLayout>(R.id.tab_layout)
         displayMenu = findViewById<Button>(R.id.display_button)
@@ -151,6 +150,20 @@
         TabLayoutMediator(tabLayout, viewPager, false, false) { _: TabLayout.Tab?, _: Int -> }
             .attach()
 
+        tabLayout.addOnTabSelectedListener(
+            object : TabLayout.OnTabSelectedListener {
+                override fun onTabSelected(tab: TabLayout.Tab?) {
+                    tab?.position?.let {
+                        terminalViewModel.selectedTabViewId = terminalTabAdapter.tabs[it].id
+                    }
+                }
+
+                override fun onTabUnselected(tab: TabLayout.Tab?) {}
+
+                override fun onTabReselected(tab: TabLayout.Tab?) {}
+            }
+        )
+
         addTerminalTab()
 
         tabAddButton?.setOnClickListener { addTerminalTab() }
@@ -160,7 +173,7 @@
         val tab = tabLayout.newTab()
         tab.setCustomView(R.layout.tabitem_terminal)
         viewPager.offscreenPageLimit += 1
-        terminalTabAdapter.addTab()
+        terminalViewModel.selectedTabViewId = terminalTabAdapter.addTab()
         tab.customView!!
             .findViewById<Button>(R.id.tab_close_button)
             .setOnClickListener(
diff --git a/android/TerminalApp/java/com/android/virtualization/terminal/TerminalTabFragment.kt b/android/TerminalApp/java/com/android/virtualization/terminal/TerminalTabFragment.kt
index 7e78235..9644885 100644
--- a/android/TerminalApp/java/com/android/virtualization/terminal/TerminalTabFragment.kt
+++ b/android/TerminalApp/java/com/android/virtualization/terminal/TerminalTabFragment.kt
@@ -32,7 +32,7 @@
 import android.webkit.WebView
 import android.webkit.WebViewClient
 import androidx.fragment.app.Fragment
-import androidx.lifecycle.ViewModelProvider
+import androidx.fragment.app.activityViewModels
 import com.android.system.virtualmachine.flags.Flags.terminalGuiSupport
 import com.android.virtualization.terminal.CertificateUtils.createOrGetKey
 import com.android.virtualization.terminal.CertificateUtils.writeCertificateToFile
@@ -45,7 +45,7 @@
     private lateinit var id: String
     private var certificates: Array<X509Certificate>? = null
     private var privateKey: PrivateKey? = null
-    private lateinit var terminalViewModel: TerminalViewModel
+    private val terminalViewModel: TerminalViewModel by activityViewModels()
 
     override fun onCreateView(
         inflater: LayoutInflater,
@@ -59,7 +59,6 @@
 
     override fun onViewCreated(view: View, savedInstanceState: Bundle?) {
         super.onViewCreated(view, savedInstanceState)
-        terminalViewModel = ViewModelProvider(this)[TerminalViewModel::class.java]
         terminalView = view.findViewById(R.id.webview)
         bootProgressView = view.findViewById(R.id.boot_progress)
         initializeWebView()
@@ -79,6 +78,11 @@
         terminalView.saveState(outState)
     }
 
+    override fun onResume() {
+        super.onResume()
+        updateFocus()
+    }
+
     private fun initializeWebView() {
         terminalView.settings.databaseEnabled = true
         terminalView.settings.domStorageEnabled = true
@@ -148,6 +152,7 @@
                             terminalView.visibility = View.VISIBLE
                             terminalView.mapTouchToMouseEvent()
                             updateMainActivity()
+                            updateFocus()
                         }
                     }
                 },
@@ -189,6 +194,12 @@
         certificates = arrayOf<X509Certificate>(pke.certificate as X509Certificate)
     }
 
+    private fun updateFocus() {
+        if (terminalViewModel.selectedTabViewId == id) {
+            terminalView.requestFocus()
+        }
+    }
+
     companion object {
         const val TAG: String = "VmTerminalApp"
     }
diff --git a/android/TerminalApp/java/com/android/virtualization/terminal/TerminalViewModel.kt b/android/TerminalApp/java/com/android/virtualization/terminal/TerminalViewModel.kt
index 4a69f75..7c6aa15 100644
--- a/android/TerminalApp/java/com/android/virtualization/terminal/TerminalViewModel.kt
+++ b/android/TerminalApp/java/com/android/virtualization/terminal/TerminalViewModel.kt
@@ -19,4 +19,5 @@
 
 class TerminalViewModel : ViewModel() {
     val terminalViews: MutableSet<TerminalView> = mutableSetOf()
+    var selectedTabViewId: String? = null
 }
