rust: Support MTE memtag_heap sanitizer

This CL adds support for the MTE memtag_heap sanitizer. This is
controlled via inclusion of an ELF note.

Bug: 170672854
Test: Heap MTE-enabled Rust test binary triggers MTE
Change-Id: I2619818785e86a94667d02b30d102c83456b7925
diff --git a/rust/Android.bp b/rust/Android.bp
index 0ee673d..cda2dbc 100644
--- a/rust/Android.bp
+++ b/rust/Android.bp
@@ -54,6 +54,7 @@
         "project_json_test.go",
         "protobuf_test.go",
         "rust_test.go",
+        "sanitize_test.go",
         "source_provider_test.go",
         "test_test.go",
         "vendor_snapshot_test.go",
diff --git a/rust/binary.go b/rust/binary.go
index 7c18730..6854e62 100644
--- a/rust/binary.go
+++ b/rust/binary.go
@@ -34,6 +34,7 @@
 type binaryInterface interface {
 	binary() bool
 	staticallyLinked() bool
+	testBinary() bool
 }
 
 type binaryDecorator struct {
@@ -168,3 +169,7 @@
 func (binary *binaryDecorator) staticallyLinked() bool {
 	return Bool(binary.Properties.Static_executable)
 }
+
+func (binary *binaryDecorator) testBinary() bool {
+	return false
+}
diff --git a/rust/rust_test.go b/rust/rust_test.go
index 80f693e..fe55d5f 100644
--- a/rust/rust_test.go
+++ b/rust/rust_test.go
@@ -442,3 +442,10 @@
 	m.Output("libwaldo.dylib.so.bloaty.csv")
 	m.Output("stripped/libwaldo.dylib.so.bloaty.csv")
 }
+
+func assertString(t *testing.T, got, expected string) {
+	t.Helper()
+	if got != expected {
+		t.Errorf("expected %q got %q", expected, got)
+	}
+}
diff --git a/rust/sanitize.go b/rust/sanitize.go
index 5b7597e..fdb342d 100644
--- a/rust/sanitize.go
+++ b/rust/sanitize.go
@@ -26,13 +26,28 @@
 	"android/soong/rust/config"
 )
 
+// TODO: When Rust has sanitizer-parity with CC, deduplicate this struct
 type SanitizeProperties struct {
 	// enable AddressSanitizer, HWAddressSanitizer, and others.
 	Sanitize struct {
 		Address   *bool `android:"arch_variant"`
 		Hwaddress *bool `android:"arch_variant"`
-		Fuzzer    *bool `android:"arch_variant"`
-		Never     *bool `android:"arch_variant"`
+
+		// Memory-tagging, only available on arm64
+		// if diag.memtag unset or false, enables async memory tagging
+		Memtag_heap *bool `android:"arch_variant"`
+		Fuzzer      *bool `android:"arch_variant"`
+		Never       *bool `android:"arch_variant"`
+
+		// Sanitizers to run in the diagnostic mode (as opposed to the release mode).
+		// Replaces abort() on error with a human-readable error message.
+		// Address and Thread sanitizers always run in diagnostic mode.
+		Diag struct {
+			// Memory-tagging, only available on arm64
+			// requires sanitizer.memtag: true
+			// if set, enables sync memory tagging
+			Memtag_heap *bool `android:"arch_variant"`
+		}
 	}
 	SanitizerEnabled bool `blueprint:"mutated"`
 	SanitizeDep      bool `blueprint:"mutated"`
@@ -99,7 +114,18 @@
 		return
 	}
 
+	// rust_test targets default to SYNC MemTag unless explicitly set to ASYNC (via diag: {Memtag_heap}).
+	if binary, ok := ctx.RustModule().compiler.(binaryInterface); ok && binary.testBinary() {
+		if s.Memtag_heap == nil {
+			s.Memtag_heap = proptools.BoolPtr(true)
+		}
+		if s.Diag.Memtag_heap == nil {
+			s.Diag.Memtag_heap = proptools.BoolPtr(true)
+		}
+	}
+
 	var globalSanitizers []string
+	var globalSanitizersDiag []string
 
 	if ctx.Host() {
 		if !ctx.Windows() {
@@ -109,6 +135,7 @@
 		arches := ctx.Config().SanitizeDeviceArch()
 		if len(arches) == 0 || android.InList(ctx.Arch().ArchType.Name, arches) {
 			globalSanitizers = ctx.Config().SanitizeDevice()
+			globalSanitizersDiag = ctx.Config().SanitizeDeviceDiag()
 		}
 	}
 
@@ -123,6 +150,12 @@
 			}
 		}
 
