Handle Model Loading/Unloading Broadcasts for Samsung in Isolated AICore

This change registers a callback with the remote service via
updateProcessingState to receive model loading/unloading updates and
then send these as broadcasts to system process.

Bug: 339688752
Test: cts in topic
Change-Id: I84ae64c5f5dfcbc7005a912b31861b3b64bf61f1
diff --git a/core/java/android/service/ondeviceintelligence/OnDeviceSandboxedInferenceService.java b/core/java/android/service/ondeviceintelligence/OnDeviceSandboxedInferenceService.java
index 29a6db6..8237b20 100644
--- a/core/java/android/service/ondeviceintelligence/OnDeviceSandboxedInferenceService.java
+++ b/core/java/android/service/ondeviceintelligence/OnDeviceSandboxedInferenceService.java
@@ -105,6 +105,21 @@
     public static final String SERVICE_INTERFACE =
             "android.service.ondeviceintelligence.OnDeviceSandboxedInferenceService";
 
+    // TODO(339594686): make API
+    /**
+     * @hide
+     */
+    public static final String REGISTER_MODEL_UPDATE_CALLBACK_BUNDLE_KEY =
+            "register_model_update_callback";
+    /**
+     * @hide
+     */
+    public static final String MODEL_LOADED_BUNDLE_KEY = "model_loaded";
+    /**
+     * @hide
+     */
+    public static final String MODEL_UNLOADED_BUNDLE_KEY = "model_unloaded";
+
     private IRemoteStorageService mRemoteStorageService;
 
     /**
diff --git a/core/res/res/values/config.xml b/core/res/res/values/config.xml
index cefc648..13b94d8 100644
--- a/core/res/res/values/config.xml
+++ b/core/res/res/values/config.xml
@@ -4704,6 +4704,13 @@
     <!-- The component name for the default system on-device sandboxed inference service. -->
     <string name="config_defaultOnDeviceSandboxedInferenceService" translatable="false"></string>
 
+    <!-- The broadcast intent name for notifying when the on-device model is loading  -->
+    <string name="config_onDeviceIntelligenceModelLoadedBroadcastKey" translatable="false"></string>
+
+    <!-- The broadcast intent name for notifying when the on-device model has been unloaded  -->
+    <string name="config_onDeviceIntelligenceModelUnloadedBroadcastKey" translatable="false"></string>
+
+
     <!-- Component name that accepts ACTION_SEND intents for requesting ambient context consent for
          wearable sensing. -->
     <string translatable="false" name="config_defaultWearableSensingConsentComponent"></string>
diff --git a/core/res/res/values/symbols.xml b/core/res/res/values/symbols.xml
index d058fb1..dcfbda0 100644
--- a/core/res/res/values/symbols.xml
+++ b/core/res/res/values/symbols.xml
@@ -3941,6 +3941,8 @@
   <java-symbol type="string" name="config_defaultWearableSensingService" />
   <java-symbol type="string" name="config_defaultOnDeviceIntelligenceService" />
   <java-symbol type="string" name="config_defaultOnDeviceSandboxedInferenceService" />
+  <java-symbol type="string" name="config_onDeviceIntelligenceModelLoadedBroadcastKey" />
+  <java-symbol type="string" name="config_onDeviceIntelligenceModelUnloadedBroadcastKey" />
   <java-symbol type="string" name="config_retailDemoPackage" />
   <java-symbol type="string" name="config_retailDemoPackageSignature" />
 
diff --git a/services/core/java/com/android/server/ondeviceintelligence/OnDeviceIntelligenceManagerService.java b/services/core/java/com/android/server/ondeviceintelligence/OnDeviceIntelligenceManagerService.java
index 99401a1..235e3cd 100644
--- a/services/core/java/com/android/server/ondeviceintelligence/OnDeviceIntelligenceManagerService.java
+++ b/services/core/java/com/android/server/ondeviceintelligence/OnDeviceIntelligenceManagerService.java
@@ -16,6 +16,10 @@
 
 package com.android.server.ondeviceintelligence;
 
+import static android.service.ondeviceintelligence.OnDeviceSandboxedInferenceService.MODEL_LOADED_BUNDLE_KEY;
+import static android.service.ondeviceintelligence.OnDeviceSandboxedInferenceService.MODEL_UNLOADED_BUNDLE_KEY;
+import static android.service.ondeviceintelligence.OnDeviceSandboxedInferenceService.REGISTER_MODEL_UPDATE_CALLBACK_BUNDLE_KEY;
+
 import static com.android.server.ondeviceintelligence.BundleUtil.sanitizeInferenceParams;
 import static com.android.server.ondeviceintelligence.BundleUtil.validatePfdReadOnly;
 import static com.android.server.ondeviceintelligence.BundleUtil.sanitizeStateParams;
@@ -41,6 +45,7 @@
 import android.app.ondeviceintelligence.OnDeviceIntelligenceException;
 import android.content.ComponentName;
 import android.content.Context;
+import android.content.Intent;
 import android.content.pm.PackageManager;
 import android.content.pm.ServiceInfo;
 import android.content.res.Resources;
@@ -105,12 +110,20 @@
     /** Handler message to {@link #resetTemporaryServices()} */
     private static final int MSG_RESET_TEMPORARY_SERVICE = 0;
 
