Restrict AudioAttribute usages for notifications

Only calls can use USAGE_NOTIFICATION_RINGTONE. Only alarms can
use USAGE_ALARMS. No notifications can play on the media stream.

If an app specifies and incorrect usage we will use USAGE_NOTIFICATION
instead.

Test: NotificationChannelExtractorTest
Flag: com.android.server.notification.restrict_audio_attributes_call DEV
Flag: com.android.server.notification.restrict_audio_attributes_alarm
DEV
Flag: com.android.server.notification.restrict_audio_attributes_media
DEV
Bug: 331793339

Change-Id: Iefb5b7225ee0a8a1da694dc76841990a13a34572
diff --git a/core/java/android/app/NotificationChannel.java b/core/java/android/app/NotificationChannel.java
index 7c803eb..193c524 100644
--- a/core/java/android/app/NotificationChannel.java
+++ b/core/java/android/app/NotificationChannel.java
@@ -434,6 +434,40 @@
     /**
      * @hide
      */
+    public NotificationChannel copy() {
+        NotificationChannel copy = new NotificationChannel(mId, mName, mImportance);
+        copy.setDescription(mDesc);
+        copy.setBypassDnd(mBypassDnd);
+        copy.setLockscreenVisibility(mLockscreenVisibility);
+        copy.setSound(mSound, mAudioAttributes);
+        copy.setLightColor(mLightColor);
+        copy.enableLights(mLights);
+        copy.setVibrationPattern(mVibrationPattern);
+        if (Flags.notificationChannelVibrationEffectApi()) {
+            copy.setVibrationEffect(mVibrationEffect);
+        }
+        copy.lockFields(mUserLockedFields);
+        copy.setUserVisibleTaskShown(mUserVisibleTaskShown);
+        copy.enableVibration(mVibrationEnabled);
+        copy.setShowBadge(mShowBadge);
+        copy.setDeleted(mDeleted);
+        copy.setGroup(mGroup);
+        copy.setBlockable(mBlockableSystem);
+        copy.setAllowBubbles(mAllowBubbles);
+        copy.setOriginalImportance(mOriginalImportance);
+        copy.setConversationId(mParentId, mConversationId);
+        copy.setDemoted(mDemoted);
+        copy.setImportantConversation(mImportantConvo);
+        copy.setDeletedTimeMs(mDeletedTime);
+        copy.setImportanceLockedByCriticalDeviceFunction(mImportanceLockedDefaultApp);
+        copy.setLastNotificationUpdateTimeMs(mLastNotificationUpdateTimeMs);
+
+        return copy;
+    }
+
+    /**
+     * @hide
+     */
     @TestApi
     public void lockFields(int field) {
         mUserLockedFields |= field;
diff --git a/core/java/android/app/notification.aconfig b/core/java/android/app/notification.aconfig
index a2cf672..0214d40 100644
--- a/core/java/android/app/notification.aconfig
+++ b/core/java/android/app/notification.aconfig
@@ -88,4 +88,25 @@
   namespace: "systemui"
   description: "Changes notification sort order to be by time within a section"
   bug: "330193582"
+}
+
+flag {
+  name: "restrict_audio_attributes_call"
+  namespace: "systemui"
+  description: "Only CallStyle notifs can use USAGE_NOTIFICATION_RINGTONE"
+  bug: "331793339"
+}
+
+flag {
+  name: "restrict_audio_attributes_alarm"
+  namespace: "systemui"
+  description: "Only alarm category notifs can use USAGE_ALARM"
+  bug: "331793339"
+}
+
+flag {
+  name: "restrict_audio_attributes_media"
+  namespace: "systemui"
+  description: "No notifs can use USAGE_UNKNOWN or USAGE_MEDIA"
+  bug: "331793339"
 }
