Make memtag_handle_longjmp precise

We would get the SP inside of memtag_handle_longjmp, which could prevent
us from detecting the case where a longjmp is going into a function that
had already returned. This changes makes the behaviour more predictable.

Change-Id: I75bf931c8f4129a2f38001156b7bbe0b54a726ee
diff --git a/libc/arch-arm64/bionic/setjmp.S b/libc/arch-arm64/bionic/setjmp.S
index 178c4c8..c408998 100644
--- a/libc/arch-arm64/bionic/setjmp.S
+++ b/libc/arch-arm64/bionic/setjmp.S
@@ -201,6 +201,7 @@
   bic x2, x2, #1
   ldr x0, [x0, #(_JB_X30_SP  * 8 + 8)]
   eor x0, x0, x2
+  add x1, sp, #16
   bl memtag_handle_longjmp
 
   mov x1, x19 // Restore 'value'.
diff --git a/libc/bionic/heap_tagging.cpp b/libc/bionic/heap_tagging.cpp
index 0c1e506..b6c9fe5 100644
--- a/libc/bionic/heap_tagging.cpp
+++ b/libc/bionic/heap_tagging.cpp
@@ -192,17 +192,29 @@
 #endif  // __aarch64__
 
 extern "C" __LIBC_HIDDEN__ __attribute__((no_sanitize("memtag"))) void memtag_handle_longjmp(
-    void* sp_dst __unused) {
+    void* sp_dst __unused, void* sp_src __unused) {
+  // A usual longjmp looks like this, where sp_dst was the LR in the call to setlongjmp (i.e.
+  // the SP of the frame calling setlongjmp).
+  // ┌─────────────────────┐                  │
+  // │                     │                  │
+  // ├─────────────────────┤◄──────── sp_dst  │ stack
+  // │         ...         │                  │ grows
+  // ├─────────────────────┤                  │ to lower
+  // │         ...         │                  │ addresses
+  // ├─────────────────────┤◄──────── sp_src  │
+  // │siglongjmp           │                  │
+  // ├─────────────────────┤                  │
+  // │memtag_handle_longjmp│                  │
+  // └─────────────────────┘                  ▼
 #ifdef __aarch64__
   if (__libc_globals->memtag_stack) {
-    void* sp = __builtin_frame_address(0);
-    size_t distance = reinterpret_cast<uintptr_t>(sp_dst) - reinterpret_cast<uintptr_t>(sp);
+    size_t distance = reinterpret_cast<uintptr_t>(sp_dst) - reinterpret_cast<uintptr_t>(sp_src);
     if (distance > kUntagLimit) {
       async_safe_fatal(
-          "memtag_handle_longjmp: stack adjustment too large! %p -> %p, distance %zx > %zx\n", sp,
-          sp_dst, distance, kUntagLimit);
+          "memtag_handle_longjmp: stack adjustment too large! %p -> %p, distance %zx > %zx\n",
+          sp_src, sp_dst, distance, kUntagLimit);
     } else {
-      untag_memory(sp, sp_dst);
+      untag_memory(sp_src, sp_dst);
     }
   }
 #endif  // __aarch64__