+    /** Handler message to clean up temporary broadcast keys. */
+    private static final int MSG_RESET_BROADCAST_KEYS = 1;
+
     /** Default value in absence of {@link DeviceConfig} override. */
     private static final boolean DEFAULT_SERVICE_ENABLED = true;
     private static final String NAMESPACE_ON_DEVICE_INTELLIGENCE = "ondeviceintelligence";
 
+    private static final String SYSTEM_PACKAGE = "android";
+
+
     private final Executor resourceClosingExecutor = Executors.newCachedThreadPool();
     private final Executor callbackExecutor = Executors.newCachedThreadPool();
+    private final Executor broadcastExecutor = Executors.newCachedThreadPool();
+
 
     private final Context mContext;
     protected final Object mLock = new Object();
@@ -123,10 +136,14 @@
     @GuardedBy("mLock")
     private String[] mTemporaryServiceNames;
 
+    @GuardedBy("mLock")
+    private String[] mTemporaryBroadcastKeys;
+    @GuardedBy("mLock")
+    private String mBroadcastPackageName;
+
     /**
      * Handler used to reset the temporary service names.
      */
-    @GuardedBy("mLock")
     private Handler mTemporaryHandler;
 
     public OnDeviceIntelligenceManagerService(Context context) {
@@ -482,6 +499,8 @@
                                     ensureRemoteIntelligenceServiceInitialized();
                                     mRemoteOnDeviceIntelligenceService.run(
                                             IOnDeviceIntelligenceService::notifyInferenceServiceConnected);
+                                    broadcastExecutor.execute(
+                                            () -> registerModelLoadingBroadcasts(service));
                                     service.registerRemoteStorageService(
                                             getIRemoteStorageService());
                                 } catch (RemoteException ex) {
@@ -493,6 +512,56 @@
         }
     }
 
