Merge "Rewrite Icons from the TCS." into rvc-dev
diff --git a/core/java/android/app/RemoteAction.java b/core/java/android/app/RemoteAction.java
index 1b13772..5a4244f 100644
--- a/core/java/android/app/RemoteAction.java
+++ b/core/java/android/app/RemoteAction.java
@@ -161,4 +161,4 @@
                     return new RemoteAction[size];
                 }
             };
-}
\ No newline at end of file
+}
diff --git a/core/java/android/service/textclassifier/TextClassifierService.java b/core/java/android/service/textclassifier/TextClassifierService.java
index 9dfbc28..93faa58 100644
--- a/core/java/android/service/textclassifier/TextClassifierService.java
+++ b/core/java/android/service/textclassifier/TextClassifierService.java
@@ -424,6 +424,11 @@
         return bundle.getParcelable(KEY_RESULT);
     }
 
+    /** @hide **/
+    public static <T extends Parcelable> void putResponse(Bundle bundle, T response) {
+        bundle.putParcelable(KEY_RESULT, response);
+    }
+
     /**
      * Callbacks for TextClassifierService results.
      *
diff --git a/core/java/android/view/textclassifier/ConversationAction.java b/core/java/android/view/textclassifier/ConversationAction.java
index e633404..bf0409d 100644
--- a/core/java/android/view/textclassifier/ConversationAction.java
+++ b/core/java/android/view/textclassifier/ConversationAction.java
@@ -206,6 +206,15 @@
         return mExtras;
     }
 
+    /** @hide */
+    public Builder toBuilder() {
+        return new Builder(mType)
+            .setTextReply(mTextReply)
+            .setAction(mAction)
+            .setConfidenceScore(mScore)
+            .setExtras(mExtras);
+    }
+
     /** Builder class to construct {@link ConversationAction}. */
     public static final class Builder {
         @Nullable
diff --git a/core/java/android/view/textclassifier/EntityConfidence.java b/core/java/android/view/textclassifier/EntityConfidence.java
index 4c12dda..b4313b7 100644
--- a/core/java/android/view/textclassifier/EntityConfidence.java
+++ b/core/java/android/view/textclassifier/EntityConfidence.java
@@ -88,6 +88,10 @@
         return 0;
     }
 
+    public Map<String, Float> toMap() {
+        return new ArrayMap(mEntityConfidence);
+    }
+
     @Override
     public String toString() {
         return mEntityConfidence.toString();
diff --git a/core/java/android/view/textclassifier/TextClassification.java b/core/java/android/view/textclassifier/TextClassification.java
index 3aed32a..ab6dcb1 100644
--- a/core/java/android/view/textclassifier/TextClassification.java
+++ b/core/java/android/view/textclassifier/TextClassification.java
@@ -48,6 +48,7 @@
 import java.lang.annotation.RetentionPolicy;
 import java.time.ZonedDateTime;
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.List;
 import java.util.Locale;
@@ -270,6 +271,20 @@
         return mExtras;
     }
 