+		if found, globalSanitizers = android.RemoveFromList("memtag_heap", globalSanitizers); found && s.Memtag_heap == nil {
+			if !ctx.Config().MemtagHeapDisabledForPath(ctx.ModuleDir()) {
+				s.Memtag_heap = proptools.BoolPtr(true)
+			}
+		}
+
 		if found, globalSanitizers = android.RemoveFromList("address", globalSanitizers); found && s.Address == nil {
 			s.Address = proptools.BoolPtr(true)
 		}
@@ -131,6 +164,27 @@
 			s.Fuzzer = proptools.BoolPtr(true)
 		}
 
+		// Global Diag Sanitizers
+		if found, globalSanitizersDiag = android.RemoveFromList("memtag_heap", globalSanitizersDiag); found &&
+			s.Diag.Memtag_heap == nil && Bool(s.Memtag_heap) {
+			s.Diag.Memtag_heap = proptools.BoolPtr(true)
+		}
+	}
+
+	// Enable Memtag for all components in the include paths (for Aarch64 only)
+	if ctx.Arch().ArchType == android.Arm64 {
+		if ctx.Config().MemtagHeapSyncEnabledForPath(ctx.ModuleDir()) {
+			if s.Memtag_heap == nil {
+				s.Memtag_heap = proptools.BoolPtr(true)
+			}
+			if s.Diag.Memtag_heap == nil {
+				s.Diag.Memtag_heap = proptools.BoolPtr(true)
+			}
+		} else if ctx.Config().MemtagHeapAsyncEnabledForPath(ctx.ModuleDir()) {
+			if s.Memtag_heap == nil {
+				s.Memtag_heap = proptools.BoolPtr(true)
+			}
+		}
 	}
 
 	// TODO:(b/178369775)
@@ -158,7 +212,12 @@
 		s.Address = nil
 	}
 
-	if ctx.Os() == android.Android && (Bool(s.Hwaddress) || Bool(s.Address)) {
+	// Memtag_heap is only implemented on AArch64.
+	if ctx.Arch().ArchType != android.Arm64 {
+		s.Memtag_heap = nil
+	}
+
+	if ctx.Os() == android.Android && (Bool(s.Hwaddress) || Bool(s.Address) || Bool(s.Memtag_heap)) {
 		sanitize.Properties.SanitizerEnabled = true
 	}
 }
@@ -198,6 +257,26 @@
 			return
 		}
 
+		if Bool(mod.sanitize.Properties.Sanitize.Memtag_heap) && mod.Binary() {
+			noteDep := "note_memtag_heap_async"
+			if Bool(mod.sanitize.Properties.Sanitize.Diag.Memtag_heap) {
+				noteDep = "note_memtag_heap_sync"
+			}
+			// If we're using snapshots, redirect to snapshot whenever possible
+			// TODO(b/178470649): clean manual snapshot redirections
+			snapshot := mctx.Provider(cc.SnapshotInfoProvider).(cc.SnapshotInfo)
+			if lib, ok := snapshot.StaticLibs[noteDep]; ok {
+				noteDep = lib
+			}
+			depTag := cc.StaticDepTag(true)
+			variations := append(mctx.Target().Variations(),
+				blueprint.Variation{Mutator: "link", Variation: "static"})
+			if mod.Device() {
+				variations = append(variations, mod.ImageVariation())
+			}
+			mctx.AddFarVariationDependencies(variations, depTag, noteDep)
+		}
+
 		variations := mctx.Target().Variations()
 		var depTag blueprint.DependencyTag
 		var deps []string
@@ -225,7 +304,9 @@
 			deps = []string{config.LibclangRuntimeLibrary(mod.toolchain(mctx), "hwasan")}
 		}
 
-		mctx.AddFarVariationDependencies(variations, depTag, deps...)
+		if len(deps) > 0 {
+			mctx.AddFarVariationDependencies(variations, depTag, deps...)
+		}
 	}
 }
 
@@ -241,6 +322,9 @@
 	case cc.Hwasan:
 		sanitize.Properties.Sanitize.Hwaddress = boolPtr(b)
 		sanitizerSet = true
+	case cc.Memtag_heap:
+		sanitize.Properties.Sanitize.Memtag_heap = boolPtr(b)
+		sanitizerSet = true
 	default:
 		panic(fmt.Errorf("setting unsupported sanitizerType %d", t))
 	}