\ No newline at end of file
diff --git a/core/tests/coretests/src/android/app/NotificationChannelTest.java b/core/tests/coretests/src/android/app/NotificationChannelTest.java
index 18209b5..504f98f 100644
--- a/core/tests/coretests/src/android/app/NotificationChannelTest.java
+++ b/core/tests/coretests/src/android/app/NotificationChannelTest.java
@@ -31,6 +31,7 @@
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
+import android.annotation.FlaggedApi;
 import android.content.AttributionSource;
 import android.content.ContentProvider;
 import android.content.ContentResolver;
@@ -46,6 +47,7 @@
 import android.os.RemoteCallback;
 import android.os.RemoteException;
 import android.os.VibrationEffect;
+import android.platform.test.annotations.EnableFlags;
 import android.platform.test.annotations.Presubmit;
 import android.platform.test.flag.junit.SetFlagsRule;
 import android.provider.MediaStore.Audio.AudioColumns;
@@ -577,6 +579,40 @@
         assertNull(channel.getVibrationEffect());
     }
 
+    @Test
+    @EnableFlags({Flags.FLAG_RESTRICT_AUDIO_ATTRIBUTES_MEDIA,
+            Flags.FLAG_RESTRICT_AUDIO_ATTRIBUTES_CALL, Flags.FLAG_RESTRICT_AUDIO_ATTRIBUTES_ALARM})
+    public void testCopy() {
+        NotificationChannel original = new NotificationChannel("id", "name", 2);
+        original.setDescription("desc");
+        original.setBypassDnd(true);
+        original.setLockscreenVisibility(7);
+        original.setSound(Uri.EMPTY, new AudioAttributes.Builder().build());
+        original.setLightColor(5);
+        original.enableLights(false);
+        original.setVibrationPattern(new long[] {1, 9, 3});
+        if (Flags.notificationChannelVibrationEffectApi()) {
+            original.setVibrationEffect(VibrationEffect.createOneShot(100, 5));
+        }
+        original.lockFields(9999);
+        original.setUserVisibleTaskShown(true);
+        original.enableVibration(false);
+        original.setShowBadge(true);
+        original.setDeleted(false);
+        original.setGroup("group");
+        original.setBlockable(false);
+        original.setAllowBubbles(true);
+        original.setOriginalImportance(6);
+        original.setConversationId("parent", "convo");
+        original.setDemoted(false);
+        original.setImportantConversation(true);
+        original.setDeletedTimeMs(100);
+        original.setImportanceLockedByCriticalDeviceFunction(false);
+
+        NotificationChannel parcelCopy = writeToAndReadFromParcel(original);
+        assertThat(original.copy()).isEqualTo(parcelCopy);
+    }
+
     /** Backs up a given channel to an XML, and returns the channel read from the XML. */
     private NotificationChannel backUpAndRestore(NotificationChannel channel) throws Exception {
         TypedXmlSerializer serializer = Xml.newFastSerializer();
diff --git a/services/core/java/com/android/server/notification/NotificationChannelExtractor.java b/services/core/java/com/android/server/notification/NotificationChannelExtractor.java
index 2f60e42..bd73cb6 100644
--- a/services/core/java/com/android/server/notification/NotificationChannelExtractor.java
+++ b/services/core/java/com/android/server/notification/NotificationChannelExtractor.java
@@ -15,8 +15,16 @@
 */
 package com.android.server.notification;
 
+import static android.app.Flags.restrictAudioAttributesAlarm;
+import static android.app.Flags.restrictAudioAttributesCall;
+import static android.app.Flags.restrictAudioAttributesMedia;
+import static android.app.Notification.CATEGORY_ALARM;
+import static android.media.AudioAttributes.USAGE_NOTIFICATION;
+
+import android.app.Notification;
 import android.app.NotificationChannel;
 import android.content.Context;
+import android.media.AudioAttributes;
 import android.util.Slog;
 
 /**
@@ -50,6 +58,36 @@
                 record.getSbn().getShortcutId(), true, false);
         record.updateNotificationChannel(updatedChannel);
 
+        if (restrictAudioAttributesCall() || restrictAudioAttributesAlarm()
+                || restrictAudioAttributesMedia()) {
+            AudioAttributes attributes = record.getChannel().getAudioAttributes();
+            boolean updateAttributes =  false;
+            if (restrictAudioAttributesCall()
+                    && !record.getNotification().isStyle(Notification.CallStyle.class)
+                    && attributes.getUsage() == AudioAttributes.USAGE_NOTIFICATION_RINGTONE) {
+                updateAttributes = true;
+            }
+            if (restrictAudioAttributesAlarm()
+                    && record.getNotification().category != CATEGORY_ALARM
+                    && attributes.getUsage() == AudioAttributes.USAGE_ALARM) {
+                updateAttributes = true;
+            }
+
+            if (restrictAudioAttributesMedia()
+                    && (attributes.getUsage() == AudioAttributes.USAGE_UNKNOWN
+                    || attributes.getUsage() == AudioAttributes.USAGE_MEDIA)) {
+                updateAttributes = true;
+            }
+
+            if (updateAttributes) {
+                NotificationChannel clone = record.getChannel().copy();
+                clone.setSound(clone.getSound(), new AudioAttributes.Builder(attributes)
+                        .setUsage(USAGE_NOTIFICATION)
+                        .build());
+                record.updateNotificationChannel(clone);
+            }
+        }
+
         return null;
     }
 
diff --git a/services/core/java/com/android/server/notification/NotificationRecord.java b/services/core/java/com/android/server/notification/NotificationRecord.java
index a4464a1..97d2620 100644
--- a/services/core/java/com/android/server/notification/NotificationRecord.java
+++ b/services/core/java/com/android/server/notification/NotificationRecord.java
@@ -15,6 +15,9 @@
  */
 package com.android.server.notification;
 
+import static android.app.Flags.restrictAudioAttributesAlarm;
+import static android.app.Flags.restrictAudioAttributesCall;
+import static android.app.Flags.restrictAudioAttributesMedia;
 import static android.app.Flags.updateRankingTime;
 import static android.app.NotificationChannel.USER_LOCKED_IMPORTANCE;
 import static android.app.NotificationManager.IMPORTANCE_DEFAULT;
@@ -1159,6 +1162,11 @@
             mChannel = channel;
             calculateImportance();
             calculateUserSentiment();
+            mVibration = calculateVibration();
+            if (restrictAudioAttributesCall() || restrictAudioAttributesAlarm()
+                    || restrictAudioAttributesMedia()) {
+                mAttributes = channel.getAudioAttributes();
+            }
         }
     }
 
