diff --git a/libc/bionic/heap_tagging.cpp b/libc/bionic/heap_tagging.cpp
index 540372b..7601ddd 100644
--- a/libc/bionic/heap_tagging.cpp
+++ b/libc/bionic/heap_tagging.cpp
@@ -33,6 +33,10 @@
 #include <platform/bionic/malloc.h>
 #include <platform/bionic/mte_kernel.h>
 
+#include <bionic/pthread_internal.h>
+
+#include "private/ScopedPthreadMutexLocker.h"
+
 extern "C" void scudo_malloc_disable_memory_tagging();
 extern "C" void scudo_malloc_set_track_allocation_stacks(int);
 
@@ -68,7 +72,32 @@
 #endif  // aarch64
 }
 
+#ifdef ANDROID_EXPERIMENTAL_MTE
+static bool set_tcf_on_all_threads(int tcf) {
+  static int g_tcf;
+  g_tcf = tcf;
+
+  return android_run_on_all_threads(
+      [](void*) {
+        int tagged_addr_ctrl = prctl(PR_GET_TAGGED_ADDR_CTRL, 0, 0, 0, 0);
+        if (tagged_addr_ctrl < 0) {
+          return false;
+        }
+
+        tagged_addr_ctrl = (tagged_addr_ctrl & ~PR_MTE_TCF_MASK) | g_tcf;
+        if (prctl(PR_SET_TAGGED_ADDR_CTRL, tagged_addr_ctrl, 0, 0, 0) < 0) {
+          return false;
+        }
+        return true;
+      },
+      nullptr);
+}
+#endif
+
 bool SetHeapTaggingLevel(void* arg, size_t arg_size) {
+  static pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER;
+  ScopedPthreadMutexLocker locker(&mutex);
+
   if (arg_size != sizeof(HeapTaggingLevel)) {
     return false;
   }
@@ -80,9 +109,6 @@
 
   switch (tag_level) {
     case M_HEAP_TAGGING_LEVEL_NONE:
-#if defined(USE_SCUDO)
-      scudo_malloc_disable_memory_tagging();
-#endif
       if (heap_tagging_level == M_HEAP_TAGGING_LEVEL_TBI) {
         __libc_globals.mutate([](libc_globals* globals) {
           // Preserve the untag mask (we still want to untag pointers when passing them to the
@@ -90,7 +116,17 @@
           // tagged and checks no longer happen.
           globals->heap_pointer_tag = static_cast<uintptr_t>(0xffull << UNTAG_SHIFT);
         });
+      } else {
+#if defined(ANDROID_EXPERIMENTAL_MTE)
+        if (!set_tcf_on_all_threads(PR_MTE_TCF_NONE)) {
+          error_log("SetHeapTaggingLevel: set_tcf_on_all_threads failed");
+          return false;
+        }
+#endif
       }
+#if defined(USE_SCUDO)
+      scudo_malloc_disable_memory_tagging();
+#endif
       break;
     case M_HEAP_TAGGING_LEVEL_TBI:
     case M_HEAP_TAGGING_LEVEL_ASYNC:
@@ -106,10 +142,16 @@
       }
 
       if (tag_level == M_HEAP_TAGGING_LEVEL_ASYNC) {
+#if defined(ANDROID_EXPERIMENTAL_MTE)
+        set_tcf_on_all_threads(PR_MTE_TCF_ASYNC);
+#endif
 #if defined(USE_SCUDO)
         scudo_malloc_set_track_allocation_stacks(0);
 #endif
       } else if (tag_level == M_HEAP_TAGGING_LEVEL_SYNC) {
+#if defined(ANDROID_EXPERIMENTAL_MTE)
+        set_tcf_on_all_threads(PR_MTE_TCF_SYNC);
+#endif
 #if defined(USE_SCUDO)
         scudo_malloc_set_track_allocation_stacks(1);
 #endif
diff --git a/libc/bionic/memory_mitigation_state.cpp b/libc/bionic/memory_mitigation_state.cpp
index 82b0b7b..abb1e8d 100644
--- a/libc/bionic/memory_mitigation_state.cpp
+++ b/libc/bionic/memory_mitigation_state.cpp
@@ -36,36 +36,14 @@
 #include <sys/prctl.h>
 #include <sys/types.h>
 
+#include <bionic/malloc.h>
 #include <bionic/mte.h>
-#include <bionic/reserved_signals.h>
 
+#include "heap_tagging.h"
 #include "private/ScopedRWLock.h"
 #include "pthread_internal.h"
 
 extern "C" void scudo_malloc_set_zero_contents(int zero_contents);
-extern "C" void scudo_malloc_disable_memory_tagging();
-
-#ifdef ANDROID_EXPERIMENTAL_MTE
-static bool set_tcf_on_all_threads(int tcf) {
-  static int g_tcf;
-  g_tcf = tcf;
-
-  return android_run_on_all_threads(
-      [](void*) {
-        int tagged_addr_ctrl = prctl(PR_GET_TAGGED_ADDR_CTRL, 0, 0, 0, 0);
-        if (tagged_addr_ctrl < 0) {
-          return false;
-        }
-
-        tagged_addr_ctrl = (tagged_addr_ctrl & ~PR_MTE_TCF_MASK) | g_tcf;
-        if (prctl(PR_SET_TAGGED_ADDR_CTRL, tagged_addr_ctrl, 0, 0, 0) < 0) {
-          return false;
-        }
-        return true;
-      },
-      nullptr);
-}
-#endif
 
 bool DisableMemoryMitigations(void* arg, size_t arg_size) {
   if (arg || arg_size) {
@@ -74,13 +52,10 @@
 
 #ifdef USE_SCUDO
   scudo_malloc_set_zero_contents(0);
+#endif
 
-#ifdef ANDROID_EXPERIMENTAL_MTE
-  if (mte_supported() && set_tcf_on_all_threads(PR_MTE_TCF_NONE)) {
-    scudo_malloc_disable_memory_tagging();
-  }
-#endif
-#endif
+  HeapTaggingLevel level = M_HEAP_TAGGING_LEVEL_NONE;
+  SetHeapTaggingLevel(reinterpret_cast<void*>(&level), sizeof(level));
 
   return true;
 }
diff --git a/libc/platform/bionic/malloc.h b/libc/platform/bionic/malloc.h
index 16ef3a0..56badf0 100644
--- a/libc/platform/bionic/malloc.h
+++ b/libc/platform/bionic/malloc.h
@@ -85,8 +85,8 @@
   //   arg_size = sizeof(android_mallopt_leak_info_t)
   M_FREE_MALLOC_LEAK_INFO = 7,
 #define M_FREE_MALLOC_LEAK_INFO M_FREE_MALLOC_LEAK_INFO
-  // Change the heap tagging state. The program must be single threaded at the point when the
-  // android_mallopt function is called.
+  // Change the heap tagging state. May be called at any time including when
+  // multiple threads are running.
   //   arg = HeapTaggingLevel*
   //   arg_size = sizeof(HeapTaggingLevel)
   M_SET_HEAP_TAGGING_LEVEL = 8,
@@ -115,15 +115,17 @@
 };
 
 enum HeapTaggingLevel {
-  // Disable heap tagging. The program must use prctl(PR_SET_TAGGED_ADDR_CTRL) to disable memory tag
-  // checks before disabling heap tagging. Heap tagging may not be re-enabled after being disabled.
+  // Disable heap tagging and memory tag checks if supported. Heap tagging may not be re-enabled
+  // after being disabled.
   M_HEAP_TAGGING_LEVEL_NONE = 0,
   // Address-only tagging. Heap pointers have a non-zero tag in the most significant byte which is
   // checked in free(). Memory accesses ignore the tag.
   M_HEAP_TAGGING_LEVEL_TBI = 1,
-  // Enable heap tagging if supported, at a level appropriate for asynchronous memory tag checks.
+  // Enable heap tagging and asynchronous memory tag checks if supported. Disable stack trace
+  // collection.
   M_HEAP_TAGGING_LEVEL_ASYNC = 2,
-  // Enable heap tagging if supported, at a level appropriate for synchronous memory tag checks.
+  // Enable heap tagging and synchronous memory tag checks if supported. Enable stack trace
+  // collection.
   M_HEAP_TAGGING_LEVEL_SYNC = 3,
 };
 
diff --git a/tests/heap_tagging_level_test.cpp b/tests/heap_tagging_level_test.cpp
index 05123fd..1dbf455 100644
--- a/tests/heap_tagging_level_test.cpp
+++ b/tests/heap_tagging_level_test.cpp
@@ -77,21 +77,8 @@
 }
 
 #if defined(__BIONIC__) && defined(__aarch64__) && defined(ANDROID_EXPERIMENTAL_MTE)
-template <int SiCode> void CheckSiCode(int, siginfo_t* info, void*) {
-  if (info->si_code != SiCode) {
-    _exit(2);
-  }
-  _exit(1);
-}
-
-static bool SetTagCheckingLevel(int level) {
-  int tagged_addr_ctrl = prctl(PR_GET_TAGGED_ADDR_CTRL, 0, 0, 0, 0);
-  if (tagged_addr_ctrl < 0) {
-    return false;
-  }
-
-  tagged_addr_ctrl = (tagged_addr_ctrl & ~PR_MTE_TCF_MASK) | level;
-  return prctl(PR_SET_TAGGED_ADDR_CTRL, tagged_addr_ctrl, 0, 0, 0) == 0;
+void ExitWithSiCode(int, siginfo_t* info, void*) {
+  _exit(info->si_code);
 }
 #endif
 
@@ -108,30 +95,26 @@
   // mismatching tag before each allocation.
   EXPECT_EXIT(
       {
-        ScopedSignalHandler ssh(SIGSEGV, CheckSiCode<SEGV_MTEAERR>, SA_SIGINFO);
+        ScopedSignalHandler ssh(SIGSEGV, ExitWithSiCode, SA_SIGINFO);
         p[-1] = 42;
       },
-      testing::ExitedWithCode(1), "");
+      testing::ExitedWithCode(SEGV_MTEAERR), "");
 
-  EXPECT_TRUE(SetTagCheckingLevel(PR_MTE_TCF_SYNC));
+  EXPECT_TRUE(SetHeapTaggingLevel(M_HEAP_TAGGING_LEVEL_SYNC));
   EXPECT_EXIT(
       {
-        ScopedSignalHandler ssh(SIGSEGV, CheckSiCode<SEGV_MTESERR>, SA_SIGINFO);
+        ScopedSignalHandler ssh(SIGSEGV, ExitWithSiCode, SA_SIGINFO);
         p[-1] = 42;
       },
-      testing::ExitedWithCode(1), "");
+      testing::ExitedWithCode(SEGV_MTESERR), "");
 
