Basic support for MTE stack tagging.

Map all stacks (primary, thread, and sigaltstack) as PROT_MTE when the
binary requests it through the ELF note.

For the reference, the note is produced by the following toolchain changes:
https://reviews.llvm.org/D118948
https://reviews.llvm.org/D119384
https://reviews.llvm.org/D119381

Bug: b/174878242
Test: fvp_mini with ToT LLVM (more tests in a separate change)

Change-Id: I04a4e21c966e7309b47b1f549a2919958d93a872
diff --git a/libc/bionic/heap_tagging.cpp b/libc/bionic/heap_tagging.cpp
index f84ba7b..d10d778 100644
--- a/libc/bionic/heap_tagging.cpp
+++ b/libc/bionic/heap_tagging.cpp
@@ -44,30 +44,37 @@
 #if !__has_feature(hwaddress_sanitizer)
   heap_tagging_level = __libc_shared_globals()->initial_heap_tagging_level;
 #endif
-  switch (heap_tagging_level) {
-    case M_HEAP_TAGGING_LEVEL_TBI:
-      __libc_globals.mutate([](libc_globals* globals) {
+
+  __libc_globals.mutate([](libc_globals* globals) {
+    switch (heap_tagging_level) {
+      case M_HEAP_TAGGING_LEVEL_TBI:
         // Arrange for us to set pointer tags to POINTER_TAG, check tags on
         // deallocation and untag when passing pointers to the allocator.
         globals->heap_pointer_tag = (reinterpret_cast<uintptr_t>(POINTER_TAG) << TAG_SHIFT) |
                                     (0xffull << CHECK_SHIFT) | (0xffull << UNTAG_SHIFT);
-      });
-#if defined(USE_SCUDO)
-      scudo_malloc_disable_memory_tagging();
-#endif  // USE_SCUDO
-      break;
-#if defined(USE_SCUDO)
-    case M_HEAP_TAGGING_LEVEL_SYNC:
-      scudo_malloc_set_track_allocation_stacks(1);
-      break;
+        break;
+      case M_HEAP_TAGGING_LEVEL_SYNC:
+      case M_HEAP_TAGGING_LEVEL_ASYNC:
+        atomic_store(&globals->memtag_stack, __libc_shared_globals()->initial_memtag_stack);
+        break;
+      default:
+        break;
+    };
+  });
 
+#if defined(USE_SCUDO)
+  switch (heap_tagging_level) {
+    case M_HEAP_TAGGING_LEVEL_TBI:
     case M_HEAP_TAGGING_LEVEL_NONE:
       scudo_malloc_disable_memory_tagging();
       break;
-#endif  // USE_SCUDO
+    case M_HEAP_TAGGING_LEVEL_SYNC:
+      scudo_malloc_set_track_allocation_stacks(1);
+      break;
     default:
       break;
   }
+#endif  // USE_SCUDO
 #endif  // aarch64
 }
 