diff --git a/services/tests/uiservicestests/src/com/android/server/notification/NotificationChannelExtractorTest.java b/services/tests/uiservicestests/src/com/android/server/notification/NotificationChannelExtractorTest.java
index 77ce2f0..ad25d76 100644
--- a/services/tests/uiservicestests/src/com/android/server/notification/NotificationChannelExtractorTest.java
+++ b/services/tests/uiservicestests/src/com/android/server/notification/NotificationChannelExtractorTest.java
@@ -16,26 +16,43 @@
 
 package com.android.server.notification;
 
+import static android.app.Notification.CATEGORY_ALARM;
 import static android.app.NotificationManager.IMPORTANCE_HIGH;
 import static android.app.NotificationManager.IMPORTANCE_LOW;
 
+import static android.media.AudioAttributes.USAGE_ALARM;
+import static android.media.AudioAttributes.USAGE_MEDIA;
+import static android.media.AudioAttributes.USAGE_NOTIFICATION;
+import static android.media.AudioAttributes.USAGE_NOTIFICATION_RINGTONE;
+import static android.media.AudioAttributes.USAGE_UNKNOWN;
+import static android.platform.test.flag.junit.SetFlagsRule.DefaultInitValueType.DEVICE_DEFAULT;
+import static com.google.common.truth.Truth.assertThat;
 import static junit.framework.Assert.assertEquals;
 import static junit.framework.Assert.assertNull;
 
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyInt;
 import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