+    private void registerModelLoadingBroadcasts(IOnDeviceSandboxedInferenceService service) {
+        String[] modelBroadcastKeys;
+        try {
+            modelBroadcastKeys = getBroadcastKeys();
+        } catch (Resources.NotFoundException e) {
+            Slog.d(TAG, "Skipping model broadcasts as broadcast intents configured.");
+            return;
+        }
+
+        Bundle bundle = new Bundle();
+        bundle.putBoolean(REGISTER_MODEL_UPDATE_CALLBACK_BUNDLE_KEY, true);
+        try {
+            service.updateProcessingState(bundle, new IProcessingUpdateStatusCallback.Stub() {
+                @Override
+                public void onSuccess(PersistableBundle statusParams) {
+                    Binder.clearCallingIdentity();
+                    synchronized (mLock) {
+                        if (statusParams.containsKey(MODEL_LOADED_BUNDLE_KEY)) {
+                            String modelLoadedBroadcastKey = modelBroadcastKeys[0];
+                            if (modelLoadedBroadcastKey != null
+                                    && !modelLoadedBroadcastKey.isEmpty()) {
+                                final Intent intent = new Intent(modelLoadedBroadcastKey);
+                                intent.setPackage(mBroadcastPackageName);
+                                mContext.sendBroadcast(intent,
+                                        Manifest.permission.USE_ON_DEVICE_INTELLIGENCE);
+                            }
+                        } else if (statusParams.containsKey(MODEL_UNLOADED_BUNDLE_KEY)) {
+                            String modelUnloadedBroadcastKey = modelBroadcastKeys[1];
+                            if (modelUnloadedBroadcastKey != null
+                                    && !modelUnloadedBroadcastKey.isEmpty()) {
+                                final Intent intent = new Intent(modelUnloadedBroadcastKey);
+                                intent.setPackage(mBroadcastPackageName);
+                                mContext.sendBroadcast(intent,
+                                        Manifest.permission.USE_ON_DEVICE_INTELLIGENCE);
+                            }
+                        }
+                    }
+                }
+
+                @Override
+                public void onFailure(int errorCode, String errorMessage) {
+                    Slog.e(TAG, "Failed to register model loading callback with status code",
+                            new OnDeviceIntelligenceException(errorCode, errorMessage));
+                }
+            });
+        } catch (RemoteException e) {
+            Slog.e(TAG, "Failed to register model loading callback with status code", e);
+        }
+    }
+
     @NonNull
     private IRemoteStorageService.Stub getIRemoteStorageService() {
         return new IRemoteStorageService.Stub() {
@@ -629,6 +698,20 @@
                         R.string.config_defaultOnDeviceSandboxedInferenceService)};
     }
 
+    protected String[] getBroadcastKeys() throws Resources.NotFoundException {
+        // TODO 329240495 : Consider a small class with explicit field names for the two services
+        synchronized (mLock) {
+            if (mTemporaryBroadcastKeys != null && mTemporaryBroadcastKeys.length == 2) {
+                return mTemporaryBroadcastKeys;
+            }
+        }
+
+        return new String[]{mContext.getResources().getString(
+                R.string.config_onDeviceIntelligenceModelLoadedBroadcastKey),
+                mContext.getResources().getString(
+                        R.string.config_onDeviceIntelligenceModelUnloadedBroadcastKey)};
+    }
+
     @RequiresPermission(Manifest.permission.USE_ON_DEVICE_INTELLIGENCE)
     public void setTemporaryServices(@NonNull String[] componentNames, int durationMs) {
         Objects.requireNonNull(componentNames);
@@ -645,25 +728,26 @@
                 mRemoteOnDeviceIntelligenceService.unbind();
                 mRemoteOnDeviceIntelligenceService = null;
             }
-            if (mTemporaryHandler == null) {
-                mTemporaryHandler = new Handler(Looper.getMainLooper(), null, true) {
-                    @Override
-                    public void handleMessage(Message msg) {
-                        if (msg.what == MSG_RESET_TEMPORARY_SERVICE) {
-                            synchronized (mLock) {
-                                resetTemporaryServices();
-                            }
-                        } else {
-                            Slog.wtf(TAG, "invalid handler msg: " + msg);
-                        }
-                    }
-                };
-            } else {
-                mTemporaryHandler.removeMessages(MSG_RESET_TEMPORARY_SERVICE);
-            }
 
             if (durationMs != -1) {
-                mTemporaryHandler.sendEmptyMessageDelayed(MSG_RESET_TEMPORARY_SERVICE, durationMs);
+                getTemporaryHandler().sendEmptyMessageDelayed(MSG_RESET_TEMPORARY_SERVICE,
+                        durationMs);
+            }
+        }
+    }
+
+    @RequiresPermission(Manifest.permission.USE_ON_DEVICE_INTELLIGENCE)
+    public void setModelBroadcastKeys(@NonNull String[] broadcastKeys, String receiverPackageName,
+            int durationMs) {
+        Objects.requireNonNull(broadcastKeys);
+        enforceShellOnly(Binder.getCallingUid(), "setModelBroadcastKeys");
+        mContext.enforceCallingPermission(
+                Manifest.permission.USE_ON_DEVICE_INTELLIGENCE, TAG);
+        synchronized (mLock) {
+            mTemporaryBroadcastKeys = broadcastKeys;
+            mBroadcastPackageName = receiverPackageName;
+            if (durationMs != -1) {
+                getTemporaryHandler().sendEmptyMessageDelayed(MSG_RESET_BROADCAST_KEYS, durationMs);
             }
         }
     }
