Handle Model Loading/Unloading Broadcasts for Samsung in Isolated AICore
This change registers a callback with the remote service via
updateProcessingState to receive model loading/unloading updates and
then send these as broadcasts to system process.
Bug: 339688752
Test: cts in topic
Change-Id: I84ae64c5f5dfcbc7005a912b31861b3b64bf61f1
diff --git a/core/java/android/service/ondeviceintelligence/OnDeviceSandboxedInferenceService.java b/core/java/android/service/ondeviceintelligence/OnDeviceSandboxedInferenceService.java
index 29a6db6..8237b20 100644
--- a/core/java/android/service/ondeviceintelligence/OnDeviceSandboxedInferenceService.java
+++ b/core/java/android/service/ondeviceintelligence/OnDeviceSandboxedInferenceService.java
@@ -105,6 +105,21 @@
public static final String SERVICE_INTERFACE =
"android.service.ondeviceintelligence.OnDeviceSandboxedInferenceService";
+ // TODO(339594686): make API
+ /**
+ * @hide
+ */
+ public static final String REGISTER_MODEL_UPDATE_CALLBACK_BUNDLE_KEY =
+ "register_model_update_callback";
+ /**
+ * @hide
+ */
+ public static final String MODEL_LOADED_BUNDLE_KEY = "model_loaded";
+ /**
+ * @hide
+ */
+ public static final String MODEL_UNLOADED_BUNDLE_KEY = "model_unloaded";
+
private IRemoteStorageService mRemoteStorageService;
/**
diff --git a/core/res/res/values/config.xml b/core/res/res/values/config.xml
index cefc648..13b94d8 100644
--- a/core/res/res/values/config.xml
+++ b/core/res/res/values/config.xml
@@ -4704,6 +4704,13 @@
<!-- The component name for the default system on-device sandboxed inference service. -->
<string name="config_defaultOnDeviceSandboxedInferenceService" translatable="false"></string>
+ <!-- The broadcast intent name for notifying when the on-device model is loading -->
+ <string name="config_onDeviceIntelligenceModelLoadedBroadcastKey" translatable="false"></string>
+
+ <!-- The broadcast intent name for notifying when the on-device model has been unloaded -->
+ <string name="config_onDeviceIntelligenceModelUnloadedBroadcastKey" translatable="false"></string>
+
+
<!-- Component name that accepts ACTION_SEND intents for requesting ambient context consent for
wearable sensing. -->
<string translatable="false" name="config_defaultWearableSensingConsentComponent"></string>
diff --git a/core/res/res/values/symbols.xml b/core/res/res/values/symbols.xml
index d058fb1..dcfbda0 100644
--- a/core/res/res/values/symbols.xml
+++ b/core/res/res/values/symbols.xml
@@ -3941,6 +3941,8 @@
<java-symbol type="string" name="config_defaultWearableSensingService" />
<java-symbol type="string" name="config_defaultOnDeviceIntelligenceService" />
<java-symbol type="string" name="config_defaultOnDeviceSandboxedInferenceService" />
+ <java-symbol type="string" name="config_onDeviceIntelligenceModelLoadedBroadcastKey" />
+ <java-symbol type="string" name="config_onDeviceIntelligenceModelUnloadedBroadcastKey" />
<java-symbol type="string" name="config_retailDemoPackage" />
<java-symbol type="string" name="config_retailDemoPackageSignature" />
diff --git a/services/core/java/com/android/server/ondeviceintelligence/OnDeviceIntelligenceManagerService.java b/services/core/java/com/android/server/ondeviceintelligence/OnDeviceIntelligenceManagerService.java
index 99401a1..235e3cd 100644
--- a/services/core/java/com/android/server/ondeviceintelligence/OnDeviceIntelligenceManagerService.java
+++ b/services/core/java/com/android/server/ondeviceintelligence/OnDeviceIntelligenceManagerService.java
@@ -16,6 +16,10 @@
package com.android.server.ondeviceintelligence;
+import static android.service.ondeviceintelligence.OnDeviceSandboxedInferenceService.MODEL_LOADED_BUNDLE_KEY;
+import static android.service.ondeviceintelligence.OnDeviceSandboxedInferenceService.MODEL_UNLOADED_BUNDLE_KEY;
+import static android.service.ondeviceintelligence.OnDeviceSandboxedInferenceService.REGISTER_MODEL_UPDATE_CALLBACK_BUNDLE_KEY;
+
import static com.android.server.ondeviceintelligence.BundleUtil.sanitizeInferenceParams;
import static com.android.server.ondeviceintelligence.BundleUtil.validatePfdReadOnly;
import static com.android.server.ondeviceintelligence.BundleUtil.sanitizeStateParams;
@@ -41,6 +45,7 @@
import android.app.ondeviceintelligence.OnDeviceIntelligenceException;
import android.content.ComponentName;
import android.content.Context;
+import android.content.Intent;
import android.content.pm.PackageManager;
import android.content.pm.ServiceInfo;
import android.content.res.Resources;
@@ -105,12 +110,20 @@
/** Handler message to {@link #resetTemporaryServices()} */
private static final int MSG_RESET_TEMPORARY_SERVICE = 0;
+ /** Handler message to clean up temporary broadcast keys. */
+ private static final int MSG_RESET_BROADCAST_KEYS = 1;
+
/** Default value in absence of {@link DeviceConfig} override. */
private static final boolean DEFAULT_SERVICE_ENABLED = true;
private static final String NAMESPACE_ON_DEVICE_INTELLIGENCE = "ondeviceintelligence";
+ private static final String SYSTEM_PACKAGE = "android";
+
+
private final Executor resourceClosingExecutor = Executors.newCachedThreadPool();
private final Executor callbackExecutor = Executors.newCachedThreadPool();
+ private final Executor broadcastExecutor = Executors.newCachedThreadPool();
+
private final Context mContext;
protected final Object mLock = new Object();
@@ -123,10 +136,14 @@
@GuardedBy("mLock")
private String[] mTemporaryServiceNames;
+ @GuardedBy("mLock")
+ private String[] mTemporaryBroadcastKeys;
+ @GuardedBy("mLock")
+ private String mBroadcastPackageName;
+
/**
* Handler used to reset the temporary service names.
*/
- @GuardedBy("mLock")
private Handler mTemporaryHandler;
public OnDeviceIntelligenceManagerService(Context context) {
@@ -482,6 +499,8 @@
ensureRemoteIntelligenceServiceInitialized();
mRemoteOnDeviceIntelligenceService.run(
IOnDeviceIntelligenceService::notifyInferenceServiceConnected);
+ broadcastExecutor.execute(
+ () -> registerModelLoadingBroadcasts(service));
service.registerRemoteStorageService(
getIRemoteStorageService());
} catch (RemoteException ex) {
@@ -493,6 +512,56 @@
}
}
+ private void registerModelLoadingBroadcasts(IOnDeviceSandboxedInferenceService service) {
+ String[] modelBroadcastKeys;
+ try {
+ modelBroadcastKeys = getBroadcastKeys();
+ } catch (Resources.NotFoundException e) {
+ Slog.d(TAG, "Skipping model broadcasts as broadcast intents configured.");
+ return;
+ }
+
+ Bundle bundle = new Bundle();
+ bundle.putBoolean(REGISTER_MODEL_UPDATE_CALLBACK_BUNDLE_KEY, true);
+ try {
+ service.updateProcessingState(bundle, new IProcessingUpdateStatusCallback.Stub() {
+ @Override
+ public void onSuccess(PersistableBundle statusParams) {
+ Binder.clearCallingIdentity();
+ synchronized (mLock) {
+ if (statusParams.containsKey(MODEL_LOADED_BUNDLE_KEY)) {
+ String modelLoadedBroadcastKey = modelBroadcastKeys[0];
+ if (modelLoadedBroadcastKey != null
+ && !modelLoadedBroadcastKey.isEmpty()) {
+ final Intent intent = new Intent(modelLoadedBroadcastKey);
+ intent.setPackage(mBroadcastPackageName);
+ mContext.sendBroadcast(intent,
+ Manifest.permission.USE_ON_DEVICE_INTELLIGENCE);
+ }
+ } else if (statusParams.containsKey(MODEL_UNLOADED_BUNDLE_KEY)) {
+ String modelUnloadedBroadcastKey = modelBroadcastKeys[1];
+ if (modelUnloadedBroadcastKey != null
+ && !modelUnloadedBroadcastKey.isEmpty()) {
+ final Intent intent = new Intent(modelUnloadedBroadcastKey);
+ intent.setPackage(mBroadcastPackageName);
+ mContext.sendBroadcast(intent,
+ Manifest.permission.USE_ON_DEVICE_INTELLIGENCE);
+ }
+ }
+ }
+ }
+
+ @Override
+ public void onFailure(int errorCode, String errorMessage) {
+ Slog.e(TAG, "Failed to register model loading callback with status code",
+ new OnDeviceIntelligenceException(errorCode, errorMessage));
+ }
+ });
+ } catch (RemoteException e) {
+ Slog.e(TAG, "Failed to register model loading callback with status code", e);
+ }
+ }
+
@NonNull
private IRemoteStorageService.Stub getIRemoteStorageService() {
return new IRemoteStorageService.Stub() {
@@ -629,6 +698,20 @@
R.string.config_defaultOnDeviceSandboxedInferenceService)};
}
+ protected String[] getBroadcastKeys() throws Resources.NotFoundException {
+ // TODO 329240495 : Consider a small class with explicit field names for the two services
+ synchronized (mLock) {
+ if (mTemporaryBroadcastKeys != null && mTemporaryBroadcastKeys.length == 2) {
+ return mTemporaryBroadcastKeys;
+ }
+ }
+
+ return new String[]{mContext.getResources().getString(
+ R.string.config_onDeviceIntelligenceModelLoadedBroadcastKey),
+ mContext.getResources().getString(
+ R.string.config_onDeviceIntelligenceModelUnloadedBroadcastKey)};
+ }
+
@RequiresPermission(Manifest.permission.USE_ON_DEVICE_INTELLIGENCE)
public void setTemporaryServices(@NonNull String[] componentNames, int durationMs) {
Objects.requireNonNull(componentNames);
@@ -645,25 +728,26 @@
mRemoteOnDeviceIntelligenceService.unbind();
mRemoteOnDeviceIntelligenceService = null;
}
- if (mTemporaryHandler == null) {
- mTemporaryHandler = new Handler(Looper.getMainLooper(), null, true) {
- @Override
- public void handleMessage(Message msg) {
- if (msg.what == MSG_RESET_TEMPORARY_SERVICE) {
- synchronized (mLock) {
- resetTemporaryServices();
- }
- } else {
- Slog.wtf(TAG, "invalid handler msg: " + msg);
- }
- }
- };
- } else {
- mTemporaryHandler.removeMessages(MSG_RESET_TEMPORARY_SERVICE);
- }
if (durationMs != -1) {
- mTemporaryHandler.sendEmptyMessageDelayed(MSG_RESET_TEMPORARY_SERVICE, durationMs);
+ getTemporaryHandler().sendEmptyMessageDelayed(MSG_RESET_TEMPORARY_SERVICE,
+ durationMs);
+ }
+ }
+ }
+
+ @RequiresPermission(Manifest.permission.USE_ON_DEVICE_INTELLIGENCE)
+ public void setModelBroadcastKeys(@NonNull String[] broadcastKeys, String receiverPackageName,
+ int durationMs) {
+ Objects.requireNonNull(broadcastKeys);
+ enforceShellOnly(Binder.getCallingUid(), "setModelBroadcastKeys");
+ mContext.enforceCallingPermission(
+ Manifest.permission.USE_ON_DEVICE_INTELLIGENCE, TAG);
+ synchronized (mLock) {
+ mTemporaryBroadcastKeys = broadcastKeys;
+ mBroadcastPackageName = receiverPackageName;
+ if (durationMs != -1) {
+ getTemporaryHandler().sendEmptyMessageDelayed(MSG_RESET_BROADCAST_KEYS, durationMs);
}
}
}
@@ -751,4 +835,28 @@
}
}
}
+
+ private synchronized Handler getTemporaryHandler() {
+ if (mTemporaryHandler == null) {
+ mTemporaryHandler = new Handler(Looper.getMainLooper(), null, true) {
+ @Override
+ public void handleMessage(Message msg) {
+ if (msg.what == MSG_RESET_TEMPORARY_SERVICE) {
+ synchronized (mLock) {
+ resetTemporaryServices();
+ }
+ } else if (msg.what == MSG_RESET_BROADCAST_KEYS) {
+ synchronized (mLock) {
+ mTemporaryBroadcastKeys = null;
+ mBroadcastPackageName = SYSTEM_PACKAGE;
+ }
+ } else {
+ Slog.wtf(TAG, "invalid handler msg: " + msg);
+ }
+ }
+ };
+ }
+
+ return mTemporaryHandler;
+ }
}
diff --git a/services/core/java/com/android/server/ondeviceintelligence/OnDeviceIntelligenceShellCommand.java b/services/core/java/com/android/server/ondeviceintelligence/OnDeviceIntelligenceShellCommand.java
index a76d8a3..5744b5c 100644
--- a/services/core/java/com/android/server/ondeviceintelligence/OnDeviceIntelligenceShellCommand.java
+++ b/services/core/java/com/android/server/ondeviceintelligence/OnDeviceIntelligenceShellCommand.java
@@ -43,6 +43,8 @@
return setTemporaryServices();
case "get-services":
return getConfiguredServices();
+ case "set-model-broadcasts":
+ return setBroadcastKeys();
default:
return handleDefaultCommands(cmd);
}
@@ -62,12 +64,18 @@
pw.println(" To reset, call without any arguments.");
pw.println(" get-services To get the names of services that are currently being used.");
+ pw.println(
+ " set-model-broadcasts [ModelLoadedBroadcastKey] [ModelUnloadedBroadcastKey] "
+ + "[ReceiverPackageName] "
+ + "[DURATION] To set the names of broadcast intent keys that are to be "
+ + "emitted for cts tests.");
}
private int setTemporaryServices() {
final PrintWriter out = getOutPrintWriter();
final String intelligenceServiceName = getNextArg();
final String inferenceServiceName = getNextArg();
+
if (getRemainingArgsCount() == 0 && intelligenceServiceName == null
&& inferenceServiceName == null) {
mService.resetTemporaryServices();
@@ -79,7 +87,8 @@
Objects.requireNonNull(inferenceServiceName);
final int duration = Integer.parseInt(getNextArgRequired());
mService.setTemporaryServices(
- new String[]{intelligenceServiceName, inferenceServiceName}, duration);
+ new String[]{intelligenceServiceName, inferenceServiceName},
+ duration);
out.println("OnDeviceIntelligenceService temporarily set to " + intelligenceServiceName
+ " \n and \n OnDeviceTrustedInferenceService set to " + inferenceServiceName
+ " for " + duration + "ms");
@@ -93,4 +102,22 @@
+ " \n and \n OnDeviceTrustedInferenceService set to : " + services[1]);
return 0;
}
+
+ private int setBroadcastKeys() {
+ final PrintWriter out = getOutPrintWriter();
+ final String modelLoadedKey = getNextArgRequired();
+ final String modelUnloadedKey = getNextArgRequired();
+ final String receiverPackageName = getNextArg();
+
+ final int duration = Integer.parseInt(getNextArgRequired());
+ mService.setModelBroadcastKeys(
+ new String[]{modelLoadedKey, modelUnloadedKey}, receiverPackageName, duration);
+ out.println("OnDeviceIntelligence Model Loading broadcast keys temporarily set to "
+ + modelLoadedKey
+ + " \n and \n OnDeviceTrustedInferenceService set to " + modelUnloadedKey
+ + "\n and Package name set to : " + receiverPackageName
+ + " for " + duration + "ms");
+ return 0;
+ }
+
}
\ No newline at end of file