@@ -104,16 +111,21 @@
 
   switch (tag_level) {
     case M_HEAP_TAGGING_LEVEL_NONE:
-      if (heap_tagging_level == M_HEAP_TAGGING_LEVEL_TBI) {
-        __libc_globals.mutate([](libc_globals* globals) {
+      __libc_globals.mutate([](libc_globals* globals) {
+        if (heap_tagging_level == M_HEAP_TAGGING_LEVEL_TBI) {
           // Preserve the untag mask (we still want to untag pointers when passing them to the
           // allocator), but clear the fixed tag and the check mask, so that pointers are no longer
           // tagged and checks no longer happen.
           globals->heap_pointer_tag = static_cast<uintptr_t>(0xffull << UNTAG_SHIFT);
-        });
-      } else if (!set_tcf_on_all_threads(PR_MTE_TCF_NONE)) {
-        error_log("SetHeapTaggingLevel: set_tcf_on_all_threads failed");
-        return false;
+        }
+        atomic_store(&globals->memtag_stack, false);
+      });
+
+      if (heap_tagging_level != M_HEAP_TAGGING_LEVEL_TBI) {
+        if (!set_tcf_on_all_threads(PR_MTE_TCF_NONE)) {
+          error_log("SetHeapTaggingLevel: set_tcf_on_all_threads failed");
+          return false;
+        }
       }
 #if defined(USE_SCUDO)
       scudo_malloc_disable_memory_tagging();
diff --git a/libc/bionic/libc_init_static.cpp b/libc/bionic/libc_init_static.cpp
index 575da62..66aaeaa 100644
--- a/libc/bionic/libc_init_static.cpp
+++ b/libc/bionic/libc_init_static.cpp
@@ -259,17 +259,18 @@
 // M_HEAP_TAGGING_LEVEL_NONE, if MTE isn't enabled for this process we enable
 // M_HEAP_TAGGING_LEVEL_TBI.
 static HeapTaggingLevel __get_heap_tagging_level(const void* phdr_start, size_t phdr_ct,
-                                                 uintptr_t load_bias) {
+                                                 uintptr_t load_bias, bool* stack) {
+  unsigned note_val =
+      __get_memtag_note(reinterpret_cast<const ElfW(Phdr)*>(phdr_start), phdr_ct, load_bias);
+  *stack = note_val & NT_MEMTAG_STACK;
+
   HeapTaggingLevel level;
   if (get_environment_memtag_setting(&level)) return level;
 
-  unsigned note_val =
-      __get_memtag_note(reinterpret_cast<const ElfW(Phdr)*>(phdr_start), phdr_ct, load_bias);
-
   // Note, previously (in Android 12), any value outside of bits [0..3] resulted
   // in a check-fail. In order to be permissive of further extensions, we
-  // relaxed this restriction. For now, we still only support MTE heap.
-  if (!(note_val & NT_MEMTAG_HEAP)) return M_HEAP_TAGGING_LEVEL_TBI;
+  // relaxed this restriction.
+  if (!(note_val & (NT_MEMTAG_HEAP | NT_MEMTAG_STACK))) return M_HEAP_TAGGING_LEVEL_TBI;
 
   unsigned mode = note_val & NT_MEMTAG_LEVEL_MASK;
   switch (mode) {
@@ -295,8 +296,10 @@
 // This function is called from the linker before the main executable is relocated.
 __attribute__((no_sanitize("hwaddress", "memtag"))) void __libc_init_mte(const void* phdr_start,
                                                                          size_t phdr_ct,
-                                                                         uintptr_t load_bias) {
-  HeapTaggingLevel level = __get_heap_tagging_level(phdr_start, phdr_ct, load_bias);
+                                                                         uintptr_t load_bias,
+                                                                         void* stack_top) {
+  bool memtag_stack;
+  HeapTaggingLevel level = __get_heap_tagging_level(phdr_start, phdr_ct, load_bias, &memtag_stack);
 
   if (level == M_HEAP_TAGGING_LEVEL_SYNC || level == M_HEAP_TAGGING_LEVEL_ASYNC) {
     unsigned long prctl_arg = PR_TAGGED_ADDR_ENABLE | PR_MTE_TAG_SET_NONZERO;
@@ -308,6 +311,17 @@
     if (prctl(PR_SET_TAGGED_ADDR_CTRL, prctl_arg | PR_MTE_TCF_SYNC, 0, 0, 0) == 0 ||
         prctl(PR_SET_TAGGED_ADDR_CTRL, prctl_arg, 0, 0, 0) == 0) {
       __libc_shared_globals()->initial_heap_tagging_level = level;
+      __libc_shared_globals()->initial_memtag_stack = memtag_stack;
+
+      if (memtag_stack) {
+        void* page_start =
+            reinterpret_cast<void*>(PAGE_START(reinterpret_cast<uintptr_t>(stack_top)));
+        if (mprotect(page_start, PAGE_SIZE, PROT_READ | PROT_WRITE | PROT_MTE | PROT_GROWSDOWN)) {
+          async_safe_fatal("error: failed to set PROT_MTE on main thread stack: %s\n",
+                           strerror(errno));
+        }
+      }
+
       return;
     }
   }
@@ -319,7 +333,7 @@
   }
 }
 #else   // __aarch64__
-void __libc_init_mte(const void*, size_t, uintptr_t) {}
+void __libc_init_mte(const void*, size_t, uintptr_t, void*) {}
 #endif  // __aarch64__
 
 void __libc_init_profiling_handlers() {
@@ -331,11 +345,9 @@
   signal(BIONIC_SIGNAL_ART_PROFILER, SIG_IGN);
 }
 
-__noreturn static void __real_libc_init(void *raw_args,
-                                        void (*onexit)(void) __unused,
-                                        int (*slingshot)(int, char**, char**),
-                                        structors_array_t const * const structors,
-                                        bionic_tcb* temp_tcb) {
+__attribute__((no_sanitize("memtag"))) __noreturn static void __real_libc_init(
+    void* raw_args, void (*onexit)(void) __unused, int (*slingshot)(int, char**, char**),
+    structors_array_t const* const structors, bionic_tcb* temp_tcb) {
   BIONIC_STOP_UNWIND;
 
   // Initialize TLS early so system calls and errno work.
@@ -349,7 +361,7 @@
   __libc_init_main_thread_final();
   __libc_init_common();
   __libc_init_mte(reinterpret_cast<ElfW(Phdr)*>(getauxval(AT_PHDR)), getauxval(AT_PHNUM),
-                  /*load_bias = */ 0);
+                  /*load_bias = */ 0, /*stack_top = */ raw_args);
   __libc_init_scudo();
   __libc_init_profiling_handlers();
   __libc_init_fork_handler();
@@ -379,11 +391,9 @@
 //
 // The 'structors' parameter contains pointers to various initializer
 // arrays that must be run before the program's 'main' routine is launched.
-__attribute__((no_sanitize("hwaddress")))
-__noreturn void __libc_init(void* raw_args,
-                            void (*onexit)(void) __unused,
-                            int (*slingshot)(int, char**, char**),
-                            structors_array_t const * const structors) {
+__attribute__((no_sanitize("hwaddress", "memtag"))) __noreturn void __libc_init(
+    void* raw_args, void (*onexit)(void) __unused, int (*slingshot)(int, char**, char**),
+    structors_array_t const* const structors) {
   bionic_tcb temp_tcb = {};
 #if __has_feature(hwaddress_sanitizer)
   // Install main thread TLS early. It will be initialized later in __libc_init_main_thread. For now
diff --git a/libc/bionic/pthread_create.cpp b/libc/bionic/pthread_create.cpp
index 121b26f..417ce76 100644
--- a/libc/bionic/pthread_create.cpp
+++ b/libc/bionic/pthread_create.cpp
@@ -40,15 +40,16 @@
 
 #include <async_safe/log.h>
 
+#include "platform/bionic/macros.h"
+#include "platform/bionic/mte.h"
+#include "private/ErrnoRestorer.h"
 #include "private/ScopedRWLock.h"
 #include "private/bionic_constants.h"
 #include "private/bionic_defs.h"
 #include "private/bionic_globals.h"
-#include "platform/bionic/macros.h"
 #include "private/bionic_ssp.h"
 #include "private/bionic_systrace.h"
 #include "private/bionic_tls.h"
-#include "private/ErrnoRestorer.h"
 
 // x86 uses segment descriptors rather than a direct pointer to TLS.
 #if defined(__i386__)
@@ -88,7 +89,13 @@
 
 static void __init_alternate_signal_stack(pthread_internal_t* thread) {
   // Create and set an alternate signal stack.
-  void* stack_base = mmap(nullptr, SIGNAL_STACK_SIZE, PROT_READ|PROT_WRITE, MAP_PRIVATE|MAP_ANONYMOUS, -1, 0);
+  int prot = PROT_READ | PROT_WRITE;
+#ifdef __aarch64__
+  if (atomic_load(&__libc_globals->memtag_stack)) {
+    prot |= PROT_MTE;
+  }
+#endif
+  void* stack_base = mmap(nullptr, SIGNAL_STACK_SIZE, prot, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
   if (stack_base != MAP_FAILED) {
     // Create a guard to catch stack overflows in signal handlers.
     if (mprotect(stack_base, PTHREAD_GUARD_SIZE, PROT_NONE) == -1) {
@@ -224,12 +231,19 @@
     return {};
   }
   const size_t writable_size = mmap_size - stack_guard_size - PTHREAD_GUARD_SIZE;
-  if (mprotect(space + stack_guard_size,
-               writable_size,
-               PROT_READ | PROT_WRITE) != 0) {
-    async_safe_format_log(ANDROID_LOG_WARN, "libc",
-                          "pthread_create failed: couldn't mprotect R+W %zu-byte thread mapping region: %s",
-                          writable_size, strerror(errno));
+  int prot = PROT_READ | PROT_WRITE;
+  const char* prot_str = "R+W";
+#ifdef __aarch64__
+  if (atomic_load(&__libc_globals->memtag_stack)) {
+    prot |= PROT_MTE;
+    prot_str = "R+W+MTE";
+  }
+#endif
+  if (mprotect(space + stack_guard_size, writable_size, prot) != 0) {
+    async_safe_format_log(
+        ANDROID_LOG_WARN, "libc",
+        "pthread_create failed: couldn't mprotect %s %zu-byte thread mapping region: %s", prot_str,
+        writable_size, strerror(errno));
     munmap(space, mmap_size);
     return {};
   }