@@ -751,4 +835,28 @@
             }
         }
     }
+
+    private synchronized Handler getTemporaryHandler() {
+        if (mTemporaryHandler == null) {
+            mTemporaryHandler = new Handler(Looper.getMainLooper(), null, true) {
+                @Override
+                public void handleMessage(Message msg) {
+                    if (msg.what == MSG_RESET_TEMPORARY_SERVICE) {
+                        synchronized (mLock) {
+                            resetTemporaryServices();
+                        }
+                    } else if (msg.what == MSG_RESET_BROADCAST_KEYS) {
+                        synchronized (mLock) {
+                            mTemporaryBroadcastKeys = null;
+                            mBroadcastPackageName = SYSTEM_PACKAGE;
+                        }
+                    } else {
+                        Slog.wtf(TAG, "invalid handler msg: " + msg);
+                    }
+                }
+            };
+        }
+
+        return mTemporaryHandler;
+    }
 }
diff --git a/services/core/java/com/android/server/ondeviceintelligence/OnDeviceIntelligenceShellCommand.java b/services/core/java/com/android/server/ondeviceintelligence/OnDeviceIntelligenceShellCommand.java
index a76d8a3..5744b5c 100644
--- a/services/core/java/com/android/server/ondeviceintelligence/OnDeviceIntelligenceShellCommand.java
+++ b/services/core/java/com/android/server/ondeviceintelligence/OnDeviceIntelligenceShellCommand.java
@@ -43,6 +43,8 @@
                 return setTemporaryServices();
             case "get-services":
                 return getConfiguredServices();
+            case "set-model-broadcasts":
+                return setBroadcastKeys();
             default:
                 return handleDefaultCommands(cmd);
         }
@@ -62,12 +64,18 @@
         pw.println("    To reset, call without any arguments.");
 
         pw.println("  get-services To get the names of services that are currently being used.");
+        pw.println(
+                "  set-model-broadcasts [ModelLoadedBroadcastKey] [ModelUnloadedBroadcastKey] "
+                        + "[ReceiverPackageName] "
+                        + "[DURATION] To set the names of broadcast intent keys that are to be "
+                        + "emitted for cts tests.");
     }
 
     private int setTemporaryServices() {
         final PrintWriter out = getOutPrintWriter();
         final String intelligenceServiceName = getNextArg();
         final String inferenceServiceName = getNextArg();
+
         if (getRemainingArgsCount() == 0 && intelligenceServiceName == null
                 && inferenceServiceName == null) {
             mService.resetTemporaryServices();
@@ -79,7 +87,8 @@
         Objects.requireNonNull(inferenceServiceName);
         final int duration = Integer.parseInt(getNextArgRequired());
         mService.setTemporaryServices(
-                new String[]{intelligenceServiceName, inferenceServiceName}, duration);
+                new String[]{intelligenceServiceName, inferenceServiceName},
+                duration);
         out.println("OnDeviceIntelligenceService temporarily set to " + intelligenceServiceName
                 + " \n and \n OnDeviceTrustedInferenceService set to " + inferenceServiceName
                 + " for " + duration + "ms");
@@ -93,4 +102,22 @@
                 + " \n and \n OnDeviceTrustedInferenceService set to : " + services[1]);
         return 0;
     }
+
+    private int setBroadcastKeys() {
+        final PrintWriter out = getOutPrintWriter();
+        final String modelLoadedKey = getNextArgRequired();
+        final String modelUnloadedKey = getNextArgRequired();
+        final String receiverPackageName = getNextArg();
+
+        final int duration = Integer.parseInt(getNextArgRequired());
+        mService.setModelBroadcastKeys(
+                new String[]{modelLoadedKey, modelUnloadedKey}, receiverPackageName, duration);
+        out.println("OnDeviceIntelligence Model Loading broadcast keys temporarily set to "
+                + modelLoadedKey
+                + " \n and \n OnDeviceTrustedInferenceService set to " + modelUnloadedKey
+                + "\n and Package name set to : " + receiverPackageName
+                + " for " + duration + "ms");
+        return 0;
+    }
+
 }
\ No newline at end of file