+    /** @hide */
+    public Builder toBuilder() {
+        return new Builder()
+                .setId(mId)
+                .setText(mText)
+                .addActions(mActions)
+                .setEntityConfidence(mEntityConfidence)
+                .setIcon(mLegacyIcon)
+                .setLabel(mLegacyLabel)
+                .setIntent(mLegacyIntent)
+                .setOnClickListener(mLegacyOnClickListener)
+                .setExtras(mExtras);
+    }
+
     @Override
     public String toString() {
         return String.format(Locale.US,
@@ -323,7 +338,7 @@
      */
     public static final class Builder {
 
-        @NonNull private List<RemoteAction> mActions = new ArrayList<>();
+        @NonNull private final List<RemoteAction> mActions = new ArrayList<>();
         @NonNull private final Map<String, Float> mTypeScoreMap = new ArrayMap<>();
         @Nullable private String mText;
         @Nullable private Drawable mLegacyIcon;
@@ -332,8 +347,6 @@
         @Nullable private OnClickListener mLegacyOnClickListener;
         @Nullable private String mId;
         @Nullable private Bundle mExtras;
-        @NonNull private final ArrayList<Intent> mActionIntents = new ArrayList<>();
-        @Nullable private Bundle mForeignLanguageExtra;
 
         /**
          * Sets the classified text.
@@ -361,6 +374,18 @@
             return this;
         }
 
+        Builder setEntityConfidence(EntityConfidence scores) {
+            mTypeScoreMap.clear();
+            mTypeScoreMap.putAll(scores.toMap());
+            return this;
+        }
+
+        /** @hide */
+        public Builder clearEntityTypes() {
+            mTypeScoreMap.clear();
+            return this;
+        }
+
         /**
          * Adds an action that may be performed on the classified text. Actions should be added in
          * order of likelihood that the user will use them, with the most likely action being added
@@ -368,19 +393,21 @@
          */
         @NonNull
         public Builder addAction(@NonNull RemoteAction action) {
-            return addAction(action, null);
-        }
-
-        /**
-         * @param intent the intent in the remote action.
-         * @see #addAction(RemoteAction)
-         * @hide
-         */
-        @VisibleForTesting(visibility = VisibleForTesting.Visibility.PACKAGE)
-        public Builder addAction(RemoteAction action, @Nullable Intent intent) {
             Preconditions.checkArgument(action != null);
             mActions.add(action);
-            mActionIntents.add(intent);
+            return this;
+        }
+
+        /** @hide */
+        public Builder addActions(Collection<RemoteAction> actions) {
+            Objects.requireNonNull(actions);
+            mActions.addAll(actions);
+            return this;
+        }
+
+        /** @hide */
+        public Builder clearActions() {
+            mActions.clear();
             return this;
         }
 
@@ -466,16 +493,6 @@
         }
 
         /**
-         * @see #setExtras(Bundle)
-         * @hide
-         */
-        @VisibleForTesting(visibility = VisibleForTesting.Visibility.PACKAGE)
-        public Builder setForeignLanguageExtra(@Nullable Bundle extra) {
-            mForeignLanguageExtra = extra;
-            return this;
-        }
-
-        /**
          * Builds and returns a {@link TextClassification} object.
          */
         @NonNull
diff --git a/core/tests/coretests/src/android/view/textclassifier/ConversationActionTest.java b/core/tests/coretests/src/android/view/textclassifier/ConversationActionTest.java
new file mode 100644
index 0000000..6b62635
--- /dev/null
+++ b/core/tests/coretests/src/android/view/textclassifier/ConversationActionTest.java
@@ -0,0 +1,73 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.view.textclassifier;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.app.PendingIntent;
+import android.app.RemoteAction;
+import android.content.Context;
+import android.content.Intent;
+import android.graphics.drawable.Icon;
+import android.os.Bundle;
+
+import androidx.test.InstrumentationRegistry;
+import androidx.test.filters.SmallTest;
+import androidx.test.runner.AndroidJUnit4;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public final class ConversationActionTest {
+
+    @Test
+    public void toBuilder() {
+        final Context context = InstrumentationRegistry.getTargetContext();
+        final PendingIntent intent = PendingIntent.getActivity(context, 0, new Intent(), 0);
+        final Icon icon = Icon.createWithData(new byte[]{0}, 0, 1);
+        final Bundle extras = new Bundle();
+        extras.putInt("key", 5);
+        final ConversationAction convAction =
+                new ConversationAction.Builder(ConversationAction.TYPE_CALL_PHONE)
+                        .setAction(new RemoteAction(icon, "title", "descr", intent))
+                        .setConfidenceScore(0.5f)
+                        .setExtras(extras)
+                        .build();
+
+        final ConversationAction fromBuilder = convAction.toBuilder().build();
+
+        assertThat(fromBuilder.getType()).isEqualTo(convAction.getType());
+        assertThat(fromBuilder.getAction()).isEqualTo(convAction.getAction());
+        assertThat(fromBuilder.getConfidenceScore()).isEqualTo(convAction.getConfidenceScore());
+        assertThat(fromBuilder.getExtras()).isEqualTo(convAction.getExtras());
+        assertThat(fromBuilder.getTextReply()).isEqualTo(convAction.getTextReply());
+    }
+
+    @Test
+    public void toBuilder_textReply() {
+        final ConversationAction convAction =
+                new ConversationAction.Builder(ConversationAction.TYPE_TEXT_REPLY)
+                        .setTextReply(":P")
+                        .build();
+
+        final ConversationAction fromBuilder = convAction.toBuilder().build();
+
+        assertThat(fromBuilder.getTextReply()).isEqualTo(convAction.getTextReply());
+    }
+}
diff --git a/core/tests/coretests/src/android/view/textclassifier/TextClassificationTest.java b/core/tests/coretests/src/android/view/textclassifier/TextClassificationTest.java
index 39ededa..cf742b0 100644
--- a/core/tests/coretests/src/android/view/textclassifier/TextClassificationTest.java
+++ b/core/tests/coretests/src/android/view/textclassifier/TextClassificationTest.java
@@ -57,6 +57,7 @@
     static {
         BUNDLE.putString(BUNDLE_KEY, BUNDLE_VALUE);
     }
+    private static final float EPSILON = 1e-7f;
 
     public Icon generateTestIcon(int width, int height, int colorValue) {
         final int numPixels = width * height;
@@ -128,8 +129,8 @@
         assertEquals(2, result.getEntityCount());
         assertEquals(TextClassifier.TYPE_PHONE, result.getEntity(0));
         assertEquals(TextClassifier.TYPE_ADDRESS, result.getEntity(1));
-        assertEquals(0.7f, result.getConfidenceScore(TextClassifier.TYPE_PHONE), 1e-7f);
-        assertEquals(0.3f, result.getConfidenceScore(TextClassifier.TYPE_ADDRESS), 1e-7f);
+        assertEquals(0.7f, result.getConfidenceScore(TextClassifier.TYPE_PHONE), EPSILON);
+        assertEquals(0.3f, result.getConfidenceScore(TextClassifier.TYPE_ADDRESS), EPSILON);
 
         // Extras
         assertEquals(BUNDLE_VALUE, result.getExtras().getString(BUNDLE_KEY));
@@ -226,4 +227,45 @@
         assertEquals(1, resultSystemTcMetadata.getUserId());
         assertFalse(resultSystemTcMetadata.useDefaultTextClassifier());
     }
+
+    @Test
+    public void testToBuilder() {
+        final Context context = InstrumentationRegistry.getInstrumentation().getContext();
+        final Icon icon1 = generateTestIcon(5, 5, Color.RED);
+        final Icon icon2 = generateTestIcon(2, 10, Color.BLUE);
+        final TextClassification classification = new TextClassification.Builder()
+                .setIcon(icon1.loadDrawable(context))
+                .setLabel("label")
+                .setIntent(new Intent("action"))
+                .setOnClickListener(view -> { })
+                .addAction(new RemoteAction(icon1, "title1", "desc1",
+                          PendingIntent.getActivity(context, 0, new Intent("action1"), 0)))
+                .addAction(new RemoteAction(icon1, "title2", "desc2",
+                          PendingIntent.getActivity(context, 0, new Intent("action2"), 0)))
+                .setEntityType(TextClassifier.TYPE_EMAIL, 0.5f)
+                .setEntityType(TextClassifier.TYPE_PHONE, 0.4f)
+                .build();
+
+        final TextClassification fromBuilder = classification.toBuilder().build();
+
+        assertEquals(classification.getId(), fromBuilder.getId());
+        assertEquals(classification.getText(), fromBuilder.getText());
+        assertEquals(classification.getIcon(), fromBuilder.getIcon());
+        assertEquals(classification.getLabel(), fromBuilder.getLabel());
+        assertEquals(classification.getIntent(), fromBuilder.getIntent());
+        assertEquals(classification.getOnClickListener(), fromBuilder.getOnClickListener());
+        assertEquals(classification.getExtras(), fromBuilder.getExtras());
+        assertEquals(classification.getActions(), fromBuilder.getActions());
+        assertEquals(classification.getEntityCount(), fromBuilder.getEntityCount());
+        assertEquals(classification.getEntity(0), fromBuilder.getEntity(0));
+        assertEquals(classification.getEntity(1), fromBuilder.getEntity(1));
+        assertEquals(
+                classification.getConfidenceScore(TextClassifier.TYPE_EMAIL),
+                fromBuilder.getConfidenceScore(TextClassifier.TYPE_EMAIL),
+                EPSILON);
+        assertEquals(
+                classification.getConfidenceScore(TextClassifier.TYPE_PHONE),
+                fromBuilder.getConfidenceScore(TextClassifier.TYPE_PHONE),
+                EPSILON);
+    }
 }
diff --git a/services/core/java/com/android/server/textclassifier/TextClassificationManagerService.java b/services/core/java/com/android/server/textclassifier/TextClassificationManagerService.java
index 9a5b020..5657c74 100644
--- a/services/core/java/com/android/server/textclassifier/TextClassificationManagerService.java
+++ b/services/core/java/com/android/server/textclassifier/TextClassificationManagerService.java
@@ -19,14 +19,18 @@
 import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.annotation.UserIdInt;
+import android.app.RemoteAction;
 import android.content.ComponentName;
 import android.content.Context;
 import android.content.Intent;
 import android.content.ServiceConnection;
 import android.content.pm.PackageManager;
+import android.graphics.drawable.Icon;
+import android.net.Uri;
 import android.os.Binder;
 import android.os.Bundle;
 import android.os.IBinder;
+import android.os.Parcelable;
 import android.os.Process;
 import android.os.RemoteException;
 import android.os.UserHandle;
@@ -39,6 +43,7 @@
 import android.util.ArrayMap;
 import android.util.Slog;
 import android.util.SparseArray;
+import android.view.textclassifier.ConversationAction;
 import android.view.textclassifier.ConversationActions;
 import android.view.textclassifier.SelectionEvent;
 import android.view.textclassifier.SystemTextClassifierMetadata;
@@ -69,6 +74,7 @@
 import java.util.Map;
 import java.util.Objects;
 import java.util.Queue;
+import java.util.stream.Collectors;
 
 /**
  * A manager for TextClassifier services.
@@ -203,7 +209,7 @@
                 request.getSystemTextClassifierMetadata(),
                 /* verifyCallingPackage= */ true,
                 /* attemptToBind= */ true,
-                service -> service.onClassifyText(sessionId, request, callback),
+                service -> service.onClassifyText(sessionId, request, wrap(callback)),
                 "onClassifyText",
                 callback);
     }