-  EXPECT_TRUE(SetTagCheckingLevel(PR_MTE_TCF_NONE));
+  EXPECT_TRUE(SetHeapTaggingLevel(M_HEAP_TAGGING_LEVEL_NONE));
   volatile int oob ATTRIBUTE_UNUSED = p[-1];
 #endif
 }
 
 TEST(heap_tagging_level, none_pointers_untagged) {
 #if defined(__BIONIC__)
-#if defined(__aarch64__) && defined(ANDROID_EXPERIMENTAL_MTE)
-  EXPECT_TRUE(SetTagCheckingLevel(PR_MTE_TCF_NONE));
-#endif
-
   EXPECT_TRUE(SetHeapTaggingLevel(M_HEAP_TAGGING_LEVEL_NONE));
   std::unique_ptr<int[]> p = std::make_unique<int[]>(4);
   EXPECT_EQ(untag_address(p.get()), p.get());
@@ -146,10 +129,6 @@
     GTEST_SKIP() << "Kernel doesn't support tagged pointers.";
   }
 
-#if defined(ANDROID_EXPERIMENTAL_MTE)
-  EXPECT_TRUE(SetTagCheckingLevel(PR_MTE_TCF_NONE));
-#endif
-
   EXPECT_FALSE(SetHeapTaggingLevel(static_cast<HeapTaggingLevel>(12345)));
 
   if (mte_supported()) {
@@ -190,10 +169,6 @@
     GTEST_SKIP() << "requires MTE support";
   }
 
-#if defined(ANDROID_EXPERIMENTAL_MTE)
-  EXPECT_TRUE(SetTagCheckingLevel(PR_MTE_TCF_NONE));
-#endif
-
   EXPECT_TRUE(SetHeapTaggingLevel(M_HEAP_TAGGING_LEVEL_SYNC));
   EXPECT_TRUE(SetHeapTaggingLevel(M_HEAP_TAGGING_LEVEL_NONE));
 #else
