[MTE] split heap and stack MTE initialization

Bug: 315182011
Test: On both an MTE-enabled and non-MTE-enabled device:
Test: atest libprocinfo_test bionic-unit-tests bionic-unit-tests-static CtsGwpAsanTestCases gwp_asan_unittest debuggerd_test memtag_stack_dlopen_test
Change-Id: Idaacce32eb7c569237714171e96dc7392c565983
diff --git a/libc/bionic/libc_init_static.cpp b/libc/bionic/libc_init_static.cpp
index 7c46113..b54ef85 100644
--- a/libc/bionic/libc_init_static.cpp
+++ b/libc/bionic/libc_init_static.cpp
@@ -344,19 +344,12 @@
 // This function is called from the linker before the main executable is relocated.
 __attribute__((no_sanitize("hwaddress", "memtag"))) void __libc_init_mte(
     const memtag_dynamic_entries_t* memtag_dynamic_entries, const void* phdr_start, size_t phdr_ct,
-    uintptr_t load_bias, void* stack_top) {
+    uintptr_t load_bias) {
   bool memtag_stack = false;
   HeapTaggingLevel level =
       __get_tagging_level(memtag_dynamic_entries, phdr_start, phdr_ct, load_bias, &memtag_stack);
-  // initial_memtag_stack is used by the linker (in linker.cpp) to communicate than any library
-  // linked by this executable enables memtag-stack.
-  // memtag_stack is also set for static executables if they request memtag stack via the note,
-  // in which case it will differ from initial_memtag_stack.
-  if (__libc_shared_globals()->initial_memtag_stack || memtag_stack) {
-    memtag_stack = true;
-    __libc_shared_globals()->initial_memtag_stack_abi = true;
-    __get_bionic_tcb()->tls_slot(TLS_SLOT_STACK_MTE) = __allocate_stack_mte_ringbuffer(0, nullptr);
-  }
+  if (memtag_stack) __libc_shared_globals()->initial_memtag_stack_abi = true;
+
   if (int64_t timed_upgrade = __get_memtag_upgrade_secs()) {
     if (level == M_HEAP_TAGGING_LEVEL_ASYNC) {
       async_safe_format_log(ANDROID_LOG_INFO, "libc",
@@ -380,15 +373,7 @@
     if (prctl(PR_SET_TAGGED_ADDR_CTRL, prctl_arg | PR_MTE_TCF_SYNC, 0, 0, 0) == 0 ||
         prctl(PR_SET_TAGGED_ADDR_CTRL, prctl_arg, 0, 0, 0) == 0) {
       __libc_shared_globals()->initial_heap_tagging_level = level;
-      __libc_shared_globals()->initial_memtag_stack = memtag_stack;
 
-      if (memtag_stack) {
-        void* pg_start =
-            reinterpret_cast<void*>(page_start(reinterpret_cast<uintptr_t>(stack_top)));
-        if (mprotect(pg_start, page_size(), PROT_READ | PROT_WRITE | PROT_MTE | PROT_GROWSDOWN)) {
-          async_safe_fatal("error: failed to set PROT_MTE on main thread stack: %m");
-        }
-      }
       struct sigaction action = {};
       action.sa_flags = SA_SIGINFO | SA_RESTART;
       action.sa_sigaction = __enable_mte_signal_handler;
@@ -404,11 +389,34 @@
   }
   // We did not enable MTE, so we do not need to arm the upgrade timer.
   __libc_shared_globals()->heap_tagging_upgrade_timer_sec = 0;
-  // We also didn't enable memtag_stack.
-  __libc_shared_globals()->initial_memtag_stack = false;
 }
+
+// Figure out whether we need to map the stack as PROT_MTE.
+// For dynamic executables, this has to be called after loading all
+// DT_NEEDED libraries, in case one of them needs stack MTE.
+__attribute__((no_sanitize("hwaddress", "memtag"))) void __libc_init_mte_stack(void* stack_top) {
+  if (!__libc_shared_globals()->initial_memtag_stack_abi) {
+    return;
+  }
+
+  // Even if the device doesn't support MTE, we have to allocate stack
+  // history buffers for code compiled for stack MTE. That is because the
+  // codegen expects a buffer to be present in TLS_SLOT_STACK_MTE either
+  // way.
+  __get_bionic_tcb()->tls_slot(TLS_SLOT_STACK_MTE) = __allocate_stack_mte_ringbuffer(0, nullptr);
+
+  if (__libc_mte_enabled()) {
+    __libc_shared_globals()->initial_memtag_stack = true;
+    void* pg_start = reinterpret_cast<void*>(page_start(reinterpret_cast<uintptr_t>(stack_top)));
+    if (mprotect(pg_start, page_size(), PROT_READ | PROT_WRITE | PROT_MTE | PROT_GROWSDOWN)) {
+      async_safe_fatal("error: failed to set PROT_MTE on main thread stack: %m");
+    }
+  }
+}
+
 #else   // __aarch64__
-void __libc_init_mte(const memtag_dynamic_entries_t*, const void*, size_t, uintptr_t, void*) {}
+void __libc_init_mte(const memtag_dynamic_entries_t*, const void*, size_t, uintptr_t) {}
+void __libc_init_mte_stack(void*) {}
 #endif  // __aarch64__
 
 void __libc_init_profiling_handlers() {
@@ -436,7 +444,8 @@
   __libc_init_common();
   __libc_init_mte(/*memtag_dynamic_entries=*/nullptr,
                   reinterpret_cast<ElfW(Phdr)*>(getauxval(AT_PHDR)), getauxval(AT_PHNUM),
-                  /*load_bias = */ 0, /*stack_top = */ raw_args);
+                  /*load_bias = */ 0);
+  __libc_init_mte_stack(/*stack_top = */ raw_args);
   __libc_init_scudo();
   __libc_init_profiling_handlers();
   __libc_init_fork_handler();
@@ -511,3 +520,8 @@
   static libc_shared_globals globals;
   return &globals;
 }
+
+__LIBC_HIDDEN__ bool __libc_mte_enabled() {
+  HeapTaggingLevel lvl = __libc_shared_globals()->initial_heap_tagging_level;
+  return lvl == M_HEAP_TAGGING_LEVEL_SYNC || lvl == M_HEAP_TAGGING_LEVEL_ASYNC;
+}