Merge "MediaRouter: make setControlCategories synchronous"
diff --git a/media/java/android/media/MediaRouter2.java b/media/java/android/media/MediaRouter2.java
index 3e6f4c0..bad0ef4 100644
--- a/media/java/android/media/MediaRouter2.java
+++ b/media/java/android/media/MediaRouter2.java
@@ -23,6 +23,7 @@
 import android.annotation.CallbackExecutor;
 import android.annotation.IntDef;
 import android.annotation.NonNull;
+import android.annotation.Nullable;
 import android.content.Context;
 import android.content.Intent;
 import android.os.Bundle;
@@ -51,7 +52,6 @@
  * @hide
  */
 public class MediaRouter2 {
-
     /** @hide */
     @Retention(SOURCE)
     @IntDef(value = {
@@ -102,13 +102,11 @@
             new CopyOnWriteArrayList<>();
 
     private final String mPackageName;
+    @GuardedBy("sLock")
     private final Map<String, MediaRoute2Info> mRoutes = new HashMap<>();
 
-    //TODO: Use a lock for this to cover the below use case
-    // mRouter.setControlCategories(...);
-    // routes = mRouter.getRoutes();
-    // The current implementation returns empty list
-    private volatile List<String> mControlCategories = Collections.emptyList();
+    @GuardedBy("sLock")
+    private List<String> mControlCategories = Collections.emptyList();
 
     private MediaRoute2Info mSelectedRoute;
     @GuardedBy("sLock")
@@ -117,7 +115,9 @@
     private Client2 mClient;
 
     final Handler mHandler;
-    volatile List<MediaRoute2Info> mFilteredRoutes = Collections.emptyList();
+    @GuardedBy("sLock")
+    private boolean mShouldUpdateRoutes;
+    private volatile List<MediaRoute2Info> mFilteredRoutes = Collections.emptyList();
 
     /**
      * Gets an instance of the media router associated with the context.
@@ -171,8 +171,7 @@
     /**
      * Registers a callback to discover routes and to receive events when they change.
      * <p>
-     * If you register the same callback twice or more, the previous arguments will be overwritten
-     * with the new arguments.
+     * If you register the same callback twice or more, it will be ignored.
      * </p>
      */
     public void registerCallback(@NonNull @CallbackExecutor Executor executor,
@@ -180,18 +179,10 @@
         Objects.requireNonNull(executor, "executor must not be null");
         Objects.requireNonNull(callback, "callback must not be null");
 
-        CallbackRecord record;
-        // This is required to prevent adding the same callback twice.
-        synchronized (mCallbackRecords) {
-            final int index = findCallbackRecordIndexLocked(callback);
-            if (index < 0) {
-                record = new CallbackRecord(callback);
-                mCallbackRecords.add(record);
-            } else {
-                record = mCallbackRecords.get(index);
-            }
-            record.mExecutor = executor;
-            record.mFlags = flags;
+        CallbackRecord record = new CallbackRecord(callback, executor, flags);
+        if (!mCallbackRecords.addIfAbsent(record)) {
+            Log.w(TAG, "Ignoring the same callback");
+            return;
         }
 
         synchronized (sLock) {
@@ -206,8 +197,6 @@
                 }
             }
         }
-        //TODO: Is it thread-safe?
-        record.notifyRoutes();
 
         //TODO: Update discovery request here.
     }
@@ -222,23 +211,20 @@
     public void unregisterCallback(@NonNull Callback callback) {
         Objects.requireNonNull(callback, "callback must not be null");
 
-        synchronized (mCallbackRecords) {
-            final int index = findCallbackRecordIndexLocked(callback);
-            if (index < 0) {
-                Log.w(TAG, "Ignoring to remove unknown callback. " + callback);
-                return;
-            }
-            mCallbackRecords.remove(index);
-            synchronized (sLock) {
-                if (mCallbackRecords.size() == 0 && mClient != null) {
-                    try {
-                        mMediaRouterService.unregisterClient2(mClient);
-                    } catch (RemoteException ex) {
-                        Log.e(TAG, "Unable to unregister media router.", ex);
-                    }
-                    //TODO: Clean up mRoutes. (onHandler?)
-                    mClient = null;
+        if (!mCallbackRecords.remove(new CallbackRecord(callback, null, 0))) {
+            Log.w(TAG, "Ignoring unknown callback");
+            return;
+        }
+
+        synchronized (sLock) {
+            if (mCallbackRecords.size() == 0 && mClient != null) {
+                try {
+                    mMediaRouterService.unregisterClient2(mClient);
+                } catch (RemoteException ex) {
+                    Log.e(TAG, "Unable to unregister media router.", ex);
                 }
+                //TODO: Clean up mRoutes. (onHandler?)
+                mClient = null;
             }
         }
     }
@@ -246,26 +232,52 @@
     //TODO(b/139033746): Rename "Control Category" when it's finalized.
     /**
      * Sets the control categories of the application.
-     * Routes that support at least one of the given control categories only exists and are handled
+     * Routes that support at least one of the given control categories are handled
      * by the media router.
      */
     public void setControlCategories(@NonNull Collection<String> controlCategories) {
         Objects.requireNonNull(controlCategories, "control categories must not be null");
 
-        // To ensure invoking callbacks correctly according to control categories
-        mHandler.sendMessage(obtainMessage(MediaRouter2::setControlCategoriesOnHandler,
-                MediaRouter2.this, new ArrayList<>(controlCategories)));
+        List<String> newControlCategories = new ArrayList<>(controlCategories);
+
+        synchronized (sLock) {
+            mShouldUpdateRoutes = true;
+
+            // invoke callbacks due to control categories change
+            handleControlCategoriesChangedLocked(newControlCategories);
+            if (mClient != null) {
+                try {
+                    mMediaRouterService.setControlCategories(mClient, mControlCategories);
+                } catch (RemoteException ex) {
+                    Log.e(TAG, "Unable to set control categories.", ex);
+                }
+            }
+        }
     }
 
     /**
      * Gets the unmodifiable list of {@link MediaRoute2Info routes} currently
      * known to the media router.
+     * Please note that the list can be changed before callbacks are invoked.
      *
      * @return the list of routes that support at least one of the control categories set by
      * the application
      */
     @NonNull
     public List<MediaRoute2Info> getRoutes() {
+        synchronized (sLock) {
+            if (mShouldUpdateRoutes) {
+                mShouldUpdateRoutes = false;
+
+                List<MediaRoute2Info> filteredRoutes = new ArrayList<>();
+                for (MediaRoute2Info route : mRoutes.values()) {
+                    if (route.supportsControlCategory(mControlCategories)) {
+                        filteredRoutes.add(route);
+                    }
+                }
+                mFilteredRoutes = Collections.unmodifiableList(filteredRoutes);
+            }
+        }
         return mFilteredRoutes;
     }
 
@@ -379,43 +391,16 @@
         }
     }
 
-    @GuardedBy("mCallbackRecords")
-    private int findCallbackRecordIndexLocked(Callback callback) {
-        final int count = mCallbackRecords.size();
-        for (int i = 0; i < count; i++) {
-            CallbackRecord callbackRecord = mCallbackRecords.get(i);
-            if (callbackRecord.mCallback == callback) {
-                return i;
-            }
-        }
-        return -1;
-    }
-
-    private void setControlCategoriesOnHandler(List<String> newControlCategories) {
-        List<String> prevControlCategories = mControlCategories;
+    private void handleControlCategoriesChangedLocked(List<String> newControlCategories) {
         List<MediaRoute2Info> addedRoutes = new ArrayList<>();
         List<MediaRoute2Info> removedRoutes = new ArrayList<>();
-        List<MediaRoute2Info> filteredRoutes = new ArrayList<>();
 
+        List<String> prevControlCategories = mControlCategories;
         mControlCategories = newControlCategories;
-        Client2 client;
-        synchronized (sLock) {
-            client = mClient;
-        }
-        if (client != null) {
-            try {
-                mMediaRouterService.setControlCategories(client, mControlCategories);
-            } catch (RemoteException ex) {
-                Log.e(TAG, "Unable to set control categories.", ex);
-            }
-        }
 
         for (MediaRoute2Info route : mRoutes.values()) {
             boolean preSupported = route.supportsControlCategory(prevControlCategories);
             boolean postSupported = route.supportsControlCategory(newControlCategories);
-            if (postSupported) {
-                filteredRoutes.add(route);
-            }
             if (preSupported == postSupported) {
                 continue;
             }
@@ -425,13 +410,14 @@
                 addedRoutes.add(route);
             }
         }
-        mFilteredRoutes = Collections.unmodifiableList(filteredRoutes);
 
         if (removedRoutes.size() > 0) {
-            notifyRoutesRemoved(removedRoutes);
+            mHandler.sendMessage(obtainMessage(MediaRouter2::notifyRoutesRemoved,
+                    MediaRouter2.this, removedRoutes));
         }
         if (addedRoutes.size() > 0) {
-            notifyRoutesAdded(addedRoutes);
+            mHandler.sendMessage(obtainMessage(MediaRouter2::notifyRoutesAdded,
+                    MediaRouter2.this, addedRoutes));
         }
     }
 
@@ -441,42 +427,47 @@
         //  2) Call onRouteSelected(system_route, reason_fallback) if previously selected route
         //     does not exist anymore. => We may need 'boolean MediaRoute2Info#isSystemRoute()'.
         List<MediaRoute2Info> addedRoutes = new ArrayList<>();
-        for (MediaRoute2Info route : routes) {
-            mRoutes.put(route.getUniqueId(), route);
-            if (route.supportsControlCategory(mControlCategories)) {
-                addedRoutes.add(route);
+        synchronized (sLock) {
+            for (MediaRoute2Info route : routes) {
+                mRoutes.put(route.getUniqueId(), route);
+                if (route.supportsControlCategory(mControlCategories)) {
+                    addedRoutes.add(route);
+                }
             }
+            mShouldUpdateRoutes = true;
         }
         if (addedRoutes.size() > 0) {
-            refreshFilteredRoutes();
             notifyRoutesAdded(addedRoutes);
         }
     }
 
     void removeRoutesOnHandler(List<MediaRoute2Info> routes) {
         List<MediaRoute2Info> removedRoutes = new ArrayList<>();
-        for (MediaRoute2Info route : routes) {
-            mRoutes.remove(route.getUniqueId());
-            if (route.supportsControlCategory(mControlCategories)) {
-                removedRoutes.add(route);
+        synchronized (sLock) {
+            for (MediaRoute2Info route : routes) {
+                mRoutes.remove(route.getUniqueId());
+                if (route.supportsControlCategory(mControlCategories)) {
+                    removedRoutes.add(route);
+                }
             }
+            mShouldUpdateRoutes = true;
         }
         if (removedRoutes.size() > 0) {
-            refreshFilteredRoutes();
             notifyRoutesRemoved(removedRoutes);
         }
     }
 
     void changeRoutesOnHandler(List<MediaRoute2Info> routes) {
         List<MediaRoute2Info> changedRoutes = new ArrayList<>();
-        for (MediaRoute2Info route : routes) {
-            mRoutes.put(route.getUniqueId(), route);
-            if (route.supportsControlCategory(mControlCategories)) {
-                changedRoutes.add(route);
+        synchronized (sLock) {
+            for (MediaRoute2Info route : routes) {
+                mRoutes.put(route.getUniqueId(), route);
+                if (route.supportsControlCategory(mControlCategories)) {
+                    changedRoutes.add(route);
+                }
             }
         }
         if (changedRoutes.size() > 0) {
-            refreshFilteredRoutes();
             notifyRoutesChanged(changedRoutes);
         }
     }
@@ -500,17 +491,6 @@
         notifyRouteSelected(route, reason, controlHints);
     }
 
-    private void refreshFilteredRoutes() {
-        List<MediaRoute2Info> filteredRoutes = new ArrayList<>();
-
-        for (MediaRoute2Info route : mRoutes.values()) {
-            if (route.supportsControlCategory(mControlCategories)) {
-                filteredRoutes.add(route);
-            }
-        }
-        mFilteredRoutes = Collections.unmodifiableList(filteredRoutes);
-    }
-
     private void notifyRoutesAdded(List<MediaRoute2Info> routes) {
         for (CallbackRecord record: mCallbackRecords) {
             record.mExecutor.execute(
@@ -544,13 +524,16 @@
      */
     public static class Callback {
         /**
-         * Called when routes are added.
+         * Called when routes are added. Whenever you registers a callback, this will
+         * be invoked with known routes.
+         *
          * @param routes the list of routes that have been added. It's never empty.
          */
         public void onRoutesAdded(@NonNull List<MediaRoute2Info> routes) {}
 
         /**
          * Called when routes are removed.
+         *
          * @param routes the list of routes that have been removed. It's never empty.
          */
         public void onRoutesRemoved(@NonNull List<MediaRoute2Info> routes) {}
@@ -569,6 +552,7 @@
 
         /**
          * Called when a route is selected. Exactly one route can be selected at a time.
+         *
          * @param route the selected route.
          * @param reason the reason why the route is selected.
          * @param controlHints An optional bundle of provider-specific arguments which may be
@@ -587,16 +571,26 @@
         public Executor mExecutor;
         public int mFlags;
 
-        CallbackRecord(@NonNull Callback callback) {
+        CallbackRecord(@NonNull Callback callback, @Nullable Executor executor, int flags) {
             mCallback = callback;
+            mExecutor = executor;
+            mFlags = flags;
         }
 
-        void notifyRoutes() {
-            final List<MediaRoute2Info> routes = mFilteredRoutes;
-            // notify only when bound to media router service.
-            if (routes.size() > 0) {
-                mExecutor.execute(() -> mCallback.onRoutesAdded(routes));
+        @Override
+        public boolean equals(Object obj) {
+            if (this == obj) {
+                return true;
             }
+            if (!(obj instanceof CallbackRecord)) {
+                return false;
+            }
+            return mCallback == ((CallbackRecord) obj).mCallback;
+        }
+
+        @Override
+        public int hashCode() {
+            return mCallback.hashCode();
         }
     }
 
diff --git a/media/java/android/media/MediaRouter2Manager.java b/media/java/android/media/MediaRouter2Manager.java
index d56dd11..502538d 100644
--- a/media/java/android/media/MediaRouter2Manager.java
+++ b/media/java/android/media/MediaRouter2Manager.java
@@ -57,7 +57,7 @@
     private Client mClient;
     private final IMediaRouterService mMediaRouterService;
     final Handler mHandler;
-    final List<CallbackRecord> mCallbackRecords = new CopyOnWriteArrayList<>();
+    final CopyOnWriteArrayList<CallbackRecord> mCallbackRecords = new CopyOnWriteArrayList<>();
 
     private final Object mRoutesLock = new Object();
     @GuardedBy("mRoutesLock")
@@ -99,14 +99,10 @@
         Objects.requireNonNull(executor, "executor must not be null");
         Objects.requireNonNull(callback, "callback must not be null");
 
-        CallbackRecord callbackRecord;
-        synchronized (mCallbackRecords) {
-            if (findCallbackRecordIndexLocked(callback) >= 0) {
-                Log.w(TAG, "Ignoring to add the same callback twice.");
-                return;
-            }
-            callbackRecord = new CallbackRecord(executor, callback);
-            mCallbackRecords.add(callbackRecord);
+        CallbackRecord callbackRecord = new CallbackRecord(executor, callback);
+        if (!mCallbackRecords.addIfAbsent(callbackRecord)) {
+            Log.w(TAG, "Ignoring to add the same callback twice.");
+            return;
         }
 
         synchronized (sLock) {
@@ -118,8 +114,6 @@
                 } catch (RemoteException ex) {
                     Log.e(TAG, "Unable to register media router manager.", ex);
                 }
-            } else {
-                callbackRecord.notifyRoutes();
             }
         }
     }
@@ -132,36 +126,23 @@
     public void unregisterCallback(@NonNull Callback callback) {
         Objects.requireNonNull(callback, "callback must not be null");
 
-        synchronized (mCallbackRecords) {
-            final int index = findCallbackRecordIndexLocked(callback);
-            if (index < 0) {
-                Log.w(TAG, "Ignore removing unknown callback. " + callback);
-                return;
-            }
-            mCallbackRecords.remove(index);
-            synchronized (sLock) {
-                if (mCallbackRecords.size() == 0 && mClient != null) {
-                    try {
-                        mMediaRouterService.unregisterManager(mClient);
-                    } catch (RemoteException ex) {
-                        Log.e(TAG, "Unable to unregister media router manager", ex);
-                    }
-                    //TODO: clear mRoutes?
-                    mClient = null;
-                }
-            }
+        if (!mCallbackRecords.remove(new CallbackRecord(null, callback))) {
+            Log.w(TAG, "Ignore removing unknown callback. " + callback);
+            return;
         }
-    }
 
-    @GuardedBy("mCallbackRecords")
-    private int findCallbackRecordIndexLocked(Callback callback) {
-        final int count = mCallbackRecords.size();
-        for (int i = 0; i < count; i++) {
-            if (mCallbackRecords.get(i).mCallback == callback) {
-                return i;
+        synchronized (sLock) {
+            if (mCallbackRecords.size() == 0 && mClient != null) {
+                try {
+                    mMediaRouterService.unregisterManager(mClient);
+                } catch (RemoteException ex) {
+                    Log.e(TAG, "Unable to unregister media router manager", ex);
+                }
+                //TODO: clear mRoutes?
+                mClient = null;
+                mControlCategoryMap.clear();
             }
         }
-        return -1;
     }
 
     //TODO: Use cache not to create array. For now, it's unclear when to purge the cache.
@@ -187,7 +168,6 @@
                 }
             }
         }
-        //TODO: Should we cache this?
         return routes;
     }
 
@@ -342,10 +322,14 @@
     }
 
     void updateControlCategories(String packageName, List<String> categories) {
-        mControlCategoryMap.put(packageName, categories);
+        List<String> prevCategories = mControlCategoryMap.put(packageName, categories);
+        if ((prevCategories == null && categories.size() == 0)
+                || Objects.equals(categories, prevCategories)) {
+            return;
+        }
         for (CallbackRecord record : mCallbackRecords) {
             record.mExecutor.execute(
-                    () -> record.mCallback.onControlCategoriesChanged(packageName));
+                    () -> record.mCallback.onControlCategoriesChanged(packageName, categories));
         }
     }
 
@@ -386,8 +370,10 @@
          * Called when the control categories of an app is changed.
          *
          * @param packageName the package name of the application
+         * @param controlCategories the list of control categories set by an application.
          */
-        public void onControlCategoriesChanged(@NonNull String packageName) {}
+        public void onControlCategoriesChanged(@NonNull String packageName,
+                @NonNull List<String> controlCategories) {}
     }
 
     final class CallbackRecord {
@@ -399,14 +385,20 @@
             mCallback = callback;
         }
 
-        void notifyRoutes() {
-            List<MediaRoute2Info> routes;
-            synchronized (mRoutesLock) {
-                routes = new ArrayList<>(mRoutes.values());
+        @Override
+        public boolean equals(Object obj) {
+            if (this == obj) {
+                return true;
             }
-            if (routes.size() > 0) {
-                mExecutor.execute(() -> mCallback.onRoutesAdded(routes));
+            if (!(obj instanceof CallbackRecord)) {
+                return false;
             }
+            return mCallback ==  ((CallbackRecord) obj).mCallback;
+        }
+
+        @Override
+        public int hashCode() {
+            return mCallback.hashCode();
         }
     }
 
diff --git a/media/tests/MediaRouter/src/com/android/mediaroutertest/MediaRouter2Test.java b/media/tests/MediaRouter/src/com/android/mediaroutertest/MediaRouter2Test.java
index 2c60d6b..3266285 100644
--- a/media/tests/MediaRouter/src/com/android/mediaroutertest/MediaRouter2Test.java
+++ b/media/tests/MediaRouter/src/com/android/mediaroutertest/MediaRouter2Test.java
@@ -16,7 +16,15 @@
 
 package com.android.mediaroutertest;
 
+import static com.android.mediaroutertest.MediaRouterManagerTest.CATEGORIES_ALL;
+import static com.android.mediaroutertest.MediaRouterManagerTest.CATEGORIES_SPECIAL;
+import static com.android.mediaroutertest.MediaRouterManagerTest.ROUTE_ID_SPECIAL_CATEGORY;
+import static com.android.mediaroutertest.MediaRouterManagerTest.ROUTE_ID_VARIABLE_VOLUME;
+import static com.android.mediaroutertest.MediaRouterManagerTest.SYSTEM_PROVIDER_ID;
+
+import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
 
 import android.content.Context;
 import android.media.MediaRoute2Info;
@@ -24,20 +32,37 @@
 import android.support.test.InstrumentationRegistry;
 import android.support.test.filters.SmallTest;
 import android.support.test.runner.AndroidJUnit4;
+import android.text.TextUtils;
 
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.Executor;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Predicate;
+
 @RunWith(AndroidJUnit4.class)
 @SmallTest
 public class MediaRouter2Test {
+    private static final String TAG = "MediaRouter2Test";
     Context mContext;
+    private MediaRouter2 mRouter2;
+    private Executor mExecutor;
+
+    private static final int TIMEOUT_MS = 5000;
 
     @Before
     public void setUp() throws Exception {
         mContext = InstrumentationRegistry.getTargetContext();
+        mRouter2 = MediaRouter2.getInstance(mContext);
+        mExecutor = Executors.newSingleThreadExecutor();
     }
 
     @After
@@ -50,4 +75,95 @@
         MediaRoute2Info initiallySelectedRoute = router.getSelectedRoute();
         assertNotNull(initiallySelectedRoute);
     }
+
+    /**
+     * Tests if we get proper routes for application that has special control category.
+     */
+    @Test
+    public void testGetRoutes() throws Exception {
+        Map<String, MediaRoute2Info> routes = waitAndGetRoutes(CATEGORIES_SPECIAL);
+
+        assertEquals(1, routes.size());
+        assertNotNull(routes.get(ROUTE_ID_SPECIAL_CATEGORY));
+    }
+
+    @Test
+    public void testControlVolumeWithRouter() throws Exception {
+        Map<String, MediaRoute2Info> routes = waitAndGetRoutes(CATEGORIES_ALL);
+
+        MediaRoute2Info volRoute = routes.get(ROUTE_ID_VARIABLE_VOLUME);
+        assertNotNull(volRoute);
+
+        int originalVolume = volRoute.getVolume();
+        int deltaVolume = (originalVolume == volRoute.getVolumeMax() ? -1 : 1);
+
+        awaitOnRouteChanged(
+                () -> mRouter2.requestUpdateVolume(volRoute, deltaVolume),
+                ROUTE_ID_VARIABLE_VOLUME,
+                (route -> route.getVolume() == originalVolume + deltaVolume));
+
+        awaitOnRouteChanged(
+                () -> mRouter2.requestSetVolume(volRoute, originalVolume),
+                ROUTE_ID_VARIABLE_VOLUME,
+                (route -> route.getVolume() == originalVolume));
+    }
+
+
+    // Helper for getting routes easily
+    static Map<String, MediaRoute2Info> createRouteMap(List<MediaRoute2Info> routes) {
+        Map<String, MediaRoute2Info> routeMap = new HashMap<>();
+        for (MediaRoute2Info route : routes) {
+            // intentionally not using route.getUniqueId() for convenience.
+            routeMap.put(route.getId(), route);
+        }
+        return routeMap;
+    }
+
+    Map<String, MediaRoute2Info> waitAndGetRoutes(List<String> controlCategories)
+            throws Exception {
+        CountDownLatch latch = new CountDownLatch(1);
+
+        // A dummy callback is required to send control category info.
+        MediaRouter2.Callback routerCallback = new MediaRouter2.Callback() {
+            @Override
+            public void onRoutesAdded(List<MediaRoute2Info> routes) {
+                for (int i = 0; i < routes.size(); i++) {
+                    //TODO: use isSystem() or similar method when it's ready
+                    if (!TextUtils.equals(routes.get(i).getProviderId(), SYSTEM_PROVIDER_ID)) {
+                        latch.countDown();
+                    }
+                }
+            }
+        };
+
+        mRouter2.setControlCategories(controlCategories);
+        mRouter2.registerCallback(mExecutor, routerCallback);
+        try {
+            latch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS);
+            return createRouteMap(mRouter2.getRoutes());
+        } finally {
+            mRouter2.unregisterCallback(routerCallback);
+        }
+    }
+
+    void awaitOnRouteChanged(Runnable task, String routeId,
+            Predicate<MediaRoute2Info> predicate) throws Exception {
+        CountDownLatch latch = new CountDownLatch(1);
+        MediaRouter2.Callback callback = new MediaRouter2.Callback() {
+            @Override
+            public void onRoutesChanged(List<MediaRoute2Info> changed) {
+                MediaRoute2Info route = createRouteMap(changed).get(routeId);
+                if (route != null && predicate.test(route)) {
+                    latch.countDown();
+                }
+            }
+        };
+        mRouter2.registerCallback(mExecutor, callback);
+        try {
+            task.run();
+            assertTrue(latch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
+        } finally {
+            mRouter2.unregisterCallback(callback);
+        }
+    }
 }
diff --git a/media/tests/MediaRouter/src/com/android/mediaroutertest/MediaRouterManagerTest.java b/media/tests/MediaRouter/src/com/android/mediaroutertest/MediaRouterManagerTest.java
index c70ad8d..b380aff 100644
--- a/media/tests/MediaRouter/src/com/android/mediaroutertest/MediaRouterManagerTest.java
+++ b/media/tests/MediaRouter/src/com/android/mediaroutertest/MediaRouterManagerTest.java
@@ -64,6 +64,9 @@
     public static final String ROUTE_ID_SPECIAL_CATEGORY = "route_special_category";
     public static final String ROUTE_NAME_SPECIAL_CATEGORY = "Special Category Route";
 
+    public static final String SYSTEM_PROVIDER_ID =
+            "com.android.server.media/.SystemMediaRoute2Provider";
+
     public static final int VOLUME_MAX = 100;
     public static final String ROUTE_ID_FIXED_VOLUME = "route_fixed_volume";
     public static final String ROUTE_NAME_FIXED_VOLUME = "Fixed Volume Route";
@@ -78,10 +81,7 @@
     public static final String CATEGORY_SPECIAL =
             "com.android.mediarouteprovider.CATEGORY_SPECIAL";
 
-    // system routes
-    private static final String DEFAULT_ROUTE_ID = "DEFAULT_ROUTE";
     private static final String CATEGORY_LIVE_AUDIO = "android.media.intent.category.LIVE_AUDIO";
-    private static final String CATEGORY_LIVE_VIDEO = "android.media.intent.category.LIVE_VIDEO";
 
     private static final int TIMEOUT_MS = 5000;
 
@@ -93,10 +93,9 @@
 
     private final List<MediaRouter2Manager.Callback> mManagerCallbacks = new ArrayList<>();
     private final List<MediaRouter2.Callback> mRouterCallbacks = new ArrayList<>();
-    private Map<String, MediaRoute2Info> mRoutes;
 
-    private static final List<String> CATEGORIES_ALL = new ArrayList();
-    private static final List<String> CATEGORIES_SPECIAL = new ArrayList();
+    public static final List<String> CATEGORIES_ALL = new ArrayList();
+    public static final List<String> CATEGORIES_SPECIAL = new ArrayList();
     private static final List<String> CATEGORIES_LIVE_AUDIO = new ArrayList<>();
 
     static {
@@ -109,7 +108,6 @@
         CATEGORIES_LIVE_AUDIO.add(CATEGORY_LIVE_AUDIO);
     }
 
-
     @Before
     public void setUp() throws Exception {
         mContext = InstrumentationRegistry.getTargetContext();
@@ -118,10 +116,6 @@
         //TODO: If we need to support thread pool executors, change this to thread pool executor.
         mExecutor = Executors.newSingleThreadExecutor();
         mPackageName = mContext.getPackageName();
-
-        // ensure media router 2 client
-        addRouterCallback(new MediaRouter2.Callback());
-        mRoutes = waitAndGetRoutesWithManager(CATEGORIES_ALL);
     }
 
     @After
@@ -168,6 +162,9 @@
     @Test
     public void testOnRoutesRemoved() throws Exception {
         CountDownLatch latch = new CountDownLatch(1);
+        Map<String, MediaRoute2Info> routes = waitAndGetRoutesWithManager(CATEGORIES_ALL);
+
+        addRouterCallback(new MediaRouter2.Callback());
         addManagerCallback(new MediaRouter2Manager.Callback() {
             @Override
             public void onRoutesRemoved(List<MediaRoute2Info> routes) {
@@ -182,7 +179,7 @@
 
         //TODO: Figure out a more proper way to test.
         // (Control requests shouldn't be used in this way.)
-        mRouter2.sendControlRequest(mRoutes.get(ROUTE_ID2), new Intent(ACTION_REMOVE_ROUTE));
+        mRouter2.sendControlRequest(routes.get(ROUTE_ID2), new Intent(ACTION_REMOVE_ROUTE));
         assertTrue(latch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
     }
 
@@ -198,23 +195,15 @@
     }
 
     /**
-     * Tests if we get proper routes for application that has special control category.
-     */
-    @Test
-    public void testGetRoutes() throws Exception {
-        Map<String, MediaRoute2Info> routes = waitAndGetRoutes(CATEGORIES_SPECIAL);
-
-        assertEquals(1, routes.size());
-        assertNotNull(routes.get(ROUTE_ID_SPECIAL_CATEGORY));
-    }
-
-    /**
      * Tests if MR2.Callback.onRouteSelected is called when a route is selected from MR2Manager.
      */
     @Test
     public void testRouterOnRouteSelected() throws Exception {
+        Map<String, MediaRoute2Info> routes = waitAndGetRoutesWithManager(CATEGORIES_ALL);
+
         CountDownLatch latch = new CountDownLatch(1);
 
+        addManagerCallback(new MediaRouter2Manager.Callback());
         addRouterCallback(new MediaRouter2.Callback() {
             @Override
             public void onRouteSelected(MediaRoute2Info route, int reason, Bundle controlHints) {
@@ -224,12 +213,16 @@
             }
         });
 
-        MediaRoute2Info routeToSelect = mRoutes.get(ROUTE_ID1);
+        MediaRoute2Info routeToSelect = routes.get(ROUTE_ID1);
         assertNotNull(routeToSelect);
 
-        mManager.selectRoute(mPackageName, routeToSelect);
+        try {
+            mManager.selectRoute(mPackageName, routeToSelect);
 
-        assertTrue(latch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
+            assertTrue(latch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
+        } finally {
+            mManager.unselectRoute(mPackageName);
+        }
     }
 
     /**
@@ -239,7 +232,9 @@
     @Test
     public void testManagerOnRouteSelected() throws Exception {
         CountDownLatch latch = new CountDownLatch(1);
+        Map<String, MediaRoute2Info> routes = waitAndGetRoutesWithManager(CATEGORIES_ALL);
 
+        addRouterCallback(new MediaRouter2.Callback());
         addManagerCallback(new MediaRouter2Manager.Callback() {
             @Override
             public void onRouteSelected(String packageName, MediaRoute2Info route) {
@@ -250,12 +245,15 @@
             }
         });
 
-        MediaRoute2Info routeToSelect = mRoutes.get(ROUTE_ID1);
+        MediaRoute2Info routeToSelect = routes.get(ROUTE_ID1);
         assertNotNull(routeToSelect);
 
-        mManager.selectRoute(mPackageName, routeToSelect);
-
-        assertTrue(latch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
+        try {
+            mManager.selectRoute(mPackageName, routeToSelect);
+            assertTrue(latch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
+        } finally {
+            mManager.unselectRoute(mPackageName);
+        }
     }
 
     /**
@@ -263,13 +261,16 @@
      */
     @Test
     public void testSingleProviderSelect() throws Exception {
+        Map<String, MediaRoute2Info> routes = waitAndGetRoutesWithManager(CATEGORIES_ALL);
+        addRouterCallback(new MediaRouter2.Callback());
+
         awaitOnRouteChangedManager(
-                () -> mManager.selectRoute(mPackageName, mRoutes.get(ROUTE_ID1)),
+                () -> mManager.selectRoute(mPackageName, routes.get(ROUTE_ID1)),
                 ROUTE_ID1,
                 route -> TextUtils.equals(route.getClientPackageName(), mPackageName));
 
         awaitOnRouteChangedManager(
-                () -> mManager.selectRoute(mPackageName, mRoutes.get(ROUTE_ID2)),
+                () -> mManager.selectRoute(mPackageName, routes.get(ROUTE_ID2)),
                 ROUTE_ID2,
                 route -> TextUtils.equals(route.getClientPackageName(), mPackageName));
 
@@ -280,27 +281,10 @@
     }
 
     @Test
-    public void testControlVolumeWithRouter() throws Exception {
-        Map<String, MediaRoute2Info> routes = waitAndGetRoutes(CATEGORIES_ALL);
-
-        MediaRoute2Info volRoute = routes.get(ROUTE_ID_VARIABLE_VOLUME);
-        int originalVolume = volRoute.getVolume();
-        int deltaVolume = (originalVolume == volRoute.getVolumeMax() ? -1 : 1);
-
-        awaitOnRouteChanged(
-                () -> mRouter2.requestUpdateVolume(volRoute, deltaVolume),
-                ROUTE_ID_VARIABLE_VOLUME,
-                (route -> route.getVolume() == originalVolume + deltaVolume));
-
-        awaitOnRouteChanged(
-                () -> mRouter2.requestSetVolume(volRoute, originalVolume),
-                ROUTE_ID_VARIABLE_VOLUME,
-                (route -> route.getVolume() == originalVolume));
-    }
-
-    @Test
     public void testControlVolumeWithManager() throws Exception {
-        MediaRoute2Info volRoute = mRoutes.get(ROUTE_ID_VARIABLE_VOLUME);
+        Map<String, MediaRoute2Info> routes = waitAndGetRoutesWithManager(CATEGORIES_ALL);
+        MediaRoute2Info volRoute = routes.get(ROUTE_ID_VARIABLE_VOLUME);
+
         int originalVolume = volRoute.getVolume();
         int deltaVolume = (originalVolume == volRoute.getVolumeMax() ? -1 : 1);
 
@@ -317,39 +301,16 @@
 
     @Test
     public void testVolumeHandling() throws Exception {
-        MediaRoute2Info fixedVolumeRoute = mRoutes.get(ROUTE_ID_FIXED_VOLUME);
-        MediaRoute2Info variableVolumeRoute = mRoutes.get(ROUTE_ID_VARIABLE_VOLUME);
+        Map<String, MediaRoute2Info> routes = waitAndGetRoutesWithManager(CATEGORIES_ALL);
+
+        MediaRoute2Info fixedVolumeRoute = routes.get(ROUTE_ID_FIXED_VOLUME);
+        MediaRoute2Info variableVolumeRoute = routes.get(ROUTE_ID_VARIABLE_VOLUME);
 
         assertEquals(PLAYBACK_VOLUME_FIXED, fixedVolumeRoute.getVolumeHandling());
         assertEquals(PLAYBACK_VOLUME_VARIABLE, variableVolumeRoute.getVolumeHandling());
         assertEquals(VOLUME_MAX, variableVolumeRoute.getVolumeMax());
     }
 
-    @Test
-    public void testDefaultRoute() throws Exception {
-        Map<String, MediaRoute2Info> routes = waitAndGetRoutes(CATEGORIES_LIVE_AUDIO);
-
-        assertNotNull(routes.get(DEFAULT_ROUTE_ID));
-    }
-
-    Map<String, MediaRoute2Info> waitAndGetRoutes(List<String> controlCategories) throws Exception {
-        CountDownLatch latch = new CountDownLatch(1);
-        MediaRouter2.Callback callback = new MediaRouter2.Callback() {
-            @Override
-            public void onRoutesAdded(List<MediaRoute2Info> added) {
-                if (added.size() > 0) latch.countDown();
-            }
-        };
-        mRouter2.setControlCategories(controlCategories);
-        mRouter2.registerCallback(mExecutor, callback);
-        try {
-            assertTrue(latch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
-            return createRouteMap(mRouter2.getRoutes());
-        } finally {
-            mRouter2.unregisterCallback(callback);
-        }
-    }
-
     Map<String, MediaRoute2Info> waitAndGetRoutesWithManager(List<String> controlCategories)
             throws Exception {
         CountDownLatch latch = new CountDownLatch(2);
@@ -359,13 +320,17 @@
         MediaRouter2Manager.Callback managerCallback = new MediaRouter2Manager.Callback() {
             @Override
             public void onRoutesAdded(List<MediaRoute2Info> routes) {
-                if (routes.size() > 0) {
-                    latch.countDown();
+                for (int i = 0; i < routes.size(); i++) {
+                    //TODO: use isSystem() or similar method when it's ready
+                    if (!TextUtils.equals(routes.get(i).getProviderId(), SYSTEM_PROVIDER_ID)) {
+                        latch.countDown();
+                        break;
+                    }
                 }
             }
 
             @Override
-            public void onControlCategoriesChanged(String packageName) {
+            public void onControlCategoriesChanged(String packageName, List<String> categories) {
                 if (TextUtils.equals(mPackageName, packageName)) {
                     latch.countDown();
                 }
@@ -375,7 +340,7 @@
         mRouter2.setControlCategories(controlCategories);
         mRouter2.registerCallback(mExecutor, routerCallback);
         try {
-            assertTrue(latch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
+            latch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS);
             return createRouteMap(mManager.getAvailableRoutes(mPackageName));
         } finally {
             mRouter2.unregisterCallback(routerCallback);
@@ -383,27 +348,6 @@
         }
     }
 
-    void awaitOnRouteChanged(Runnable task, String routeId,
-            Predicate<MediaRoute2Info> predicate) throws Exception {
-        CountDownLatch latch = new CountDownLatch(1);
-        MediaRouter2.Callback callback = new MediaRouter2.Callback() {
-            @Override
-            public void onRoutesChanged(List<MediaRoute2Info> changed) {
-                MediaRoute2Info route = createRouteMap(changed).get(routeId);
-                if (route != null && predicate.test(route)) {
-                    latch.countDown();
-                }
-            }
-        };
-        mRouter2.registerCallback(mExecutor, callback);
-        try {
-            task.run();
-            assertTrue(latch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
-        } finally {
-            mRouter2.unregisterCallback(callback);
-        }
-    }
-
     void awaitOnRouteChangedManager(Runnable task, String routeId,
             Predicate<MediaRoute2Info> predicate) throws Exception {
         CountDownLatch latch = new CountDownLatch(1);
diff --git a/services/core/java/com/android/server/media/MediaRouter2ServiceImpl.java b/services/core/java/com/android/server/media/MediaRouter2ServiceImpl.java
index 9fcee50..e7b8860 100644
--- a/services/core/java/com/android/server/media/MediaRouter2ServiceImpl.java
+++ b/services/core/java/com/android/server/media/MediaRouter2ServiceImpl.java
@@ -183,8 +183,9 @@
     }
 
     public void setControlCategories(@NonNull IMediaRouter2Client client,
-            @Nullable List<String> categories) {
+            @NonNull List<String> categories) {
         Objects.requireNonNull(client, "client must not be null");
+        Objects.requireNonNull(categories, "categories must not be null");
 
         final long token = Binder.clearCallingIdentity();
         try {
@@ -390,8 +391,11 @@
 
     private void setControlCategoriesLocked(Client2Record clientRecord, List<String> categories) {
         if (clientRecord != null) {
-            clientRecord.mControlCategories = categories;
+            if (clientRecord.mControlCategories.equals(categories)) {
+                return;
+            }
 
+            clientRecord.mControlCategories = categories;
             clientRecord.mUserRecord.mHandler.sendMessage(
                     obtainMessage(UserHandler::updateClientUsage,
                             clientRecord.mUserRecord.mHandler, clientRecord));