@@ -289,7 +295,8 @@
                 request.getSystemTextClassifierMetadata(),
                 /* verifyCallingPackage= */ true,
                 /* attemptToBind= */ true,
-                service -> service.onSuggestConversationActions(sessionId, request, callback),
+                service -> service.onSuggestConversationActions(
+                        sessionId, request, wrap(callback)),
                 "onSuggestConversationActions",
                 callback);
     }
@@ -464,6 +471,10 @@
         }
     }
 
+    private static ITextClassifierCallback wrap(ITextClassifierCallback orig) {
+        return new CallbackWrapper(orig);
+    }
+
     private void onTextClassifierServicePackageOverrideChanged(String overriddenPackage) {
         synchronized (mLock) {
             final int size = mUserStates.size();
@@ -1004,4 +1015,112 @@
             onTextClassifierServicePackageOverrideChanged(currentServicePackageOverride);
         }
     }
+
+    /**
+     * Wraps an ITextClassifierCallback and modifies the response to it where necessary.
+     */
+    private static final class CallbackWrapper extends ITextClassifierCallback.Stub {
+
+        private final ITextClassifierCallback mWrapped;
+
+        CallbackWrapper(ITextClassifierCallback wrapped) {
+            mWrapped = Objects.requireNonNull(wrapped);
+        }
+
+        @Override
+        public void onSuccess(Bundle result) {
+            final Parcelable parcelled = TextClassifierService.getResponse(result);
+            if (parcelled instanceof TextClassification) {
+                rewriteTextClassificationIcons(result);
+            } else if (parcelled instanceof ConversationActions) {
+                rewriteConversationActionsIcons(result);
+            } else {
+                // do nothing.
+            }
+            try {
+                mWrapped.onSuccess(result);
+            } catch (RemoteException e) {
+                Slog.e(LOG_TAG, "Callback error", e);
+            }
+        }
+
+        private static void rewriteTextClassificationIcons(Bundle result) {
+            final TextClassification classification = TextClassifierService.getResponse(result);
+            boolean rewrite = false;
+            for (RemoteAction action : classification.getActions()) {
+                rewrite |= shouldRewriteIcon(action);
+            }
+            if (rewrite) {
+                TextClassifierService.putResponse(
+                        result,
+                        classification.toBuilder()
+                                .clearActions()
+                                .addActions(classification.getActions()
+                                        .stream()
+                                        .map(action -> validAction(action))
+                                        .collect(Collectors.toList()))
+                                .build());
+            }
+        }
+
+        private static void rewriteConversationActionsIcons(Bundle result) {
+            final ConversationActions convActions = TextClassifierService.getResponse(result);
+            boolean rewrite = false;
+            for (ConversationAction convAction : convActions.getConversationActions()) {
+                rewrite |= shouldRewriteIcon(convAction.getAction());
+            }
+            if (rewrite) {
+                TextClassifierService.putResponse(
+                        result,
+                        new ConversationActions(
+                                convActions.getConversationActions()
+                                        .stream()
+                                        .map(convAction -> convAction.toBuilder()
+                                                .setAction(validAction(convAction.getAction()))
+                                                .build())
+                                        .collect(Collectors.toList()),
+                                convActions.getId()));
+            }
+        }
+
+        @Nullable
+        private static RemoteAction validAction(@Nullable RemoteAction action) {
+            if (!shouldRewriteIcon(action)) {
+                return action;
+            }
+
+            final RemoteAction newAction = new RemoteAction(
+                    changeIcon(action.getIcon()),
+                    action.getTitle(),
+                    action.getContentDescription(),
+                    action.getActionIntent());
+            newAction.setEnabled(action.isEnabled());
+            newAction.setShouldShowIcon(action.shouldShowIcon());
+            return newAction;
+        }
+
+        private static boolean shouldRewriteIcon(@Nullable RemoteAction action) {
+            // Check whether to rewrite the icon.
+            // Rewrite icons to ensure that the icons do not:
+            // 1. Leak package names
+            // 2. are renderable in the client process.
+            return action != null && action.getIcon().getType() == Icon.TYPE_RESOURCE;
+        }
+
+        /** Changes icon of type=RESOURCES to icon of type=URI. */
+        private static Icon changeIcon(Icon icon) {
+            final Uri uri = IconsUriHelper.getInstance()
+                    .getContentUri(icon.getResPackage(), icon.getResId());
+            return Icon.createWithContentUri(uri);
+        }
+
+        @Override
+        public void onFailure() {
+            try {
+                mWrapped.onFailure();
+            } catch (RemoteException e) {
+                Slog.e(LOG_TAG, "Callback error", e);
+            }
+        }
+    }
 }