diff --git a/javalib/src/android/system/virtualmachine/VirtualMachine.java b/javalib/src/android/system/virtualmachine/VirtualMachine.java
index 65ce7ea..ed2c2a1 100644
--- a/javalib/src/android/system/virtualmachine/VirtualMachine.java
+++ b/javalib/src/android/system/virtualmachine/VirtualMachine.java
@@ -36,6 +36,8 @@
 import android.system.virtualizationservice.VirtualMachineState;
 import android.util.JsonReader;
 
+import com.android.internal.annotations.GuardedBy;
+
 import java.io.File;
 import java.io.FileInputStream;
 import java.io.FileNotFoundException;
@@ -52,6 +54,8 @@
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.function.Consumer;
 import java.util.zip.ZipFile;
 
 /**
@@ -92,6 +96,9 @@
         DELETED,
     }
 
+    /** Lock for internal synchronization. */
+    private final Object mLock = new Object();
+
     /** The package which owns this VM. */
     private final @NonNull String mPackageName;
 
@@ -135,9 +142,11 @@
     private @Nullable IVirtualMachine mVirtualMachine;
 
     /** The registered callback */
+    @GuardedBy("mLock")
     private @Nullable VirtualMachineCallback mCallback;
 
     /** The executor on which the callback will be executed */
+    @GuardedBy("mLock")
     private @Nullable Executor mCallbackExecutor;
 
     private @Nullable ParcelFileDescriptor mConsoleReader;
@@ -299,20 +308,37 @@
     public void setCallback(
             @NonNull @CallbackExecutor Executor executor,
             @NonNull VirtualMachineCallback callback) {
-        mCallbackExecutor = executor;
-        mCallback = callback;
+        synchronized (mLock) {
+            mCallback = callback;
+            mCallbackExecutor = executor;
+        }
     }
 
     /** Clears the currently registered callback. */
     public void clearCallback() {
-        // TODO(b/220730550): synchronize with the callers of the callback
-        mCallback = null;
-        mCallbackExecutor = null;
+        synchronized (mLock) {
+            mCallback = null;
+            mCallbackExecutor = null;
+        }
     }
 
-    /** Returns the currently registered callback. */
-    public @Nullable VirtualMachineCallback getCallback() {
-        return mCallback;
+    /** Executes a callback on the callback executor. */
+    private void executeCallback(Consumer<VirtualMachineCallback> fn) {
+        final VirtualMachineCallback callback;
+        final Executor executor;
+        synchronized (mLock) {
+            callback = mCallback;
+            executor = mCallbackExecutor;
+        }
+        if (callback == null || executor == null) {
+            return;
+        }
+        final long restoreToken = Binder.clearCallingIdentity();
+        try {
+            executor.execute(() -> fn.accept(callback));
+        } finally {
+            Binder.restoreCallingIdentity(restoreToken);
+        }
     }
 
     /**
@@ -376,14 +402,15 @@
             android.system.virtualizationservice.VirtualMachineConfig vmConfigParcel =
                     android.system.virtualizationservice.VirtualMachineConfig.appConfig(appConfig);
 
+            // The VM should only be observed to die once
+            AtomicBoolean onDiedCalled = new AtomicBoolean(false);
+
             IBinder.DeathRecipient deathRecipient = new IBinder.DeathRecipient() {
                 @Override
                 public void binderDied() {
-                    final VirtualMachineCallback cb = mCallback;
-                    if (cb != null) {
-                        // TODO(b/220730550): don't call if the VM already died
-                        cb.onDied(VirtualMachine.this, VirtualMachineCallback
-                                .DEATH_REASON_VIRTUALIZATIONSERVICE_DIED);
+                    if (onDiedCalled.compareAndSet(false, true)) {
+                        executeCallback((cb) -> cb.onDied(VirtualMachine.this,
+                                VirtualMachineCallback.DEATH_REASON_VIRTUALIZATIONSERVICE_DIED));
                     }
                 }
             };
@@ -393,80 +420,32 @@
                     new IVirtualMachineCallback.Stub() {
                         @Override
                         public void onPayloadStarted(int cid, ParcelFileDescriptor stream) {
-                            final VirtualMachineCallback cb = mCallback;
-                            if (cb == null) {
-                                return;
-                            }
-                            final long restoreToken = Binder.clearCallingIdentity();
-                            try {
-                                mCallbackExecutor.execute(
-                                        () -> cb.onPayloadStarted(VirtualMachine.this, stream));
-                            } finally {
-                                Binder.restoreCallingIdentity(restoreToken);
-                            }
+                            executeCallback(
+                                    (cb) -> cb.onPayloadStarted(VirtualMachine.this, stream));
                         }
-
                         @Override
                         public void onPayloadReady(int cid) {
-                            final VirtualMachineCallback cb = mCallback;
-                            if (cb == null) {
-                                return;
-                            }
-                            final long restoreToken = Binder.clearCallingIdentity();
-                            try {
-                                mCallbackExecutor.execute(
-                                        () -> cb.onPayloadReady(VirtualMachine.this));
-                            } finally {
-                                Binder.restoreCallingIdentity(restoreToken);
-                            }
+                            executeCallback((cb) -> cb.onPayloadReady(VirtualMachine.this));
                         }
-
                         @Override
                         public void onPayloadFinished(int cid, int exitCode) {
-                            final VirtualMachineCallback cb = mCallback;
-                            if (cb == null) {
-                                return;
-                            }
-                            final long restoreToken = Binder.clearCallingIdentity();
-                            try {
-                                mCallbackExecutor.execute(
-                                        () -> cb.onPayloadFinished(VirtualMachine.this, exitCode));
-                            } finally {
-                                Binder.restoreCallingIdentity(restoreToken);
-                            }
+                            executeCallback(
+                                    (cb) -> cb.onPayloadFinished(VirtualMachine.this, exitCode));
                         }
-
                         @Override
                         public void onError(int cid, int errorCode, String message) {
-                            final VirtualMachineCallback cb = mCallback;
-                            if (cb == null) {
-                                return;
-                            }
-                            final long restoreToken = Binder.clearCallingIdentity();
-                            try {
-                                mCallbackExecutor.execute(
-                                        () -> cb.onError(VirtualMachine.this, errorCode, message));
-                            } finally {
-                                Binder.restoreCallingIdentity(restoreToken);
-                            }
+                            executeCallback(
+                                    (cb) -> cb.onError(VirtualMachine.this, errorCode, message));
                         }
-
                         @Override
                         public void onDied(int cid, int reason) {
                             service.asBinder().unlinkToDeath(deathRecipient, 0);
-                            final VirtualMachineCallback cb = mCallback;
-                            if (cb == null) {
-                                return;
-                            }
-                            final long restoreToken = Binder.clearCallingIdentity();
-                            try {
-                                mCallbackExecutor.execute(
-                                        () -> cb.onDied(VirtualMachine.this, reason));
-                            } finally {
-                                Binder.restoreCallingIdentity(restoreToken);
+                            if (onDiedCalled.compareAndSet(false, true)) {
+                                executeCallback((cb) -> cb.onDied(VirtualMachine.this, reason));
                             }
                         }
-                    });
+                    }
+            );
             service.asBinder().linkToDeath(deathRecipient, 0);
             mVirtualMachine.start();
         } catch (IOException e) {