+import android.app.Flags;
 import android.app.Notification;
 import android.app.NotificationChannel;
+import android.app.PendingIntent;
+import android.app.Person;
+import android.media.AudioAttributes;
+import android.net.Uri;
 import android.os.UserHandle;
+import android.platform.test.annotations.EnableFlags;
+import android.platform.test.flag.junit.SetFlagsRule;
 import android.provider.Settings;
 import android.service.notification.StatusBarNotification;
 
 import com.android.server.UiServiceTestCase;
 
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
@@ -44,25 +61,34 @@
 
     @Mock RankingConfig mConfig;
 
+    @Rule
+    public final SetFlagsRule mSetFlagsRule = new SetFlagsRule(DEVICE_DEFAULT);
+
+    NotificationChannelExtractor mExtractor;
+
     @Before
     public void setUp() {
         MockitoAnnotations.initMocks(this);
+
+        mExtractor = new NotificationChannelExtractor();
+        mExtractor.setConfig(mConfig);
+        mExtractor.initialize(mContext, null);
+    }
+
+    private NotificationRecord getRecord(NotificationChannel channel, Notification n) {
+        StatusBarNotification sbn = new StatusBarNotification("", "", 0, "", 0,
+                0, n, UserHandle.ALL, null, System.currentTimeMillis());
+        return new NotificationRecord(getContext(), sbn, channel);
     }
 
     @Test
-    public void testExtractsUpdatedChannel() {
-        NotificationChannelExtractor extractor = new NotificationChannelExtractor();
-        extractor.setConfig(mConfig);
-        extractor.initialize(mContext, null);
-
+    public void testExtractsUpdatedConversationChannel() {
         NotificationChannel channel = new NotificationChannel("a", "a", IMPORTANCE_LOW);
-        final Notification.Builder builder = new Notification.Builder(getContext())
+        final Notification n = new Notification.Builder(getContext())
                 .setContentTitle("foo")
-                .setSmallIcon(android.R.drawable.sym_def_app_icon);
-        Notification n = builder.build();
-        StatusBarNotification sbn = new StatusBarNotification("", "", 0, "", 0,
-                0, n, UserHandle.ALL, null, System.currentTimeMillis());
-        NotificationRecord r = new NotificationRecord(getContext(), sbn, channel);
+                .setSmallIcon(android.R.drawable.sym_def_app_icon)
+                .build();
+        NotificationRecord r = getRecord(channel, n);
 
         NotificationChannel updatedChannel =
                 new NotificationChannel("a", "", IMPORTANCE_HIGH);
@@ -70,26 +96,19 @@
                 any(), anyInt(), eq("a"), eq(null), eq(true), eq(false)))
                 .thenReturn(updatedChannel);
 
-        assertNull(extractor.process(r));
+        assertNull(mExtractor.process(r));
         assertEquals(updatedChannel, r.getChannel());
     }
 
     @Test
-    public void testInvalidShortcutFlagEnabled_looksUpCorrectChannel() {
-
-        NotificationChannelExtractor extractor = new NotificationChannelExtractor();
-        extractor.setConfig(mConfig);
-        extractor.initialize(mContext, null);
-
+    public void testInvalidShortcutFlagEnabled_looksUpCorrectNonChannel() {
         NotificationChannel channel = new NotificationChannel("a", "a", IMPORTANCE_LOW);
-        final Notification.Builder builder = new Notification.Builder(getContext())
+        final Notification n = new Notification.Builder(getContext())
                 .setContentTitle("foo")
                 .setStyle(new Notification.MessagingStyle("name"))
-                .setSmallIcon(android.R.drawable.sym_def_app_icon);
-        Notification n = builder.build();
-        StatusBarNotification sbn = new StatusBarNotification("", "", 0, "tag", 0,
-                0, n, UserHandle.ALL, null, System.currentTimeMillis());
-        NotificationRecord r = new NotificationRecord(getContext(), sbn, channel);
+                .setSmallIcon(android.R.drawable.sym_def_app_icon)
+                .build();
+        NotificationRecord r = getRecord(channel, n);
 
         NotificationChannel updatedChannel =
                 new NotificationChannel("a", "", IMPORTANCE_HIGH);
@@ -98,26 +117,19 @@
                 eq(true), eq(false)))
                 .thenReturn(updatedChannel);
 