@@ -300,6 +384,8 @@
 		return sanitize.Properties.Sanitize.Address
 	case cc.Hwasan:
 		return sanitize.Properties.Sanitize.Hwaddress
+	case cc.Memtag_heap:
+		return sanitize.Properties.Sanitize.Memtag_heap
 	default:
 		return nil
 	}
@@ -330,6 +416,8 @@
 			return false
 		}
 		return true
+	case cc.Memtag_heap:
+		return true
 	default:
 		return false
 	}
diff --git a/rust/sanitize_test.go b/rust/sanitize_test.go
new file mode 100644
index 0000000..d6a14b2
--- /dev/null
+++ b/rust/sanitize_test.go
@@ -0,0 +1,365 @@
+package rust
+
+import (
+	"fmt"
+	"strings"
+	"testing"
+
+	"android/soong/android"
+)
+
+type MemtagNoteType int
+
+const (
+	None MemtagNoteType = iota + 1
+	Sync
+	Async
+)
+
+func (t MemtagNoteType) str() string {
+	switch t {
+	case None:
+		return "none"
+	case Sync:
+		return "sync"
+	case Async:
+		return "async"
+	default:
+		panic("type_note_invalid")
+	}
+}
+
+func checkHasMemtagNote(t *testing.T, m android.TestingModule, expected MemtagNoteType) {
+	t.Helper()
+	note_async := "note_memtag_heap_async"
+	note_sync := "note_memtag_heap_sync"
+
+	found := None
+	implicits := m.Rule("rustc").Implicits
+	for _, lib := range implicits {
+		if strings.Contains(lib.Rel(), note_async) {
+			found = Async
+			break
+		} else if strings.Contains(lib.Rel(), note_sync) {
+			found = Sync
+			break
+		}
+	}
+
+	if found != expected {
+		t.Errorf("Wrong Memtag note in target %q: found %q, expected %q", m.Module().(*Module).Name(), found.str(), expected.str())
+	}
+}
+
+var prepareForTestWithMemtagHeap = android.GroupFixturePreparers(
+	android.FixtureModifyMockFS(func(fs android.MockFS) {
+		templateBp := `
+		rust_test {
+			name: "unset_test_%[1]s",
+			srcs: ["foo.rs"],
+		}
+
+		rust_test {
+			name: "no_memtag_test_%[1]s",
+			srcs: ["foo.rs"],
+			sanitize: { memtag_heap: false },
+		}
+
+		rust_test {
+			name: "set_memtag_test_%[1]s",
+			srcs: ["foo.rs"],
+			sanitize: { memtag_heap: true },
+		}
+
+		rust_test {
+			name: "set_memtag_set_async_test_%[1]s",
+			srcs: ["foo.rs"],
+			sanitize: { memtag_heap: true, diag: { memtag_heap: false }  },
+		}
+
+		rust_test {
+			name: "set_memtag_set_sync_test_%[1]s",
+			srcs: ["foo.rs"],
+			sanitize: { memtag_heap: true, diag: { memtag_heap: true }  },
+		}
+
+		rust_test {
+			name: "unset_memtag_set_sync_test_%[1]s",
+			srcs: ["foo.rs"],
+			sanitize: { diag: { memtag_heap: true }  },
+		}
+
+		rust_binary {
+			name: "unset_binary_%[1]s",
+			srcs: ["foo.rs"],
+		}
+
+		rust_binary {
+			name: "no_memtag_binary_%[1]s",
+			srcs: ["foo.rs"],
+			sanitize: { memtag_heap: false },
+		}
+
+		rust_binary {
+			name: "set_memtag_binary_%[1]s",
+			srcs: ["foo.rs"],
+			sanitize: { memtag_heap: true },
+		}
+
+		rust_binary {
+			name: "set_memtag_set_async_binary_%[1]s",
+			srcs: ["foo.rs"],
+			sanitize: { memtag_heap: true, diag: { memtag_heap: false }  },
+		}
+
+		rust_binary {
+			name: "set_memtag_set_sync_binary_%[1]s",
+			srcs: ["foo.rs"],
+			sanitize: { memtag_heap: true, diag: { memtag_heap: true }  },
+		}
+
+		rust_binary {
+			name: "unset_memtag_set_sync_binary_%[1]s",
+			srcs: ["foo.rs"],
+			sanitize: { diag: { memtag_heap: true }  },
+		}
+		`
+		subdirNoOverrideBp := fmt.Sprintf(templateBp, "no_override")
+		subdirOverrideDefaultDisableBp := fmt.Sprintf(templateBp, "override_default_disable")
+		subdirSyncBp := fmt.Sprintf(templateBp, "override_default_sync")
+		subdirAsyncBp := fmt.Sprintf(templateBp, "override_default_async")
+
+		fs.Merge(android.MockFS{
+			"subdir_no_override/Android.bp":              []byte(subdirNoOverrideBp),
+			"subdir_override_default_disable/Android.bp": []byte(subdirOverrideDefaultDisableBp),
+			"subdir_sync/Android.bp":                     []byte(subdirSyncBp),
+			"subdir_async/Android.bp":                    []byte(subdirAsyncBp),
+		})
+	}),
+	android.FixtureModifyProductVariables(func(variables android.FixtureProductVariables) {
+		variables.MemtagHeapExcludePaths = []string{"subdir_override_default_disable"}
+		// "subdir_override_default_disable" is covered by both include and override_default_disable paths. override_default_disable wins.
+		variables.MemtagHeapSyncIncludePaths = []string{"subdir_sync", "subdir_override_default_disable"}
+		variables.MemtagHeapAsyncIncludePaths = []string{"subdir_async", "subdir_override_default_disable"}
+	}),
+)
+
+func TestSanitizeMemtagHeap(t *testing.T) {
+	variant := "android_arm64_armv8-a"
+
+	result := android.GroupFixturePreparers(
+		prepareForRustTest,
+		prepareForTestWithMemtagHeap,
+	).RunTest(t)
+	ctx := result.TestContext
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("no_memtag_binary_no_override", variant), None)
+	checkHasMemtagNote(t, ctx.ModuleForTests("no_memtag_binary_override_default_async", variant), None)
+	checkHasMemtagNote(t, ctx.ModuleForTests("no_memtag_binary_override_default_disable", variant), None)
+	checkHasMemtagNote(t, ctx.ModuleForTests("no_memtag_binary_override_default_sync", variant), None)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("no_memtag_test_no_override", variant), None)
+	checkHasMemtagNote(t, ctx.ModuleForTests("no_memtag_test_override_default_async", variant), None)
+	checkHasMemtagNote(t, ctx.ModuleForTests("no_memtag_test_override_default_disable", variant), None)
+	checkHasMemtagNote(t, ctx.ModuleForTests("no_memtag_test_override_default_sync", variant), None)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_binary_no_override", variant), Async)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_binary_override_default_async", variant), Async)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_binary_override_default_disable", variant), Async)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_binary_override_default_sync", variant), Sync)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_test_no_override", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_test_override_default_async", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_test_override_default_disable", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_test_override_default_sync", variant), Sync)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_async_binary_no_override", variant), Async)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_async_binary_override_default_async", variant), Async)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_async_binary_override_default_disable", variant), Async)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_async_binary_override_default_sync", variant), Async)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_async_test_no_override", variant), Async)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_async_test_override_default_async", variant), Async)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_async_test_override_default_disable", variant), Async)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_async_test_override_default_sync", variant), Async)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_sync_binary_no_override", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_sync_binary_override_default_async", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_sync_binary_override_default_disable", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_sync_binary_override_default_sync", variant), Sync)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_sync_test_no_override", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_sync_test_override_default_async", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_sync_test_override_default_disable", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_sync_test_override_default_sync", variant), Sync)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_memtag_set_sync_binary_no_override", variant), None)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_memtag_set_sync_binary_override_default_async", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_memtag_set_sync_binary_override_default_disable", variant), None)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_memtag_set_sync_binary_override_default_sync", variant), Sync)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_memtag_set_sync_test_no_override", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_memtag_set_sync_test_override_default_async", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_memtag_set_sync_test_override_default_disable", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_memtag_set_sync_test_override_default_sync", variant), Sync)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_binary_no_override", variant), None)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_binary_override_default_async", variant), Async)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_binary_override_default_disable", variant), None)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_binary_override_default_sync", variant), Sync)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_test_no_override", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_test_override_default_async", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_test_override_default_disable", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_test_override_default_sync", variant), Sync)
+}
+
+func TestSanitizeMemtagHeapWithSanitizeDevice(t *testing.T) {
+	variant := "android_arm64_armv8-a"
+
+	result := android.GroupFixturePreparers(
+		prepareForRustTest,
+		prepareForTestWithMemtagHeap,
+		android.FixtureModifyProductVariables(func(variables android.FixtureProductVariables) {
+			variables.SanitizeDevice = []string{"memtag_heap"}
+		}),
+	).RunTest(t)
+	ctx := result.TestContext
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("no_memtag_binary_no_override", variant), None)
+	checkHasMemtagNote(t, ctx.ModuleForTests("no_memtag_binary_override_default_async", variant), None)
+	checkHasMemtagNote(t, ctx.ModuleForTests("no_memtag_binary_override_default_disable", variant), None)
+	checkHasMemtagNote(t, ctx.ModuleForTests("no_memtag_binary_override_default_sync", variant), None)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("no_memtag_test_no_override", variant), None)
+	checkHasMemtagNote(t, ctx.ModuleForTests("no_memtag_test_override_default_async", variant), None)
+	checkHasMemtagNote(t, ctx.ModuleForTests("no_memtag_test_override_default_disable", variant), None)
+	checkHasMemtagNote(t, ctx.ModuleForTests("no_memtag_test_override_default_sync", variant), None)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_binary_no_override", variant), Async)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_binary_override_default_async", variant), Async)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_binary_override_default_disable", variant), Async)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_binary_override_default_sync", variant), Sync)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_test_no_override", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_test_override_default_async", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_test_override_default_disable", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_test_override_default_sync", variant), Sync)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_async_binary_no_override", variant), Async)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_async_binary_override_default_async", variant), Async)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_async_binary_override_default_disable", variant), Async)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_async_binary_override_default_sync", variant), Async)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_async_test_no_override", variant), Async)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_async_test_override_default_async", variant), Async)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_async_test_override_default_disable", variant), Async)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_async_test_override_default_sync", variant), Async)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_sync_binary_no_override", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_sync_binary_override_default_async", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_sync_binary_override_default_disable", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_sync_binary_override_default_sync", variant), Sync)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_sync_test_no_override", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_sync_test_override_default_async", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_sync_test_override_default_disable", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_sync_test_override_default_sync", variant), Sync)
+
+	// should sanitize: { diag: { memtag: true } } result in Sync instead of None here?
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_memtag_set_sync_binary_no_override", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_memtag_set_sync_binary_override_default_async", variant), Sync)
+	// should sanitize: { diag: { memtag: true } } result in Sync instead of None here?
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_memtag_set_sync_binary_override_default_disable", variant), None)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_memtag_set_sync_binary_override_default_sync", variant), Sync)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_memtag_set_sync_test_no_override", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_memtag_set_sync_test_override_default_async", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_memtag_set_sync_test_override_default_disable", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_memtag_set_sync_test_override_default_sync", variant), Sync)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_binary_no_override", variant), Async)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_binary_override_default_async", variant), Async)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_binary_override_default_disable", variant), None)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_binary_override_default_sync", variant), Sync)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_test_no_override", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_test_override_default_async", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_test_override_default_disable", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_test_override_default_sync", variant), Sync)
+}
+
+func TestSanitizeMemtagHeapWithSanitizeDeviceDiag(t *testing.T) {
+	variant := "android_arm64_armv8-a"
+
+	result := android.GroupFixturePreparers(
+		prepareForRustTest,
+		prepareForTestWithMemtagHeap,
+		android.FixtureModifyProductVariables(func(variables android.FixtureProductVariables) {
+			variables.SanitizeDevice = []string{"memtag_heap"}
+			variables.SanitizeDeviceDiag = []string{"memtag_heap"}
+		}),
+	).RunTest(t)
+	ctx := result.TestContext
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("no_memtag_binary_no_override", variant), None)
+	checkHasMemtagNote(t, ctx.ModuleForTests("no_memtag_binary_override_default_async", variant), None)
+	checkHasMemtagNote(t, ctx.ModuleForTests("no_memtag_binary_override_default_disable", variant), None)
+	checkHasMemtagNote(t, ctx.ModuleForTests("no_memtag_binary_override_default_sync", variant), None)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("no_memtag_test_no_override", variant), None)
+	checkHasMemtagNote(t, ctx.ModuleForTests("no_memtag_test_override_default_async", variant), None)
+	checkHasMemtagNote(t, ctx.ModuleForTests("no_memtag_test_override_default_disable", variant), None)
+	checkHasMemtagNote(t, ctx.ModuleForTests("no_memtag_test_override_default_sync", variant), None)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_binary_no_override", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_binary_override_default_async", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_binary_override_default_disable", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_binary_override_default_sync", variant), Sync)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_test_no_override", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_test_override_default_async", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_test_override_default_disable", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_test_override_default_sync", variant), Sync)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_async_binary_no_override", variant), Async)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_async_binary_override_default_async", variant), Async)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_async_binary_override_default_disable", variant), Async)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_async_binary_override_default_sync", variant), Async)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_async_test_no_override", variant), Async)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_async_test_override_default_async", variant), Async)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_async_test_override_default_disable", variant), Async)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_async_test_override_default_sync", variant), Async)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_sync_binary_no_override", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_sync_binary_override_default_async", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_sync_binary_override_default_disable", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_sync_binary_override_default_sync", variant), Sync)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_sync_test_no_override", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_sync_test_override_default_async", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_sync_test_override_default_disable", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("set_memtag_set_sync_test_override_default_sync", variant), Sync)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_memtag_set_sync_binary_no_override", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_memtag_set_sync_binary_override_default_async", variant), Sync)
+	// should sanitize: { diag: { memtag: true } } result in Sync instead of None here?
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_memtag_set_sync_binary_override_default_disable", variant), None)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_memtag_set_sync_binary_override_default_sync", variant), Sync)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_memtag_set_sync_test_no_override", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_memtag_set_sync_test_override_default_async", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_memtag_set_sync_test_override_default_disable", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_memtag_set_sync_test_override_default_sync", variant), Sync)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_binary_no_override", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_binary_override_default_async", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_binary_override_default_disable", variant), None)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_binary_override_default_sync", variant), Sync)
+
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_test_no_override", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_test_override_default_async", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_test_override_default_disable", variant), Sync)
+	checkHasMemtagNote(t, ctx.ModuleForTests("unset_test_override_default_sync", variant), Sync)
+}
diff --git a/rust/test.go b/rust/test.go
index 56da509..bb877a9 100644
--- a/rust/test.go
+++ b/rust/test.go
@@ -196,3 +196,7 @@
 
 	return deps
 }
