Reland "[MTE] remap stacks with PROT_MTE when requested by dlopened library"

This reverts commit c20e1c2bdfc00b3fb7931da5e88ceac3fa4df4b2.

Reason for revert: Was not the root-cause of test failure.

Change-Id: I7dcd9fc3cbac47703fa8ecd5aafd7e1c3ed87301
diff --git a/libc/bionic/heap_tagging.cpp b/libc/bionic/heap_tagging.cpp
index 230899a..6fdd1e0 100644
--- a/libc/bionic/heap_tagging.cpp
+++ b/libc/bionic/heap_tagging.cpp
@@ -57,6 +57,7 @@
         break;
       case M_HEAP_TAGGING_LEVEL_SYNC:
       case M_HEAP_TAGGING_LEVEL_ASYNC:
+        atomic_store(&globals->memtag, true);
         atomic_store(&globals->memtag_stack, __libc_shared_globals()->initial_memtag_stack);
         break;
       default:
@@ -116,6 +117,7 @@
           globals->heap_pointer_tag = static_cast<uintptr_t>(0xffull << UNTAG_SHIFT);
         }
         atomic_store(&globals->memtag_stack, false);
+        atomic_store(&globals->memtag, false);
       });
 
       if (heap_tagging_level != M_HEAP_TAGGING_LEVEL_TBI) {
diff --git a/libc/bionic/libc_init_dynamic.cpp b/libc/bionic/libc_init_dynamic.cpp
index c61810e..295484b 100644
--- a/libc/bionic/libc_init_dynamic.cpp
+++ b/libc/bionic/libc_init_dynamic.cpp
@@ -39,11 +39,12 @@
  *   all dynamic linking has been performed.
  */
 
+#include <elf.h>
 #include <stddef.h>
+#include <stdint.h>
 #include <stdio.h>
 #include <stdlib.h>
-#include <stdint.h>
-#include <elf.h>
+#include "bionic/pthread_internal.h"
 #include "libc_init_common.h"
 
 #include "private/bionic_defs.h"
@@ -59,6 +60,10 @@
   extern int __cxa_atexit(void (*)(void *), void *, void *);
 };
 
+void memtag_stack_dlopen_callback() {
+  __pthread_internal_remap_stack_with_mte();
+}
+
 // Use an initializer so __libc_sysinfo will have a fallback implementation
 // while .preinit_array constructors run.
 #if defined(__i386__)
@@ -156,6 +161,10 @@
 
   __libc_init_mte_late();
 
+  // This roundabout way is needed so we don't use the static libc linked into the linker, which
+  // will not affect the process.
+  __libc_shared_globals()->memtag_stack_dlopen_callback = memtag_stack_dlopen_callback;
+
   exit(slingshot(args.argc - __libc_shared_globals()->initial_linker_arg_count,
                  args.argv + __libc_shared_globals()->initial_linker_arg_count,
                  args.envp));
diff --git a/libc/bionic/pthread_attr.cpp b/libc/bionic/pthread_attr.cpp
index de4cc9e..f6c0401 100644
--- a/libc/bionic/pthread_attr.cpp
+++ b/libc/bionic/pthread_attr.cpp
@@ -155,36 +155,6 @@
   return 0;
 }
 