-        assertNull(extractor.process(r));
+        assertNull(mExtractor.process(r));
         assertEquals(updatedChannel, r.getChannel());
     }
 
     @Test
     public void testInvalidShortcutFlagDisabled_looksUpCorrectChannel() {
-
-        NotificationChannelExtractor extractor = new NotificationChannelExtractor();
-        extractor.setConfig(mConfig);
-        extractor.initialize(mContext, null);
-
         NotificationChannel channel = new NotificationChannel("a", "a", IMPORTANCE_LOW);
-        final Notification.Builder builder = new Notification.Builder(getContext())
+        final Notification n = new Notification.Builder(getContext())
                 .setContentTitle("foo")
                 .setStyle(new Notification.MessagingStyle("name"))
-                .setSmallIcon(android.R.drawable.sym_def_app_icon);
-        Notification n = builder.build();
-        StatusBarNotification sbn = new StatusBarNotification("", "", 0, "tag", 0,
-                0, n, UserHandle.ALL, null, System.currentTimeMillis());
-        NotificationRecord r = new NotificationRecord(getContext(), sbn, channel);
+                .setSmallIcon(android.R.drawable.sym_def_app_icon)
+                .build();
+        NotificationRecord r = getRecord(channel, n);
 
         NotificationChannel updatedChannel =
                 new NotificationChannel("a", "", IMPORTANCE_HIGH);
@@ -125,7 +137,129 @@
                 any(), anyInt(), eq("a"), eq(null), eq(true), eq(false)))
                 .thenReturn(updatedChannel);
 
-        assertNull(extractor.process(r));
+        assertNull(mExtractor.process(r));
         assertEquals(updatedChannel, r.getChannel());
     }