+
+func (test *testDecorator) testBinary() bool {
+	return true
+}
diff --git a/rust/vendor_snapshot_test.go b/rust/vendor_snapshot_test.go
index 60ddb65..bfa6f36 100644
--- a/rust/vendor_snapshot_test.go
+++ b/rust/vendor_snapshot_test.go
@@ -562,6 +562,7 @@
 					"libvendor",
 					"libvndk",
 					"libclang_rt.builtins-aarch64-android",
+					"note_memtag_heap_sync",
 				],
 				shared_libs: [
 					"libvendor_available",
@@ -853,6 +854,20 @@
 		},
 	}
 
+	// Test sanitizers use the snapshot libraries
+	rust_binary {
+		name: "memtag_binary",
+		srcs: ["vendor/bin.rs"],
+		vendor: true,
+		compile_multilib: "64",
+		sanitize: {
+			memtag_heap: true,
+			diag: {
+				memtag_heap: true,
+			}
+		},
+	}
+
 	// old snapshot module which has to be ignored
 	vendor_snapshot_binary {
 		name: "bin",
@@ -880,11 +895,25 @@
 			},
 		},
 	}
+
+	vendor_snapshot_static {
+		name: "note_memtag_heap_sync",
+		vendor: true,
+		target_arch: "arm64",
+		version: "30",
+		arch: {
+			arm64: {
+				src: "note_memtag_heap_sync.a",
+			},
+		},
+	}
+
 `
 
 	mockFS := android.MockFS{
 		"framework/Android.bp":                          []byte(frameworkBp),
 		"framework/bin.rs":                              nil,
+		"note_memtag_heap_sync.a":                       nil,
 		"vendor/Android.bp":                             []byte(vendorProprietaryBp),
 		"vendor/bin":                                    nil,
 		"vendor/bin32":                                  nil,
@@ -993,4 +1022,9 @@
 	if android.InList(binaryVariant, binVariants) {
 		t.Errorf("bin must not have variant %#v, but it does", sharedVariant)
 	}
+
+	memtagStaticLibs := ctx.ModuleForTests("memtag_binary", "android_vendor.30_arm64_armv8-a").Module().(*Module).Properties.AndroidMkStaticLibs
+	if g, w := memtagStaticLibs, []string{"libclang_rt.builtins-aarch64-android.vendor", "note_memtag_heap_sync.vendor"}; !reflect.DeepEqual(g, w) {
+		t.Errorf("wanted memtag_binary AndroidMkStaticLibs %q, got %q", w, g)
+	}
 }