-static uintptr_t __get_main_stack_startstack() {
-  FILE* fp = fopen("/proc/self/stat", "re");
-  if (fp == nullptr) {
-    async_safe_fatal("couldn't open /proc/self/stat: %m");
-  }
-
-  char line[BUFSIZ];
-  if (fgets(line, sizeof(line), fp) == nullptr) {
-    async_safe_fatal("couldn't read /proc/self/stat: %m");
-  }
-
-  fclose(fp);
-
-  // See man 5 proc. There's no reason comm can't contain ' ' or ')',
-  // so we search backwards for the end of it. We're looking for this field:
-  //
-  //  startstack %lu (28) The address of the start (i.e., bottom) of the stack.
-  uintptr_t startstack = 0;
-  const char* end_of_comm = strrchr(line, ')');
-  if (sscanf(end_of_comm + 1, " %*c "
-             "%*d %*d %*d %*d %*d "
-             "%*u %*u %*u %*u %*u %*u %*u "
-             "%*d %*d %*d %*d %*d %*d "
-             "%*u %*u %*d %*u %*u %*u %" SCNuPTR, &startstack) != 1) {
-    async_safe_fatal("couldn't parse /proc/self/stat");
-  }
-
-  return startstack;
-}
-
 static int __pthread_attr_getstack_main_thread(void** stack_base, size_t* stack_size) {
   ErrnoRestorer errno_restorer;
 
@@ -198,28 +168,11 @@
   if (stack_limit.rlim_cur == RLIM_INFINITY) {
     stack_limit.rlim_cur = 8 * 1024 * 1024;
   }
-
-  // Ask the kernel where our main thread's stack started.
-  uintptr_t startstack = __get_main_stack_startstack();
-
-  // Hunt for the region that contains that address.
-  FILE* fp = fopen("/proc/self/maps", "re");
-  if (fp == nullptr) {
-    async_safe_fatal("couldn't open /proc/self/maps: %m");
-  }
-  char line[BUFSIZ];
-  while (fgets(line, sizeof(line), fp) != nullptr) {
-    uintptr_t lo, hi;
-    if (sscanf(line, "%" SCNxPTR "-%" SCNxPTR, &lo, &hi) == 2) {
-      if (lo <= startstack && startstack <= hi) {
-        *stack_size = stack_limit.rlim_cur;
-        *stack_base = reinterpret_cast<void*>(hi - *stack_size);
-        fclose(fp);
-        return 0;
-      }
-    }
-  }
-  async_safe_fatal("stack not found in /proc/self/maps");
+  uintptr_t lo, hi;
+  __find_main_stack_limits(&lo, &hi);
+  *stack_size = stack_limit.rlim_cur;
+  *stack_base = reinterpret_cast<void*>(hi - *stack_size);
+  return 0;
 }
 
 __BIONIC_WEAK_FOR_NATIVE_BRIDGE
diff --git a/libc/bionic/pthread_internal.cpp b/libc/bionic/pthread_internal.cpp
index 6a7ee2f..bfe2f98 100644
--- a/libc/bionic/pthread_internal.cpp
+++ b/libc/bionic/pthread_internal.cpp
@@ -40,6 +40,7 @@
 #include "private/ErrnoRestorer.h"
 #include "private/ScopedRWLock.h"
 #include "private/bionic_futex.h"
+#include "private/bionic_globals.h"
 #include "private/bionic_tls.h"
 
 static pthread_internal_t* g_thread_list = nullptr;
@@ -119,6 +120,89 @@
   return nullptr;
 }
 