+
+    @Test
+    @EnableFlags(Flags.FLAG_RESTRICT_AUDIO_ATTRIBUTES_CALL)
+    public void testAudioAttributes_callStyleCanUseCallUsage() {
+        NotificationChannel channel = new NotificationChannel("a", "a", IMPORTANCE_HIGH);
+        channel.setSound(Uri.EMPTY, new AudioAttributes.Builder()
+                .setUsage(USAGE_NOTIFICATION_RINGTONE)
+                .build());
+        final Notification n = new Notification.Builder(getContext())
+                .setContentTitle("foo")
+                .setStyle(Notification.CallStyle.forIncomingCall(
+                        new Person.Builder().setName("A Caller").build(),
+                        mock(PendingIntent.class),
+                        mock(PendingIntent.class)
+                ))
+                .setSmallIcon(android.R.drawable.sym_def_app_icon)
+                .build();
+        NotificationRecord r = getRecord(channel, n);
+
+        assertThat(mExtractor.process(r)).isNull();
+        assertThat(r.getAudioAttributes().getUsage()).isEqualTo(USAGE_NOTIFICATION_RINGTONE);
+        assertThat(r.getChannel()).isEqualTo(channel);
+    }
+
+    @Test
+    @EnableFlags(Flags.FLAG_RESTRICT_AUDIO_ATTRIBUTES_CALL)
+    public void testAudioAttributes_nonCallStyleCannotUseCallUsage() {
+        NotificationChannel channel = new NotificationChannel("a", "a", IMPORTANCE_HIGH);
+        channel.setSound(Uri.EMPTY, new AudioAttributes.Builder()
+                .setUsage(USAGE_NOTIFICATION_RINGTONE)
+                .build());
+        final Notification n = new Notification.Builder(getContext())
+                .setContentTitle("foo")
+                .setSmallIcon(android.R.drawable.sym_def_app_icon)
+                .build();
+        NotificationRecord r = getRecord(channel, n);
+
+        assertThat(mExtractor.process(r)).isNull();
+        // instance updated
+        assertThat(r.getAudioAttributes().getUsage()).isEqualTo(USAGE_NOTIFICATION);
+        // in-memory channel unchanged
+        assertThat(channel.getAudioAttributes().getUsage()).isEqualTo(USAGE_NOTIFICATION_RINGTONE);
+    }
+
+    @Test
+    @EnableFlags(Flags.FLAG_RESTRICT_AUDIO_ATTRIBUTES_ALARM)
+    public void testAudioAttributes_alarmCategoryCanUseAlarmUsage() {
+        NotificationChannel channel = new NotificationChannel("a", "a", IMPORTANCE_HIGH);
+        channel.setSound(Uri.EMPTY, new AudioAttributes.Builder()
+                .setUsage(USAGE_ALARM)
+                .build());
+        final Notification n = new Notification.Builder(getContext())
+                .setContentTitle("foo")
+                .setCategory(CATEGORY_ALARM)
+                .setSmallIcon(android.R.drawable.sym_def_app_icon)
+                .build();
+        NotificationRecord r = getRecord(channel, n);
+
+        assertThat(mExtractor.process(r)).isNull();
+        assertThat(r.getAudioAttributes().getUsage()).isEqualTo(USAGE_ALARM);
+        assertThat(r.getChannel()).isEqualTo(channel);
+    }
+
+    @Test
+    @EnableFlags(Flags.FLAG_RESTRICT_AUDIO_ATTRIBUTES_ALARM)
+    public void testAudioAttributes_nonAlarmCategoryCannotUseAlarmUsage() {
+        NotificationChannel channel = new NotificationChannel("a", "a", IMPORTANCE_HIGH);
+        channel.setSound(Uri.EMPTY, new AudioAttributes.Builder()
+                .setUsage(USAGE_ALARM)
+                .build());
+        final Notification n = new Notification.Builder(getContext())
+                .setContentTitle("foo")
+                .setSmallIcon(android.R.drawable.sym_def_app_icon)
+                .build();
+        NotificationRecord r = getRecord(channel, n);
+
+        assertThat(mExtractor.process(r)).isNull();
+        // instance updated
+        assertThat(r.getAudioAttributes().getUsage()).isEqualTo(USAGE_NOTIFICATION);
+        // in-memory channel unchanged
+        assertThat(channel.getAudioAttributes().getUsage()).isEqualTo(USAGE_ALARM);
+    }
+
+    @Test
+    @EnableFlags(Flags.FLAG_RESTRICT_AUDIO_ATTRIBUTES_MEDIA)
+    public void testAudioAttributes_noMediaUsage() {
+        NotificationChannel channel = new NotificationChannel("a", "a", IMPORTANCE_HIGH);
+        channel.setSound(Uri.EMPTY, new AudioAttributes.Builder()
+                .setUsage(USAGE_MEDIA)
+                .build());
+        final Notification n = new Notification.Builder(getContext())
+                .setContentTitle("foo")
+                .setSmallIcon(android.R.drawable.sym_def_app_icon)
+                .build();
+        NotificationRecord r = getRecord(channel, n);
+
+        assertThat(mExtractor.process(r)).isNull();
+        // instance updated
+        assertThat(r.getAudioAttributes().getUsage()).isEqualTo(USAGE_NOTIFICATION);
+        // in-memory channel unchanged
+        assertThat(channel.getAudioAttributes().getUsage()).isEqualTo(USAGE_MEDIA);
+    }
+
+    @Test
+    @EnableFlags(Flags.FLAG_RESTRICT_AUDIO_ATTRIBUTES_MEDIA)
+    public void testAudioAttributes_noUnknownUsage() {
+        NotificationChannel channel = new NotificationChannel("a", "a", IMPORTANCE_HIGH);
+        channel.setSound(Uri.EMPTY, new AudioAttributes.Builder()
+                .setUsage(USAGE_UNKNOWN)
+                .build());
+        final Notification n = new Notification.Builder(getContext())
+                .setContentTitle("foo")
+                .setSmallIcon(android.R.drawable.sym_def_app_icon)
+                .build();
+        NotificationRecord r = getRecord(channel, n);
+
+        assertThat(mExtractor.process(r)).isNull();
+        // instance updated
+        assertThat(r.getAudioAttributes().getUsage()).isEqualTo(USAGE_NOTIFICATION);
+        // in-memory channel unchanged
+        assertThat(channel.getAudioAttributes().getUsage()).isEqualTo(USAGE_UNKNOWN);
+    }
 }
