Merge "Restart WearableSensingService process before sending connection" into main
diff --git a/core/java/android/app/wearable/WearableSensingManager.java b/core/java/android/app/wearable/WearableSensingManager.java
index 077f7b5..3b281e9 100644
--- a/core/java/android/app/wearable/WearableSensingManager.java
+++ b/core/java/android/app/wearable/WearableSensingManager.java
@@ -206,9 +206,10 @@
      * and kill the WearableSensingService process.
      *
      * <p>Before providing the secureWearableConnection, the system will restart the
-     * WearableSensingService process. Other method calls into WearableSensingService may be dropped
-     * during the restart. The caller is responsible for ensuring other method calls are queued
-     * until a success status is returned from the {@code statusConsumer}.
+     * WearableSensingService process if it has not been restarted since the last
+     * secureWearableConnection was provided. Other method calls into WearableSensingService may be
+     * dropped during the restart. The caller is responsible for ensuring other method calls are
+     * queued until a success status is returned from the {@code statusConsumer}.
      *
      * @param wearableConnection The connection to provide
      * @param executor Executor on which to run the consumer callback
diff --git a/core/java/android/service/wearable/IWearableSensingService.aidl b/core/java/android/service/wearable/IWearableSensingService.aidl
index 0556188..f67dcff 100644
--- a/core/java/android/service/wearable/IWearableSensingService.aidl
+++ b/core/java/android/service/wearable/IWearableSensingService.aidl
@@ -37,4 +37,5 @@
             in RemoteCallback detectionResultCallback, in RemoteCallback statusCallback);
     void stopDetection(in String packageName);
     void queryServiceStatus(in int[] eventTypes, in String packageName, in RemoteCallback callback);
+    void killProcess();
 }
\ No newline at end of file
diff --git a/core/java/android/service/wearable/WearableSensingService.java b/core/java/android/service/wearable/WearableSensingService.java
index d25cff7..bb6e030 100644
--- a/core/java/android/service/wearable/WearableSensingService.java
+++ b/core/java/android/service/wearable/WearableSensingService.java
@@ -32,6 +32,7 @@
 import android.os.IBinder;
 import android.os.ParcelFileDescriptor;
 import android.os.PersistableBundle;
+import android.os.Process;
 import android.os.RemoteCallback;
 import android.os.SharedMemory;
 import android.service.ambientcontext.AmbientContextDetectionResult;
@@ -242,6 +243,13 @@
                     WearableSensingService.this.onQueryServiceStatus(
                             new HashSet<>(Arrays.asList(events)), packageName, consumer);
                 }
+
+                /** {@inheritDoc} */
+                @Override
+                public void killProcess() {
+                    Slog.d(TAG, "#killProcess");
+                    Process.killProcess(Process.myPid());
+                }
             };
         }
         Slog.w(TAG, "Incorrect service interface, returning null.");
diff --git a/services/core/java/com/android/server/wearable/RemoteWearableSensingService.java b/services/core/java/com/android/server/wearable/RemoteWearableSensingService.java
index 88d3daf..62a637e 100644
--- a/services/core/java/com/android/server/wearable/RemoteWearableSensingService.java
+++ b/services/core/java/com/android/server/wearable/RemoteWearableSensingService.java
@@ -19,6 +19,8 @@
 import static android.content.Context.BIND_FOREGROUND_SERVICE;
 import static android.content.Context.BIND_INCLUDE_CAPABILITIES;
 
+import android.app.wearable.Flags;
+import android.app.wearable.WearableSensingManager;
 import android.content.ComponentName;
 import android.content.Context;
 import android.content.Intent;
@@ -30,6 +32,7 @@
 import android.service.wearable.WearableSensingService;
 import android.util.Slog;
 
+import com.android.internal.annotations.GuardedBy;
 import com.android.internal.infra.ServiceConnector;
 
 import java.io.IOException;
@@ -40,6 +43,17 @@
             com.android.server.wearable.RemoteWearableSensingService.class.getSimpleName();
     private final static boolean DEBUG = false;
 
