Move memtag_stack out of libc_globals

We cannot use a WriteProtected because we are accessing it in a
multithreaded context.

Test: atest memtag_stack_dlopen_test w/ MTE
Test: atest bionic-unit-tests w/ MTE
Test: atest bionic-unit-tests on _fullmte
Bug: 328256432
Change-Id: I39faa75f97fd5b3fb755a46e88346c17c0e9a8e2
diff --git a/libc/bionic/heap_tagging.cpp b/libc/bionic/heap_tagging.cpp
index b6c9fe5..c4347e8 100644
--- a/libc/bionic/heap_tagging.cpp
+++ b/libc/bionic/heap_tagging.cpp
@@ -58,7 +58,7 @@
       case M_HEAP_TAGGING_LEVEL_SYNC:
       case M_HEAP_TAGGING_LEVEL_ASYNC:
         atomic_store(&globals->memtag, true);
-        atomic_store(&globals->memtag_stack, __libc_shared_globals()->initial_memtag_stack);
+        atomic_store(&__libc_memtag_stack, __libc_shared_globals()->initial_memtag_stack);
         break;
       default:
         break;
@@ -113,7 +113,7 @@
           // tagged and checks no longer happen.
           globals->heap_pointer_tag = static_cast<uintptr_t>(0xffull << UNTAG_SHIFT);
         }
-        atomic_store(&globals->memtag_stack, false);
+        atomic_store(&__libc_memtag_stack, false);
         atomic_store(&globals->memtag, false);
       });
 
@@ -207,7 +207,7 @@
   // │memtag_handle_longjmp│                  │
   // └─────────────────────┘                  ▼
 #ifdef __aarch64__
-  if (__libc_globals->memtag_stack) {
+  if (atomic_load(&__libc_memtag_stack)) {
     size_t distance = reinterpret_cast<uintptr_t>(sp_dst) - reinterpret_cast<uintptr_t>(sp_src);
     if (distance > kUntagLimit) {
       async_safe_fatal(
diff --git a/libc/bionic/libc_init_common.cpp b/libc/bionic/libc_init_common.cpp
index 51f7ce9..944098f 100644
--- a/libc/bionic/libc_init_common.cpp
+++ b/libc/bionic/libc_init_common.cpp
@@ -57,6 +57,7 @@
 extern "C" void scudo_malloc_set_pattern_fill_contents(int);
 
 __LIBC_HIDDEN__ constinit WriteProtected<libc_globals> __libc_globals;
+__LIBC_HIDDEN__ constinit _Atomic(bool) __libc_memtag_stack;
 
 // Not public, but well-known in the BSDs.
 __BIONIC_WEAK_VARIABLE_FOR_NATIVE_BRIDGE
diff --git a/libc/bionic/malloc_common.cpp b/libc/bionic/malloc_common.cpp
index 3c4884b..9932e3e 100644
--- a/libc/bionic/malloc_common.cpp
+++ b/libc/bionic/malloc_common.cpp
@@ -353,7 +353,7 @@
       errno = EINVAL;
       return false;
     }
-    *reinterpret_cast<bool*>(arg) = atomic_load(&__libc_globals->memtag_stack);
+    *reinterpret_cast<bool*>(arg) = atomic_load(&__libc_memtag_stack);
     return true;
   }
   if (opcode == M_GET_DECAY_TIME_ENABLED) {
diff --git a/libc/bionic/malloc_common_dynamic.cpp b/libc/bionic/malloc_common_dynamic.cpp
index a6bf7a7..8858178 100644
--- a/libc/bionic/malloc_common_dynamic.cpp
+++ b/libc/bionic/malloc_common_dynamic.cpp
@@ -542,7 +542,7 @@
       errno = EINVAL;
       return false;
     }
-    *reinterpret_cast<bool*>(arg) = atomic_load(&__libc_globals->memtag_stack);
+    *reinterpret_cast<bool*>(arg) = atomic_load(&__libc_memtag_stack);
     return true;
   }
   if (opcode == M_GET_DECAY_TIME_ENABLED) {
diff --git a/libc/bionic/pthread_create.cpp b/libc/bionic/pthread_create.cpp
index 194db18..5bd4f16 100644
--- a/libc/bionic/pthread_create.cpp
+++ b/libc/bionic/pthread_create.cpp
@@ -91,7 +91,7 @@
   // Create and set an alternate signal stack.
   int prot = PROT_READ | PROT_WRITE;
 #ifdef __aarch64__
-  if (atomic_load(&__libc_globals->memtag_stack)) {
+  if (atomic_load(&__libc_memtag_stack)) {
     prot |= PROT_MTE;
   }
 #endif
@@ -237,7 +237,7 @@
   int prot = PROT_READ | PROT_WRITE;
   const char* prot_str = "R+W";
 #ifdef __aarch64__
-  if (atomic_load(&__libc_globals->memtag_stack)) {
+  if (atomic_load(&__libc_memtag_stack)) {
     prot |= PROT_MTE;
     prot_str = "R+W+MTE";
   }
diff --git a/libc/bionic/pthread_internal.cpp b/libc/bionic/pthread_internal.cpp
index bfe2f98..2342aff 100644
--- a/libc/bionic/pthread_internal.cpp
+++ b/libc/bionic/pthread_internal.cpp
@@ -179,10 +179,8 @@
 void __pthread_internal_remap_stack_with_mte() {
 #if defined(__aarch64__)
   // If process doesn't have MTE enabled, we don't need to do anything.
-  if (!__libc_globals->memtag) return;
-  bool prev = true;
-  __libc_globals.mutate(
-      [&prev](libc_globals* globals) { prev = atomic_exchange(&globals->memtag_stack, true); });
+  if (!atomic_load(&__libc_globals->memtag)) return;
+  bool prev = atomic_exchange(&__libc_memtag_stack, true);
   if (prev) return;
   uintptr_t lo, hi;
   __find_main_stack_limits(&lo, &hi);