diff --git a/services/tests/uiservicestests/src/com/android/server/notification/NotificationManagerServiceTest.java b/services/tests/uiservicestests/src/com/android/server/notification/NotificationManagerServiceTest.java
index 87f0773..ecc1730 100755
--- a/services/tests/uiservicestests/src/com/android/server/notification/NotificationManagerServiceTest.java
+++ b/services/tests/uiservicestests/src/com/android/server/notification/NotificationManagerServiceTest.java
@@ -72,6 +72,8 @@
 import static android.content.pm.PackageManager.FEATURE_WATCH;
 import static android.content.pm.PackageManager.PERMISSION_DENIED;
 import static android.content.pm.PackageManager.PERMISSION_GRANTED;
+import static android.media.AudioAttributes.USAGE_MEDIA;
+import static android.media.AudioAttributes.USAGE_NOTIFICATION;
 import static android.os.Build.VERSION_CODES.O_MR1;
 import static android.os.Build.VERSION_CODES.P;
 import static android.os.Flags.FLAG_ALLOW_PRIVATE_PROFILE;
@@ -186,6 +188,7 @@
 import android.graphics.Bitmap;
 import android.graphics.Color;
 import android.graphics.drawable.Icon;
+import android.media.AudioAttributes;
 import android.media.AudioManager;
 import android.media.session.MediaSession;
 import android.net.Uri;
@@ -14858,6 +14861,33 @@
         assertThat(posted.getRankingTimeMs()).isEqualTo(posted.getSbn().getPostTime());
     }
 
+    @Test
+    @EnableFlags(android.app.Flags.FLAG_RESTRICT_AUDIO_ATTRIBUTES_MEDIA)
+    public void testRestrictAudioAttributes_listenersGetCorrectAttributes() throws Exception {
+        NotificationChannel sound = new NotificationChannel("a", "a", IMPORTANCE_DEFAULT);
+        sound.setSound(Uri.EMPTY, new AudioAttributes.Builder().setUsage(USAGE_MEDIA).build());
+        mBinderService.createNotificationChannels(mPkg, new ParceledListSlice(
+                Arrays.asList(sound)));
+
+        Notification n = new Notification.Builder(mContext, "a")
+                .setSmallIcon(android.R.drawable.sym_def_app_icon)
+                .build();
+        StatusBarNotification sbn = new StatusBarNotification(mPkg, mPkg, 9, null, mUid, 0,
+                n, UserHandle.getUserHandleForUid(mUid), null, 0);
+
+        mBinderService.enqueueNotificationWithTag(mPkg, mPkg, sbn.getTag(),
+                sbn.getId(), sbn.getNotification(), sbn.getUserId());
+        waitForIdle();
+
+        ArgumentCaptor<NotificationRecord> captor =
+                ArgumentCaptor.forClass(NotificationRecord.class);
+        verify(mListeners, times(1)).prepareNotifyPostedLocked(
+                captor.capture(), any(), anyBoolean());
+
+        assertThat(captor.getValue().getChannel().getAudioAttributes().getUsage())
+                .isEqualTo(USAGE_NOTIFICATION);
+    }
+
     private NotificationRecord createAndPostCallStyleNotification(String packageName,
             UserHandle userHandle, String testName) throws Exception {
         Person person = new Person.Builder().setName("caller").build();