+static uintptr_t __get_main_stack_startstack() {
+  FILE* fp = fopen("/proc/self/stat", "re");
+  if (fp == nullptr) {
+    async_safe_fatal("couldn't open /proc/self/stat: %m");
+  }
+
+  char line[BUFSIZ];
+  if (fgets(line, sizeof(line), fp) == nullptr) {
+    async_safe_fatal("couldn't read /proc/self/stat: %m");
+  }
+
+  fclose(fp);
+
+  // See man 5 proc. There's no reason comm can't contain ' ' or ')',
+  // so we search backwards for the end of it. We're looking for this field:
+  //
+  //  startstack %lu (28) The address of the start (i.e., bottom) of the stack.
+  uintptr_t startstack = 0;
+  const char* end_of_comm = strrchr(line, ')');
+  if (sscanf(end_of_comm + 1,
+             " %*c "
+             "%*d %*d %*d %*d %*d "
+             "%*u %*u %*u %*u %*u %*u %*u "
+             "%*d %*d %*d %*d %*d %*d "
+             "%*u %*u %*d %*u %*u %*u %" SCNuPTR,
+             &startstack) != 1) {
+    async_safe_fatal("couldn't parse /proc/self/stat");
+  }
+
+  return startstack;
+}
+
+void __find_main_stack_limits(uintptr_t* low, uintptr_t* high) {
+  // Ask the kernel where our main thread's stack started.
+  uintptr_t startstack = __get_main_stack_startstack();
+
+  // Hunt for the region that contains that address.
+  FILE* fp = fopen("/proc/self/maps", "re");
+  if (fp == nullptr) {
+    async_safe_fatal("couldn't open /proc/self/maps: %m");
+  }
+  char line[BUFSIZ];
+  while (fgets(line, sizeof(line), fp) != nullptr) {
+    uintptr_t lo, hi;
+    if (sscanf(line, "%" SCNxPTR "-%" SCNxPTR, &lo, &hi) == 2) {
+      if (lo <= startstack && startstack <= hi) {
+        *low = lo;
+        *high = hi;
+        fclose(fp);
+        return;
+      }
+    }
+  }
+  async_safe_fatal("stack not found in /proc/self/maps");
+}
+
+void __pthread_internal_remap_stack_with_mte() {
+#if defined(__aarch64__)
+  // If process doesn't have MTE enabled, we don't need to do anything.
+  if (!__libc_globals->memtag) return;
+  bool prev = true;
+  __libc_globals.mutate(
+      [&prev](libc_globals* globals) { prev = atomic_exchange(&globals->memtag_stack, true); });
+  if (prev) return;
+  uintptr_t lo, hi;
+  __find_main_stack_limits(&lo, &hi);
+
+  if (mprotect(reinterpret_cast<void*>(lo), hi - lo,
+               PROT_READ | PROT_WRITE | PROT_MTE | PROT_GROWSDOWN)) {
+    async_safe_fatal("error: failed to set PROT_MTE on main thread");
+  }
+  ScopedWriteLock creation_locker(&g_thread_creation_lock);
+  ScopedReadLock list_locker(&g_thread_list_lock);
+  for (pthread_internal_t* t = g_thread_list; t != nullptr; t = t->next) {
+    if (t->terminating || t->is_main()) continue;
+    if (mprotect(t->mmap_base_unguarded, t->mmap_size_unguarded,
+                 PROT_READ | PROT_WRITE | PROT_MTE)) {
+      async_safe_fatal("error: failed to set PROT_MTE on thread: %d", t->tid);
+    }
+  }
+#endif
+}
+
 bool android_run_on_all_threads(bool (*func)(void*), void* arg) {
   // Take the locks in this order to avoid inversion (pthread_create ->
   // __pthread_internal_add).
diff --git a/libc/bionic/pthread_internal.h b/libc/bionic/pthread_internal.h
index 3b9e6a4..091f711 100644
--- a/libc/bionic/pthread_internal.h
+++ b/libc/bionic/pthread_internal.h
@@ -178,6 +178,7 @@
   bionic_tls* bionic_tls;
 
   int errno_value;
+  bool is_main() { return start_routine == nullptr; }
 };
 
 struct ThreadMapping {
@@ -207,6 +208,7 @@
 __LIBC_HIDDEN__ pid_t __pthread_internal_gettid(pthread_t pthread_id, const char* caller);
 __LIBC_HIDDEN__ void __pthread_internal_remove(pthread_internal_t* thread);
 __LIBC_HIDDEN__ void __pthread_internal_remove_and_free(pthread_internal_t* thread);
+__LIBC_HIDDEN__ void __find_main_stack_limits(uintptr_t* low, uintptr_t* high);
 
 static inline __always_inline bionic_tcb* __get_bionic_tcb() {
   return reinterpret_cast<bionic_tcb*>(&__get_tls()[MIN_TLS_SLOT]);
@@ -266,6 +268,9 @@
 __LIBC_HIDDEN__ extern void __bionic_atfork_run_child();
 __LIBC_HIDDEN__ extern void __bionic_atfork_run_parent();
 
+// Re-map all threads and successively launched threads with PROT_MTE.
+__LIBC_HIDDEN__ void __pthread_internal_remap_stack_with_mte();
+
 extern "C" bool android_run_on_all_threads(bool (*func)(void*), void* arg);
 
 extern pthread_rwlock_t g_thread_creation_lock;