+    private final Object mSecureWearableConnectionLock = new Object();
+
+    // mNextSecureWearableConnectionContext will only be non-null when we are waiting for the
+    // WearableSensingService process to restart. It will be set to null after it is passed into
+    // WearableSensingService.
+    @GuardedBy("mSecureWearableConnectionLock")
+    private SecureWearableConnectionContext mNextSecureWearableConnectionContext;
+
+    @GuardedBy("mSecureWearableConnectionLock")
+    private boolean mSecureWearableConnectionProvided = false;
+
     RemoteWearableSensingService(Context context, ComponentName serviceName,
             int userId) {
         super(context, new Intent(
@@ -66,18 +80,84 @@
     public void provideSecureWearableConnection(
             ParcelFileDescriptor secureWearableConnection, RemoteCallback callback) {
         if (DEBUG) {
-            Slog.i(TAG, "Providing secure wearable connection.");
+            Slog.i(TAG, "#provideSecureWearableConnection");
         }
-        var unused = post(
-                service -> {
-                    service.provideSecureWearableConnection(secureWearableConnection, callback);
-                    try {
-                        // close the local fd after it has been sent to the WSS process
-                        secureWearableConnection.close();
-                    } catch (IOException ex) {
-                        Slog.w(TAG, "Unable to close the local parcelFileDescriptor.", ex);
-                    }
-                });
+        if (!Flags.enableRestartWssProcess()) {
+            Slog.d(
+                    TAG,
+                    "FLAG_ENABLE_RESTART_WSS_PROCESS is disabled. Do not attempt to restart the"
+                        + " WearableSensingService process");
+            provideSecureWearableConnectionInternal(secureWearableConnection, callback);
+            return;
+        }
+        synchronized (mSecureWearableConnectionLock) {
+            if (mNextSecureWearableConnectionContext != null) {
+                // A process restart is in progress, #binderDied is about to be called. Replace
+                // the previous mNextSecureWearableConnectionContext with the current one
+                Slog.i(
+                        TAG,
+                        "A new wearable connection is provided before the process restart triggered"
+                            + " by the previous connection is complete. Discarding the previous"
+                            + " connection.");
+                if (Flags.enableProvideWearableConnectionApi()) {
+                    WearableSensingManagerPerUserService.notifyStatusCallback(
+                            mNextSecureWearableConnectionContext.mStatusCallback,
+                            WearableSensingManager.STATUS_CHANNEL_ERROR);
+                }
+                mNextSecureWearableConnectionContext =
+                        new SecureWearableConnectionContext(secureWearableConnection, callback);
+                return;
+            }
+            if (!mSecureWearableConnectionProvided) {
+                // no need to kill the process
+                provideSecureWearableConnectionInternal(secureWearableConnection, callback);
+                mSecureWearableConnectionProvided = true;
+                return;
+            }
+            mNextSecureWearableConnectionContext =
+                    new SecureWearableConnectionContext(secureWearableConnection, callback);
+            // Killing the process causes the binder to die. #binderDied will then be triggered
+            killWearableSensingServiceProcess();
+        }
+    }
+
+    private void provideSecureWearableConnectionInternal(
+            ParcelFileDescriptor secureWearableConnection, RemoteCallback callback) {
+        Slog.d(TAG, "Providing secure wearable connection.");
+        var unused =
+                post(
+                        service -> {
+                            service.provideSecureWearableConnection(
+                                    secureWearableConnection, callback);
+                            try {
+                                // close the local fd after it has been sent to the WSS process
+                                secureWearableConnection.close();
+                            } catch (IOException ex) {
+                                Slog.w(TAG, "Unable to close the local parcelFileDescriptor.", ex);
+                            }
+                        });
+    }
+
+    @Override
+    public void binderDied() {
+        super.binderDied();
+        synchronized (mSecureWearableConnectionLock) {
+            if (mNextSecureWearableConnectionContext != null) {
+                // This will call #post, which will recreate the process and bind to it
+                provideSecureWearableConnectionInternal(
+                        mNextSecureWearableConnectionContext.mSecureWearableConnection,
+                        mNextSecureWearableConnectionContext.mStatusCallback);
+                mNextSecureWearableConnectionContext = null;
+            } else {
+                mSecureWearableConnectionProvided = false;
+                Slog.w(TAG, "Binder died but there is no secure wearable connection to provide.");
+            }
+        }
+    }
+
+    /** Kills the WearableSensingService process. */
+    public void killWearableSensingServiceProcess() {
+        var unused = post(service -> service.killProcess());
     }
 
     /**
@@ -176,4 +256,15 @@
                                         packageName,
                                         statusCallback));
     }
+
+    private static class SecureWearableConnectionContext {
+        final ParcelFileDescriptor mSecureWearableConnection;
+        final RemoteCallback mStatusCallback;
+
+        SecureWearableConnectionContext(
+                ParcelFileDescriptor secureWearableConnection, RemoteCallback statusCallback) {
+            this.mSecureWearableConnection = secureWearableConnection;
+            this.mStatusCallback = statusCallback;
+        }
+    }
 }
diff --git a/services/core/java/com/android/server/wearable/WearableSensingManagerPerUserService.java b/services/core/java/com/android/server/wearable/WearableSensingManagerPerUserService.java
index 0e8b82f..9ba4433 100644
--- a/services/core/java/com/android/server/wearable/WearableSensingManagerPerUserService.java
+++ b/services/core/java/com/android/server/wearable/WearableSensingManagerPerUserService.java
@@ -44,6 +44,7 @@
 
 import java.io.IOException;
 import java.io.PrintWriter;
+import java.util.concurrent.atomic.AtomicReference;
 
 /**
  * Per-user manager service for managing sensing {@link AmbientContextEvent}s on Wearables.
@@ -68,7 +69,7 @@
         super(master, lock, userId);
     }
 
-    static void notifyStatusCallback(RemoteCallback statusCallback, int statusCode) {
+    public static void notifyStatusCallback(RemoteCallback statusCallback, int statusCode) {
         Bundle bundle = new Bundle();
         bundle.putInt(
                 WearableSensingManager.STATUS_RESPONSE_BUNDLE_KEY, statusCode);
@@ -183,11 +184,11 @@
         }
         synchronized (mSecureChannelLock) {
             if (mSecureChannel != null) {
-                // TODO(b/321012559): Kill the WearableSensingService process if it has not been
-                // killed from onError
                 mSecureChannel.close();
             }
             try {
+                final AtomicReference<WearableSensingSecureChannel> currentSecureChannelRef =
+                        new AtomicReference<>();
                 mSecureChannel =
                         WearableSensingSecureChannel.create(
                                 getContext().getSystemService(CompanionDeviceManager.class),
@@ -206,8 +207,17 @@
 
                                     @Override
                                     public void onError() {
-                                        // TODO(b/321012559): Kill the WearableSensingService
-                                        // process if mSecureChannel has not been reassigned
+                                        if (Flags.enableRestartWssProcess()) {
+                                            synchronized (mSecureChannelLock) {
+                                                if (mSecureChannel != null
+                                                        && mSecureChannel
+                                                                == currentSecureChannelRef.get()) {
+                                                    mRemoteService
+                                                            .killWearableSensingServiceProcess();
+                                                    mSecureChannel = null;
+                                                }
+                                            }
+                                        }
                                         if (Flags.enableProvideWearableConnectionApi()) {
                                             notifyStatusCallback(
                                                     callback,
@@ -215,6 +225,7 @@
                                         }
                                     }
                                 });
+                currentSecureChannelRef.set(mSecureChannel);
             } catch (IOException ex) {
                 Slog.e(TAG, "Unable to create the secure channel.", ex);
                 if (Flags.enableProvideWearableConnectionApi()) {