Merge "MediaSession2: Add caller to the callback methods"
diff --git a/packages/MediaComponents/src/com/android/media/MediaBrowser2Impl.java b/packages/MediaComponents/src/com/android/media/MediaBrowser2Impl.java
index 9190dfc..24293ab 100644
--- a/packages/MediaComponents/src/com/android/media/MediaBrowser2Impl.java
+++ b/packages/MediaComponents/src/com/android/media/MediaBrowser2Impl.java
@@ -19,6 +19,7 @@
 import android.content.Context;
 import android.media.MediaBrowser2;
 import android.media.MediaBrowser2.BrowserCallback;
+import android.media.MediaController2;
 import android.media.MediaItem2;
 import android.media.SessionToken2;
 import android.media.update.MediaBrowser2Provider;
@@ -44,6 +45,10 @@
         mCallback = callback;
     }
 
+    @Override MediaBrowser2 getInstance() {
+        return (MediaBrowser2) super.getInstance();
+    }
+
     @Override
     public void getLibraryRoot_impl(Bundle rootHints) {
         final IMediaSession2 binder = getSessionBinder();
@@ -183,39 +188,39 @@
     public void onGetLibraryRootDone(
             final Bundle rootHints, final String rootMediaId, final Bundle rootExtra) {
         getCallbackExecutor().execute(() -> {
-            mCallback.onGetLibraryRootDone(rootHints, rootMediaId, rootExtra);
+            mCallback.onGetLibraryRootDone(getInstance(), rootHints, rootMediaId, rootExtra);
         });
     }
 
     public void onGetItemDone(String mediaId, MediaItem2 item) {
         getCallbackExecutor().execute(() -> {
-            mCallback.onGetItemDone(mediaId, item);
+            mCallback.onGetItemDone(getInstance(), mediaId, item);
         });
     }
 
     public void onGetChildrenDone(String parentId, int page, int pageSize, List<MediaItem2> result,
             Bundle extras) {
         getCallbackExecutor().execute(() -> {
-            mCallback.onGetChildrenDone(parentId, page, pageSize, result, extras);
+            mCallback.onGetChildrenDone(getInstance(), parentId, page, pageSize, result, extras);
         });
     }
 
     public void onSearchResultChanged(String query, int itemCount, Bundle extras) {
         getCallbackExecutor().execute(() -> {
-            mCallback.onSearchResultChanged(query, itemCount, extras);
+            mCallback.onSearchResultChanged(getInstance(), query, itemCount, extras);
         });
     }
 
     public void onGetSearchResultDone(String query, int page, int pageSize, List<MediaItem2> result,
             Bundle extras) {
         getCallbackExecutor().execute(() -> {
-            mCallback.onGetSearchResultDone(query, page, pageSize, result, extras);
+            mCallback.onGetSearchResultDone(getInstance(), query, page, pageSize, result, extras);
         });
     }
 
     public void onChildrenChanged(final String parentId, int itemCount, final Bundle extras) {
         getCallbackExecutor().execute(() -> {
-            mCallback.onChildrenChanged(parentId, itemCount, extras);
+            mCallback.onChildrenChanged(getInstance(), parentId, itemCount, extras);
         });
     }
 }
diff --git a/packages/MediaComponents/src/com/android/media/MediaController2Impl.java b/packages/MediaComponents/src/com/android/media/MediaController2Impl.java
index f6cad09..6214afd 100644
--- a/packages/MediaComponents/src/com/android/media/MediaController2Impl.java
+++ b/packages/MediaComponents/src/com/android/media/MediaController2Impl.java
@@ -226,7 +226,7 @@
             }
         }
         mCallbackExecutor.execute(() -> {
-            mCallback.onDisconnected();
+            mCallback.onDisconnected(mInstance);
         });
     }
 
@@ -623,7 +623,7 @@
             if (!mInstance.isConnected()) {
                 return;
             }
-            mCallback.onPlaybackStateChanged(state);
+            mCallback.onPlaybackStateChanged(mInstance, state);
         });
     }
 
@@ -635,7 +635,7 @@
             if (!mInstance.isConnected()) {
                 return;
             }
-            mCallback.onPlaylistParamsChanged(params);
+            mCallback.onPlaylistParamsChanged(mInstance, params);
         });
     }
 
@@ -647,7 +647,7 @@
             if (!mInstance.isConnected()) {
                 return;
             }
-            mCallback.onPlaybackInfoChanged(info);
+            mCallback.onPlaybackInfoChanged(mInstance, info);
         });
     }
 
@@ -666,7 +666,7 @@
                 if (!mInstance.isConnected()) {
                     return;
                 }
-                mCallback.onPlaylistChanged(playlist);
+                mCallback.onPlaylistChanged(mInstance, playlist);
             });
         }
     }
@@ -721,7 +721,7 @@
                 // Note: We may trigger ControllerCallbacks with the initial values
                 // But it's hard to define the order of the controller callbacks
                 // Only notify about the
-                mCallback.onConnected(allowedCommands);
+                mCallback.onConnected(mInstance, allowedCommands);
             });
         } finally {
             if (close) {
@@ -739,13 +739,13 @@
         }
         mCallbackExecutor.execute(() -> {
             // TODO(jaewan): Double check if the controller exists.
-            mCallback.onCustomCommand(command, args, receiver);
+            mCallback.onCustomCommand(mInstance, command, args, receiver);
         });
     }
 
     void onCustomLayoutChanged(final List<CommandButton> layout) {
         mCallbackExecutor.execute(() -> {
-            mCallback.onCustomLayoutChanged(layout);
+            mCallback.onCustomLayoutChanged(mInstance, layout);
         });
     }
 
diff --git a/packages/MediaComponents/src/com/android/media/MediaSession2Stub.java b/packages/MediaComponents/src/com/android/media/MediaSession2Stub.java
index 0d0fc68..7ff3ae3 100644
--- a/packages/MediaComponents/src/com/android/media/MediaSession2Stub.java
+++ b/packages/MediaComponents/src/com/android/media/MediaSession2Stub.java
@@ -231,7 +231,8 @@
                 // instead of pending them.
                 mConnectingControllers.add(ControllerInfoImpl.from(controllerInfo).getId());
             }
-            CommandGroup allowedCommands = session.getCallback().onConnect(controllerInfo);
+            CommandGroup allowedCommands = session.getCallback().onConnect(
+                    session.getInstance(), controllerInfo);
             // Don't reject connection for the request from trusted app.
             // Otherwise server will fail to retrieve session's information to dispatch
             // media keys to.
@@ -342,7 +343,8 @@
             // TODO(jaewan): Sanity check.
             Command command = new Command(
                     session.getContext(), MediaSession2.COMMAND_CODE_SET_VOLUME);
-            boolean accepted = session.getCallback().onCommandRequest(controller, command);
+            boolean accepted = session.getCallback().onCommandRequest(session.getInstance(),
+                    controller, command);
             if (!accepted) {
                 // Don't run rejected command.
                 if (DEBUG) {
@@ -377,7 +379,8 @@
             // TODO(jaewan): Sanity check.
             Command command = new Command(
                     session.getContext(), MediaSession2.COMMAND_CODE_SET_VOLUME);
-            boolean accepted = session.getCallback().onCommandRequest(controller, command);
+            boolean accepted = session.getCallback().onCommandRequest(session.getInstance(),
+                    controller, command);
             if (!accepted) {
                 // Don't run rejected command.
                 if (DEBUG) {
@@ -416,7 +419,8 @@
             }
             // TODO(jaewan): Sanity check.
             Command command = new Command(session.getContext(), commandCode);
-            boolean accepted = session.getCallback().onCommandRequest(controller, command);
+            boolean accepted = session.getCallback().onCommandRequest(session.getInstance(),
+                    controller, command);
             if (!accepted) {
                 // Don't run rejected command.
                 if (DEBUG) {
@@ -488,7 +492,8 @@
             if (getControllerIfAble(caller, command) == null) {
                 return;
             }
-            session.getCallback().onCustomCommand(controller, command, args, receiver);
+            session.getCallback().onCustomCommand(session.getInstance(),
+                    controller, command, args, receiver);
         });
     }
 
@@ -506,7 +511,8 @@
                     caller, MediaSession2.COMMAND_CODE_PREPARE_FROM_URI) == null) {
                 return;
             }
-            session.getCallback().onPrepareFromUri(controller, uri, extras);
+            session.getCallback().onPrepareFromUri(session.getInstance(),
+                    controller, uri, extras);
         });
     }
 
@@ -524,7 +530,8 @@
                     caller, MediaSession2.COMMAND_CODE_PREPARE_FROM_SEARCH) == null) {
                 return;
             }
-            session.getCallback().onPrepareFromSearch(controller, query, extras);
+            session.getCallback().onPrepareFromSearch(session.getInstance(),
+                    controller, query, extras);
         });
     }
 
@@ -542,7 +549,8 @@
                     caller, MediaSession2.COMMAND_CODE_PREPARE_FROM_MEDIA_ID) == null) {
                 return;
             }
-            session.getCallback().onPrepareFromMediaId(controller, mediaId, extras);
+            session.getCallback().onPrepareFromMediaId(session.getInstance(),
+                    controller, mediaId, extras);
         });
     }
 
@@ -560,7 +568,7 @@
                     caller, MediaSession2.COMMAND_CODE_PLAY_FROM_URI) == null) {
                 return;
             }
-            session.getCallback().onPlayFromUri(controller, uri, extras);
+            session.getCallback().onPlayFromUri(session.getInstance(), controller, uri, extras);
         });
     }
 
@@ -578,7 +586,8 @@
                     caller, MediaSession2.COMMAND_CODE_PLAY_FROM_SEARCH) == null) {
                 return;
             }
-            session.getCallback().onPlayFromSearch(controller, query, extras);
+            session.getCallback().onPlayFromSearch(session.getInstance(),
+                    controller, query, extras);
         });
     }
 
@@ -595,7 +604,8 @@
             if (session == null) {
                 return;
             }
-            session.getCallback().onPlayFromMediaId(controller, mediaId, extras);
+            session.getCallback().onPlayFromMediaId(session.getInstance(),
+                    controller, mediaId, extras);
         });
     }
 
@@ -616,7 +626,8 @@
                 return;
             }
             Rating2 rating = Rating2Impl.fromBundle(session.getContext(), ratingBundle);
-            session.getCallback().onSetRating(controller, mediaId, rating);
+            session.getCallback().onSetRating(session.getInstance(),
+                    controller, mediaId, rating);
         });
     }
 
@@ -637,7 +648,8 @@
             if (getControllerIfAble(caller, MediaSession2.COMMAND_CODE_BROWSER) == null) {
                 return;
             }
-            LibraryRoot root = session.getCallback().onGetLibraryRoot(controller, rootHints);
+            LibraryRoot root = session.getCallback().onGetLibraryRoot(session.getInstance(),
+                    controller, rootHints);
             try {
                 caller.onGetLibraryRootDone(rootHints,
                         root == null ? null : root.getRootId(),
@@ -668,7 +680,8 @@
             if (getControllerIfAble(caller, MediaSession2.COMMAND_CODE_BROWSER) == null) {
                 return;
             }
-            MediaItem2 result = session.getCallback().onGetItem(controller, mediaId);
+            MediaItem2 result = session.getCallback().onGetItem(session.getInstance(),
+                    controller, mediaId);
             try {
                 caller.onGetItemDone(mediaId, result == null ? null : result.toBundle());
             } catch (RemoteException e) {
@@ -703,7 +716,7 @@
             if (getControllerIfAble(caller, MediaSession2.COMMAND_CODE_BROWSER) == null) {
                 return;
             }
-            List<MediaItem2> result = session.getCallback().onGetChildren(
+            List<MediaItem2> result = session.getCallback().onGetChildren(session.getInstance(),
                     controller, parentId, page, pageSize, extras);
             if (result != null && result.size() > pageSize) {
                 throw new IllegalArgumentException("onGetChildren() shouldn't return media items "
@@ -738,7 +751,8 @@
             if (getControllerIfAble(caller, MediaSession2.COMMAND_CODE_BROWSER) == null) {
                 return;
             }
-            session.getCallback().onSearch(controller, query, extras);
+            session.getCallback().onSearch(session.getInstance(),
+                    controller, query, extras);
         });
     }
 
@@ -767,7 +781,7 @@
             if (getControllerIfAble(caller, MediaSession2.COMMAND_CODE_BROWSER) == null) {
                 return;
             }
-            List<MediaItem2> result = session.getCallback().onGetSearchResult(
+            List<MediaItem2> result = session.getCallback().onGetSearchResult(session.getInstance(),
                     controller, query, page, pageSize, extras);
             if (result != null && result.size() > pageSize) {
                 throw new IllegalArgumentException("onGetSearchResult() shouldn't return media "
@@ -804,7 +818,8 @@
             if (getControllerIfAble(caller, MediaSession2.COMMAND_CODE_BROWSER) == null) {
                 return;
             }
-            session.getCallback().onSubscribe(controller, parentId, option);
+            session.getCallback().onSubscribe(session.getInstance(),
+                    controller, parentId, option);
             synchronized (mLock) {
                 Set<String> subscription = mSubscriptions.get(controller);
                 if (subscription == null) {
@@ -828,7 +843,7 @@
             if (getControllerIfAble(caller, MediaSession2.COMMAND_CODE_BROWSER) == null) {
                 return;
             }
-            session.getCallback().onUnsubscribe(controller, parentId);
+            session.getCallback().onUnsubscribe(session.getInstance(), controller, parentId);
             synchronized (mLock) {
                 mSubscriptions.remove(controller);
             }
diff --git a/packages/MediaComponents/test/src/android/media/MediaBrowser2Test.java b/packages/MediaComponents/test/src/android/media/MediaBrowser2Test.java
index 27822e6..d1c7717 100644
--- a/packages/MediaComponents/test/src/android/media/MediaBrowser2Test.java
+++ b/packages/MediaComponents/test/src/android/media/MediaBrowser2Test.java
@@ -462,45 +462,48 @@
 
         @CallSuper
         @Override
-        public void onConnected(CommandGroup commands) {
+        public void onConnected(MediaController2 controller, CommandGroup commands) {
             connectLatch.countDown();
         }
 
         @CallSuper
         @Override
-        public void onDisconnected() {
+        public void onDisconnected(MediaController2 controller) {
             disconnectLatch.countDown();
         }
 
         @Override
-        public void onPlaybackStateChanged(PlaybackState2 state) {
+        public void onPlaybackStateChanged(MediaController2 controller, PlaybackState2 state) {
             mCallbackProxy.onPlaybackStateChanged(state);
         }
 
         @Override
-        public void onPlaylistParamsChanged(PlaylistParams params) {
+        public void onPlaylistParamsChanged(MediaController2 controller, PlaylistParams params) {
             mCallbackProxy.onPlaylistParamsChanged(params);
         }
 
         @Override
-        public void onPlaybackInfoChanged(MediaController2.PlaybackInfo info) {
+        public void onPlaybackInfoChanged(MediaController2 controller,
+                MediaController2.PlaybackInfo info) {
             mCallbackProxy.onPlaybackInfoChanged(info);
         }
 
         @Override
-        public void onCustomCommand(Command command, Bundle args, ResultReceiver receiver) {
+        public void onCustomCommand(MediaController2 controller, Command command, Bundle args,
+                ResultReceiver receiver) {
             mCallbackProxy.onCustomCommand(command, args, receiver);
         }
 
 
         @Override
-        public void onCustomLayoutChanged(List<CommandButton> layout) {
+        public void onCustomLayoutChanged(MediaController2 controller, List<CommandButton> layout) {
             mCallbackProxy.onCustomLayoutChanged(layout);
         }
 
         @Override
-        public void onGetLibraryRootDone(Bundle rootHints, String rootMediaId, Bundle rootExtra) {
-            super.onGetLibraryRootDone(rootHints, rootMediaId, rootExtra);
+        public void onGetLibraryRootDone(MediaBrowser2 browser, Bundle rootHints,
+                String rootMediaId, Bundle rootExtra) {
+            super.onGetLibraryRootDone(browser, rootHints, rootMediaId, rootExtra);
             if (mCallbackProxy instanceof TestBrowserCallbackInterface) {
                 ((TestBrowserCallbackInterface) mCallbackProxy)
                         .onGetLibraryRootDone(rootHints, rootMediaId, rootExtra);
@@ -508,17 +511,17 @@
         }
 
         @Override
-        public void onGetItemDone(String mediaId, MediaItem2 result) {
-            super.onGetItemDone(mediaId, result);
+        public void onGetItemDone(MediaBrowser2 browser, String mediaId, MediaItem2 result) {
+            super.onGetItemDone(browser, mediaId, result);
             if (mCallbackProxy instanceof TestBrowserCallbackInterface) {
                 ((TestBrowserCallbackInterface) mCallbackProxy).onGetItemDone(mediaId, result);
             }
         }
 
         @Override
-        public void onGetChildrenDone(String parentId, int page, int pageSize,
-                List<MediaItem2> result, Bundle extras) {
-            super.onGetChildrenDone(parentId, page, pageSize, result, extras);
+        public void onGetChildrenDone(MediaBrowser2 browser, String parentId, int page,
+                int pageSize, List<MediaItem2> result, Bundle extras) {
+            super.onGetChildrenDone(browser, parentId, page, pageSize, result, extras);
             if (mCallbackProxy instanceof TestBrowserCallbackInterface) {
                 ((TestBrowserCallbackInterface) mCallbackProxy)
                         .onGetChildrenDone(parentId, page, pageSize, result, extras);
@@ -526,8 +529,9 @@
         }
 
         @Override
-        public void onSearchResultChanged(String query, int itemCount, Bundle extras) {
-            super.onSearchResultChanged(query, itemCount, extras);
+        public void onSearchResultChanged(MediaBrowser2 browser, String query, int itemCount,
+                Bundle extras) {
+            super.onSearchResultChanged(browser, query, itemCount, extras);
             if (mCallbackProxy instanceof TestBrowserCallbackInterface) {
                 ((TestBrowserCallbackInterface) mCallbackProxy)
                         .onSearchResultChanged(query, itemCount, extras);
@@ -535,9 +539,9 @@
         }
 
         @Override
-        public void onGetSearchResultDone(String query, int page, int pageSize,
-                List<MediaItem2> result, Bundle extras) {
-            super.onGetSearchResultDone(query, page, pageSize, result, extras);
+        public void onGetSearchResultDone(MediaBrowser2 browser, String query, int page,
+                int pageSize, List<MediaItem2> result, Bundle extras) {
+            super.onGetSearchResultDone(browser, query, page, pageSize, result, extras);
             if (mCallbackProxy instanceof TestBrowserCallbackInterface) {
                 ((TestBrowserCallbackInterface) mCallbackProxy)
                         .onGetSearchResultDone(query, page, pageSize, result, extras);
@@ -545,8 +549,9 @@
         }
 
         @Override
-        public void onChildrenChanged(String parentId, int itemCount, Bundle extras) {
-            super.onChildrenChanged(parentId, itemCount, extras);
+        public void onChildrenChanged(MediaBrowser2 browser, String parentId, int itemCount,
+                Bundle extras) {
+            super.onChildrenChanged(browser, parentId, itemCount, extras);
             if (mCallbackProxy instanceof TestBrowserCallbackInterface) {
                 ((TestBrowserCallbackInterface) mCallbackProxy)
                         .onChildrenChanged(parentId, itemCount, extras);
diff --git a/packages/MediaComponents/test/src/android/media/MediaController2Test.java b/packages/MediaComponents/test/src/android/media/MediaController2Test.java
index 908952c..e6ad098 100644
--- a/packages/MediaComponents/test/src/android/media/MediaController2Test.java
+++ b/packages/MediaComponents/test/src/android/media/MediaController2Test.java
@@ -19,7 +19,6 @@
 import android.app.PendingIntent;
 import android.content.Context;
 import android.content.Intent;
-import android.media.MediaPlayerBase.PlayerEventCallback;
 import android.media.MediaSession2.Command;
 import android.media.MediaSession2.CommandGroup;
 import android.media.MediaSession2.ControllerInfo;
@@ -322,9 +321,9 @@
         final CountDownLatch latch = new CountDownLatch(1);
         final SessionCallback callback = new SessionCallback(mContext) {
             @Override
-            public void onCustomCommand(ControllerInfo controller, Command customCommand,
-                    Bundle args, ResultReceiver cb) {
-                super.onCustomCommand(controller, customCommand, args, cb);
+            public void onCustomCommand(MediaSession2 session, ControllerInfo controller,
+                    Command customCommand, Bundle args, ResultReceiver cb) {
+                super.onCustomCommand(session, controller, customCommand, args, cb);
                 assertEquals(mContext.getPackageName(), controller.getPackageName());
                 assertEquals(testCommand, customCommand);
                 assertTrue(TestUtils.equals(testArgs, args));
@@ -352,7 +351,8 @@
     public void testControllerCallback_sessionRejects() throws InterruptedException {
         final MediaSession2.SessionCallback sessionCallback = new SessionCallback(mContext) {
             @Override
-            public MediaSession2.CommandGroup onConnect(ControllerInfo controller) {
+            public MediaSession2.CommandGroup onConnect(MediaSession2 session,
+                    ControllerInfo controller) {
                 return null;
             }
         };
@@ -390,7 +390,9 @@
         final CountDownLatch latch = new CountDownLatch(1);
         final SessionCallback callback = new SessionCallback(mContext) {
             @Override
-            public void onPlayFromSearch(ControllerInfo controller, String query, Bundle extras) {
+            public void onPlayFromSearch(MediaSession2 session, ControllerInfo controller,
+                    String query, Bundle extras) {
+                super.onPlayFromSearch(session, controller, query, extras);
                 assertEquals(mContext.getPackageName(), controller.getPackageName());
                 assertEquals(request, query);
                 assertTrue(TestUtils.equals(bundle, extras));
@@ -415,7 +417,8 @@
         final CountDownLatch latch = new CountDownLatch(1);
         final SessionCallback callback = new SessionCallback(mContext) {
             @Override
-            public void onPlayFromUri(ControllerInfo controller, Uri uri, Bundle extras) {
+            public void onPlayFromUri(MediaSession2 session, ControllerInfo controller, Uri uri,
+                    Bundle extras) {
                 assertEquals(mContext.getPackageName(), controller.getPackageName());
                 assertEquals(request, uri);
                 assertTrue(TestUtils.equals(bundle, extras));
@@ -440,9 +443,10 @@
         final CountDownLatch latch = new CountDownLatch(1);
         final SessionCallback callback = new SessionCallback(mContext) {
             @Override
-            public void onPlayFromMediaId(ControllerInfo controller, String id, Bundle extras) {
+            public void onPlayFromMediaId(MediaSession2 session, ControllerInfo controller,
+                    String mediaId, Bundle extras) {
                 assertEquals(mContext.getPackageName(), controller.getPackageName());
-                assertEquals(request, id);
+                assertEquals(request, mediaId);
                 assertTrue(TestUtils.equals(bundle, extras));
                 latch.countDown();
             }
@@ -466,8 +470,8 @@
         final CountDownLatch latch = new CountDownLatch(1);
         final SessionCallback callback = new SessionCallback(mContext) {
             @Override
-            public void onPrepareFromSearch(ControllerInfo controller, String query,
-                    Bundle extras) {
+            public void onPrepareFromSearch(MediaSession2 session, ControllerInfo controller,
+                    String query, Bundle extras) {
                 assertEquals(mContext.getPackageName(), controller.getPackageName());
                 assertEquals(request, query);
                 assertTrue(TestUtils.equals(bundle, extras));
@@ -492,7 +496,8 @@
         final CountDownLatch latch = new CountDownLatch(1);
         final SessionCallback callback = new SessionCallback(mContext) {
             @Override
-            public void onPrepareFromUri(ControllerInfo controller, Uri uri, Bundle extras) {
+            public void onPrepareFromUri(MediaSession2 session, ControllerInfo controller, Uri uri,
+                    Bundle extras) {
                 assertEquals(mContext.getPackageName(), controller.getPackageName());
                 assertEquals(request, uri);
                 assertTrue(TestUtils.equals(bundle, extras));
@@ -517,9 +522,10 @@
         final CountDownLatch latch = new CountDownLatch(1);
         final SessionCallback callback = new SessionCallback(mContext) {
             @Override
-            public void onPrepareFromMediaId(ControllerInfo controller, String id, Bundle extras) {
+            public void onPrepareFromMediaId(MediaSession2 session, ControllerInfo controller,
+                    String mediaId, Bundle extras) {
                 assertEquals(mContext.getPackageName(), controller.getPackageName());
-                assertEquals(request, id);
+                assertEquals(request, mediaId);
                 assertTrue(TestUtils.equals(bundle, extras));
                 latch.countDown();
             }
@@ -544,8 +550,8 @@
         final CountDownLatch latch = new CountDownLatch(1);
         final SessionCallback callback = new SessionCallback(mContext) {
             @Override
-            public void onSetRating(ControllerInfo controller, String mediaIdOut,
-                    Rating2 ratingOut) {
+            public void onSetRating(MediaSession2 session, ControllerInfo controller,
+                    String mediaIdOut, Rating2 ratingOut) {
                 assertEquals(mContext.getPackageName(), controller.getPackageName());
                 assertEquals(mediaId, mediaIdOut);
                 assertEquals(rating, ratingOut);
diff --git a/packages/MediaComponents/test/src/android/media/MediaSession2Test.java b/packages/MediaComponents/test/src/android/media/MediaSession2Test.java
index 16cc07c..8c1a749 100644
--- a/packages/MediaComponents/test/src/android/media/MediaSession2Test.java
+++ b/packages/MediaComponents/test/src/android/media/MediaSession2Test.java
@@ -29,7 +29,6 @@
 
 import android.content.Context;
 import android.media.MediaController2.PlaybackInfo;
-import android.media.MediaPlayerBase.PlayerEventCallback;
 import android.media.MediaSession2.Builder;
 import android.media.MediaSession2.Command;
 import android.media.MediaSession2.CommandButton;
@@ -38,7 +37,6 @@
 import android.media.MediaSession2.PlaylistParams;
 import android.media.MediaSession2.SessionCallback;
 import android.os.Bundle;
-import android.os.Looper;
 import android.os.Process;
 import android.os.ResultReceiver;
 import android.support.annotation.NonNull;
@@ -391,11 +389,12 @@
         final CountDownLatch latch = new CountDownLatch(1);
         final SessionCallback sessionCallback = new SessionCallback(mContext) {
             @Override
-            public CommandGroup onConnect(ControllerInfo controller) {
+            public CommandGroup onConnect(MediaSession2 session,
+                    ControllerInfo controller) {
                 if (mContext.getPackageName().equals(controller.getPackageName())) {
                     mSession.setCustomLayout(controller, buttons);
                 }
-                return super.onConnect(controller);
+                return super.onConnect(session, controller);
             }
         };
 
@@ -474,7 +473,8 @@
         }
 
         @Override
-        public MediaSession2.CommandGroup onConnect(ControllerInfo controllerInfo) {
+        public MediaSession2.CommandGroup onConnect(MediaSession2 session,
+                ControllerInfo controllerInfo) {
             if (Process.myUid() != controllerInfo.getUid()) {
                 return null;
             }
@@ -494,7 +494,7 @@
         }
 
         @Override
-        public boolean onCommandRequest(ControllerInfo controllerInfo,
+        public boolean onCommandRequest(MediaSession2 session, ControllerInfo controllerInfo,
                 MediaSession2.Command command) {
             assertEquals(mContext.getPackageName(), controllerInfo.getPackageName());
             assertEquals(Process.myUid(), controllerInfo.getUid());
diff --git a/packages/MediaComponents/test/src/android/media/MediaSession2TestBase.java b/packages/MediaComponents/test/src/android/media/MediaSession2TestBase.java
index c30b9a6..b32400f 100644
--- a/packages/MediaComponents/test/src/android/media/MediaSession2TestBase.java
+++ b/packages/MediaComponents/test/src/android/media/MediaSession2TestBase.java
@@ -187,23 +187,24 @@
 
         @CallSuper
         @Override
-        public void onConnected(CommandGroup commands) {
+        public void onConnected(MediaController2 controller, CommandGroup commands) {
             connectLatch.countDown();
         }
 
         @CallSuper
         @Override
-        public void onDisconnected() {
+        public void onDisconnected(MediaController2 controller) {
             disconnectLatch.countDown();
         }
 
         @Override
-        public void onPlaybackStateChanged(PlaybackState2 state) {
+        public void onPlaybackStateChanged(MediaController2 controller, PlaybackState2 state) {
             mCallbackProxy.onPlaybackStateChanged(state);
         }
 
         @Override
-        public void onCustomCommand(Command command, Bundle args, ResultReceiver receiver) {
+        public void onCustomCommand(MediaController2 controller, Command command, Bundle args,
+                ResultReceiver receiver) {
             mCallbackProxy.onCustomCommand(command, args, receiver);
         }
 
@@ -226,22 +227,24 @@
         }
 
         @Override
-        public void onPlaylistChanged(List<MediaItem2> params) {
+        public void onPlaylistChanged(MediaController2 controller, List<MediaItem2> params) {
             mCallbackProxy.onPlaylistChanged(params);
         }
 
         @Override
-        public void onPlaylistParamsChanged(MediaSession2.PlaylistParams params) {
+        public void onPlaylistParamsChanged(MediaController2 controller,
+                MediaSession2.PlaylistParams params) {
             mCallbackProxy.onPlaylistParamsChanged(params);
         }
 
         @Override
-        public void onPlaybackInfoChanged(MediaController2.PlaybackInfo info) {
+        public void onPlaybackInfoChanged(MediaController2 controller,
+                MediaController2.PlaybackInfo info) {
             mCallbackProxy.onPlaybackInfoChanged(info);
         }
 
         @Override
-        public void onCustomLayoutChanged(List<CommandButton> layout) {
+        public void onCustomLayoutChanged(MediaController2 controller, List<CommandButton> layout) {
             mCallbackProxy.onCustomLayoutChanged(layout);
         }
     }
diff --git a/packages/MediaComponents/test/src/android/media/MediaSession2_PermissionTest.java b/packages/MediaComponents/test/src/android/media/MediaSession2_PermissionTest.java
index d1ff9fb..d89cecd 100644
--- a/packages/MediaComponents/test/src/android/media/MediaSession2_PermissionTest.java
+++ b/packages/MediaComponents/test/src/android/media/MediaSession2_PermissionTest.java
@@ -56,7 +56,11 @@
     private MediaSession2 mSession;
     private MediaSession2.SessionCallback mCallback;
 
-    private static ControllerInfo matchesSelf() {
+    private MediaSession2 matchesSession() {
+        return argThat((session) -> session == mSession);
+    }
+
+    private static ControllerInfo matchesCaller() {
         return argThat((controllerInfo) -> controllerInfo.getUid() == Process.myUid());
     }
 
@@ -88,7 +92,7 @@
             commands = new CommandGroup(mContext);
         }
         mCallback = mock(SessionCallback.class);
-        when(mCallback.onConnect(any())).thenReturn(commands);
+        when(mCallback.onConnect(any(), any())).thenReturn(commands);
         if (mSession != null) {
             mSession.close();
         }
@@ -114,36 +118,36 @@
     public void testPlay() throws InterruptedException {
         createSessionWithAllowedActions(createCommandGroupWith(COMMAND_CODE_PLAYBACK_PLAY));
         createController(mSession.getToken()).play();
-        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onCommandRequest(matchesSelf(),
-                matches(COMMAND_CODE_PLAYBACK_PLAY));
+        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onCommandRequest(
+                matchesSession(), matchesCaller(), matches(COMMAND_CODE_PLAYBACK_PLAY));
 
         createSessionWithAllowedActions(createCommandGroupWithout(COMMAND_CODE_PLAYBACK_PLAY));
         createController(mSession.getToken()).play();
-        verify(mCallback, after(WAIT_TIME_MS).never()).onCommandRequest(any(), any());
+        verify(mCallback, after(WAIT_TIME_MS).never()).onCommandRequest(any(), any(), any());
     }
 
     @Test
     public void testPause() throws InterruptedException {
         createSessionWithAllowedActions(createCommandGroupWith(COMMAND_CODE_PLAYBACK_PAUSE));
         createController(mSession.getToken()).pause();
-        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onCommandRequest(matchesSelf(),
-                matches(COMMAND_CODE_PLAYBACK_PAUSE));
+        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onCommandRequest(
+                matchesSession(), matchesCaller(), matches(COMMAND_CODE_PLAYBACK_PAUSE));
 
         createSessionWithAllowedActions(createCommandGroupWithout(COMMAND_CODE_PLAYBACK_PAUSE));
         createController(mSession.getToken()).pause();
-        verify(mCallback, after(WAIT_TIME_MS).never()).onCommandRequest(any(), any());
+        verify(mCallback, after(WAIT_TIME_MS).never()).onCommandRequest(any(), any(), any());
     }
 
     @Test
     public void testStop() throws InterruptedException {
         createSessionWithAllowedActions(createCommandGroupWith(COMMAND_CODE_PLAYBACK_STOP));
         createController(mSession.getToken()).stop();
-        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onCommandRequest(matchesSelf(),
-                matches(COMMAND_CODE_PLAYBACK_STOP));
+        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onCommandRequest(
+                matchesSession(), matchesCaller(), matches(COMMAND_CODE_PLAYBACK_STOP));
 
         createSessionWithAllowedActions(createCommandGroupWithout(COMMAND_CODE_PLAYBACK_STOP));
         createController(mSession.getToken()).stop();
-        verify(mCallback, after(WAIT_TIME_MS).never()).onCommandRequest(any(), any());
+        verify(mCallback, after(WAIT_TIME_MS).never()).onCommandRequest(any(), any(), any());
     }
 
     @Test
@@ -151,13 +155,13 @@
         createSessionWithAllowedActions(
                 createCommandGroupWith(COMMAND_CODE_PLAYBACK_SKIP_NEXT_ITEM));
         createController(mSession.getToken()).skipToNext();
-        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onCommandRequest(matchesSelf(),
-                matches(COMMAND_CODE_PLAYBACK_SKIP_NEXT_ITEM));
+        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onCommandRequest(
+                matchesSession(), matchesCaller(), matches(COMMAND_CODE_PLAYBACK_SKIP_NEXT_ITEM));
 
         createSessionWithAllowedActions(
                 createCommandGroupWithout(COMMAND_CODE_PLAYBACK_SKIP_NEXT_ITEM));
         createController(mSession.getToken()).skipToNext();
-        verify(mCallback, after(WAIT_TIME_MS).never()).onCommandRequest(any(), any());
+        verify(mCallback, after(WAIT_TIME_MS).never()).onCommandRequest(any(), any(), any());
     }
 
     @Test
@@ -165,13 +169,13 @@
         createSessionWithAllowedActions(
                 createCommandGroupWith(COMMAND_CODE_PLAYBACK_SKIP_PREV_ITEM));
         createController(mSession.getToken()).skipToPrevious();
-        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onCommandRequest(matchesSelf(),
-                matches(COMMAND_CODE_PLAYBACK_SKIP_PREV_ITEM));
+        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onCommandRequest(
+                matchesSession(), matchesCaller(), matches(COMMAND_CODE_PLAYBACK_SKIP_PREV_ITEM));
 
         createSessionWithAllowedActions(
                 createCommandGroupWithout(COMMAND_CODE_PLAYBACK_SKIP_PREV_ITEM));
         createController(mSession.getToken()).skipToPrevious();
-        verify(mCallback, after(WAIT_TIME_MS).never()).onCommandRequest(any(), any());
+        verify(mCallback, after(WAIT_TIME_MS).never()).onCommandRequest(any(), any(), any());
     }
 
     @Test
@@ -179,13 +183,13 @@
         createSessionWithAllowedActions(
                 createCommandGroupWith(COMMAND_CODE_PLAYBACK_FAST_FORWARD));
         createController(mSession.getToken()).fastForward();
-        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onCommandRequest(matchesSelf(),
-                matches(COMMAND_CODE_PLAYBACK_FAST_FORWARD));
+        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onCommandRequest(
+                matchesSession(), matchesCaller(), matches(COMMAND_CODE_PLAYBACK_FAST_FORWARD));
 
         createSessionWithAllowedActions(
                 createCommandGroupWithout(COMMAND_CODE_PLAYBACK_FAST_FORWARD));
         createController(mSession.getToken()).fastForward();
-        verify(mCallback, after(WAIT_TIME_MS).never()).onCommandRequest(any(), any());
+        verify(mCallback, after(WAIT_TIME_MS).never()).onCommandRequest(any(), any(), any());
     }
 
     @Test
@@ -193,12 +197,12 @@
         createSessionWithAllowedActions(
                 createCommandGroupWith(COMMAND_CODE_PLAYBACK_REWIND));
         createController(mSession.getToken()).rewind();
-        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onCommandRequest(matchesSelf(),
-                matches(COMMAND_CODE_PLAYBACK_REWIND));
+        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onCommandRequest(
+                matchesSession(), matchesCaller(), matches(COMMAND_CODE_PLAYBACK_REWIND));
 
         createSessionWithAllowedActions(createCommandGroupWithout(COMMAND_CODE_PLAYBACK_REWIND));
         createController(mSession.getToken()).rewind();
-        verify(mCallback, after(WAIT_TIME_MS).never()).onCommandRequest(any(), any());
+        verify(mCallback, after(WAIT_TIME_MS).never()).onCommandRequest(any(), any(), any());
     }
 
     @Test
@@ -207,12 +211,12 @@
         createSessionWithAllowedActions(
                 createCommandGroupWith(COMMAND_CODE_PLAYBACK_SEEK_TO));
         createController(mSession.getToken()).seekTo(position);
-        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onCommandRequest(matchesSelf(),
-                matches(COMMAND_CODE_PLAYBACK_SEEK_TO));
+        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onCommandRequest(
+                matchesSession(), matchesCaller(), matches(COMMAND_CODE_PLAYBACK_SEEK_TO));
 
         createSessionWithAllowedActions(createCommandGroupWithout(COMMAND_CODE_PLAYBACK_SEEK_TO));
         createController(mSession.getToken()).seekTo(position);
-        verify(mCallback, after(WAIT_TIME_MS).never()).onCommandRequest(any(), any());
+        verify(mCallback, after(WAIT_TIME_MS).never()).onCommandRequest(any(), any(), any());
     }
 
     // TODO(jaewan): Uncomment when we implement skipToPlaylistItem()
@@ -227,7 +231,7 @@
         createSessionWithAllowedActions(
                 createCommandGroupWith(COMMAND_CODE_PLAYBACK_SET_CURRENT_PLAYLIST_ITEM));
         createController(mSession.getToken()).skipToPlaylistItem(item);
-        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onCommandRequest(matchesSelf(),
+        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onCommandRequest(matchesCaller(),
                 matches(COMMAND_CODE_PLAYBACK_SET_CURRENT_PLAYLIST_ITEM));
 
         createSessionWithAllowedActions(
@@ -244,25 +248,26 @@
         createSessionWithAllowedActions(
                 createCommandGroupWith(COMMAND_CODE_PLAYBACK_SET_PLAYLIST_PARAMS));
         createController(mSession.getToken()).setPlaylistParams(param);
-        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onCommandRequest(matchesSelf(),
+        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onCommandRequest(
+                matchesSession(), matchesCaller(),
                 matches(COMMAND_CODE_PLAYBACK_SET_PLAYLIST_PARAMS));
 
         createSessionWithAllowedActions(
                 createCommandGroupWithout(COMMAND_CODE_PLAYBACK_SET_PLAYLIST_PARAMS));
         createController(mSession.getToken()).setPlaylistParams(param);
-        verify(mCallback, after(WAIT_TIME_MS).never()).onCommandRequest(any(), any());
+        verify(mCallback, after(WAIT_TIME_MS).never()).onCommandRequest(any(), any(), any());
     }
 
     @Test
     public void testSetVolume() throws InterruptedException {
         createSessionWithAllowedActions(createCommandGroupWith(COMMAND_CODE_SET_VOLUME));
         createController(mSession.getToken()).setVolumeTo(0, 0);
-        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onCommandRequest(matchesSelf(),
-                matches(COMMAND_CODE_SET_VOLUME));
+        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onCommandRequest(
+                matchesSession(), matchesCaller(), matches(COMMAND_CODE_SET_VOLUME));
 
         createSessionWithAllowedActions(createCommandGroupWithout(COMMAND_CODE_SET_VOLUME));
         createController(mSession.getToken()).setVolumeTo(0, 0);
-        verify(mCallback, after(WAIT_TIME_MS).never()).onCommandRequest(any(), any());
+        verify(mCallback, after(WAIT_TIME_MS).never()).onCommandRequest(any(), any(), any());
     }
 
     @Test
@@ -271,13 +276,14 @@
         createSessionWithAllowedActions(
                 createCommandGroupWith(COMMAND_CODE_PLAY_FROM_MEDIA_ID));
         createController(mSession.getToken()).playFromMediaId(mediaId, null);
-        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onPlayFromMediaId(matchesSelf(),
-                eq(mediaId), isNull());
+        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onPlayFromMediaId(
+                matchesSession(), matchesCaller(), eq(mediaId), isNull());
 
         createSessionWithAllowedActions(
                 createCommandGroupWithout(COMMAND_CODE_PLAY_FROM_MEDIA_ID));
         createController(mSession.getToken()).playFromMediaId(mediaId, null);
-        verify(mCallback, after(WAIT_TIME_MS).never()).onPlayFromMediaId(any(), any(), any());
+        verify(mCallback, after(WAIT_TIME_MS).never()).onPlayFromMediaId(
+                any(), any(), any(), any());
     }
 
     @Test
@@ -286,13 +292,13 @@
         createSessionWithAllowedActions(
                 createCommandGroupWith(COMMAND_CODE_PLAY_FROM_URI));
         createController(mSession.getToken()).playFromUri(uri, null);
-        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onPlayFromUri(matchesSelf(),
-                eq(uri), isNull());
+        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onPlayFromUri(
+                matchesSession(), matchesCaller(), eq(uri), isNull());
 
         createSessionWithAllowedActions(
                 createCommandGroupWithout(COMMAND_CODE_PLAY_FROM_URI));
         createController(mSession.getToken()).playFromUri(uri, null);
-        verify(mCallback, after(WAIT_TIME_MS).never()).onPlayFromUri(any(), any(), any());
+        verify(mCallback, after(WAIT_TIME_MS).never()).onPlayFromUri(any(), any(), any(), any());
     }
 
     @Test
@@ -301,13 +307,13 @@
         createSessionWithAllowedActions(
                 createCommandGroupWith(COMMAND_CODE_PLAY_FROM_SEARCH));
         createController(mSession.getToken()).playFromSearch(query, null);
-        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onPlayFromSearch(matchesSelf(),
-                eq(query), isNull());
+        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onPlayFromSearch(
+                matchesSession(), matchesCaller(), eq(query), isNull());
 
         createSessionWithAllowedActions(
                 createCommandGroupWithout(COMMAND_CODE_PLAY_FROM_SEARCH));
         createController(mSession.getToken()).playFromSearch(query, null);
-        verify(mCallback, after(WAIT_TIME_MS).never()).onPlayFromSearch(any(), any(), any());
+        verify(mCallback, after(WAIT_TIME_MS).never()).onPlayFromSearch(any(), any(), any(), any());
     }
 
     @Test
@@ -316,13 +322,14 @@
         createSessionWithAllowedActions(
                 createCommandGroupWith(COMMAND_CODE_PREPARE_FROM_MEDIA_ID));
         createController(mSession.getToken()).prepareFromMediaId(mediaId, null);
-        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onPrepareFromMediaId(matchesSelf(),
-                eq(mediaId), isNull());
+        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onPrepareFromMediaId(
+                matchesSession(), matchesCaller(), eq(mediaId), isNull());
 
         createSessionWithAllowedActions(
                 createCommandGroupWithout(COMMAND_CODE_PREPARE_FROM_MEDIA_ID));
         createController(mSession.getToken()).prepareFromMediaId(mediaId, null);
-        verify(mCallback, after(WAIT_TIME_MS).never()).onPrepareFromMediaId(any(), any(), any());
+        verify(mCallback, after(WAIT_TIME_MS).never()).onPrepareFromMediaId(
+                any(), any(), any(), any());
     }
 
     @Test
@@ -331,13 +338,13 @@
         createSessionWithAllowedActions(
                 createCommandGroupWith(COMMAND_CODE_PREPARE_FROM_URI));
         createController(mSession.getToken()).prepareFromUri(uri, null);
-        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onPrepareFromUri(matchesSelf(),
-                eq(uri), isNull());
+        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onPrepareFromUri(
+                matchesSession(), matchesCaller(), eq(uri), isNull());
 
         createSessionWithAllowedActions(
                 createCommandGroupWithout(COMMAND_CODE_PREPARE_FROM_URI));
         createController(mSession.getToken()).prepareFromUri(uri, null);
-        verify(mCallback, after(WAIT_TIME_MS).never()).onPrepareFromUri(any(), any(), any());
+        verify(mCallback, after(WAIT_TIME_MS).never()).onPrepareFromUri(any(), any(), any(), any());
     }
 
     @Test
@@ -346,12 +353,13 @@
         createSessionWithAllowedActions(
                 createCommandGroupWith(COMMAND_CODE_PREPARE_FROM_SEARCH));
         createController(mSession.getToken()).prepareFromSearch(query, null);
-        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onPrepareFromSearch(matchesSelf(),
-                eq(query), isNull());
+        verify(mCallback, timeout(TIMEOUT_MS).atLeastOnce()).onPrepareFromSearch(
+                matchesSession(), matchesCaller(), eq(query), isNull());
 
         createSessionWithAllowedActions(
                 createCommandGroupWithout(COMMAND_CODE_PREPARE_FROM_SEARCH));
         createController(mSession.getToken()).prepareFromSearch(query, null);
-        verify(mCallback, after(WAIT_TIME_MS).never()).onPrepareFromSearch(any(), any(), any());
+        verify(mCallback, after(WAIT_TIME_MS).never()).onPrepareFromSearch(
+                any(), any(), any(), any());
     }
 }
diff --git a/packages/MediaComponents/test/src/android/media/MediaSessionManager_MediaSession2.java b/packages/MediaComponents/test/src/android/media/MediaSessionManager_MediaSession2.java
index 4f344d1..17b200f 100644
--- a/packages/MediaComponents/test/src/android/media/MediaSessionManager_MediaSession2.java
+++ b/packages/MediaComponents/test/src/android/media/MediaSessionManager_MediaSession2.java
@@ -114,7 +114,8 @@
             mSession = new MediaSession2.Builder(mContext).setPlayer(new MockPlayer(0))
                     .setId(TAG).setSessionCallback(sHandlerExecutor, new SessionCallback(mContext) {
                         @Override
-                        public MediaSession2.CommandGroup onConnect(ControllerInfo controller) {
+                        public MediaSession2.CommandGroup onConnect(
+                                MediaSession2 session, ControllerInfo controller) {
                             // Reject all connection request.
                             return null;
                         }
diff --git a/packages/MediaComponents/test/src/android/media/MockMediaLibraryService2.java b/packages/MediaComponents/test/src/android/media/MockMediaLibraryService2.java
index c18d025..fb02f7a 100644
--- a/packages/MediaComponents/test/src/android/media/MockMediaLibraryService2.java
+++ b/packages/MediaComponents/test/src/android/media/MockMediaLibraryService2.java
@@ -23,7 +23,6 @@
 import android.content.Context;
 import android.media.MediaSession2.CommandGroup;
 import android.media.MediaSession2.ControllerInfo;
-import android.media.MediaLibraryService2.MediaLibrarySession;
 import android.media.MediaLibraryService2.MediaLibrarySession.MediaLibrarySessionCallback;
 import android.media.TestServiceRegistry.SessionCallbackProxy;
 import android.media.TestUtils.SyncHandler;
@@ -145,17 +144,20 @@
         }
 
         @Override
-        public CommandGroup onConnect(ControllerInfo controller) {
+        public CommandGroup onConnect(MediaSession2 session,
+                ControllerInfo controller) {
             return mCallbackProxy.onConnect(controller);
         }
 
         @Override
-        public LibraryRoot onGetLibraryRoot(ControllerInfo controller, Bundle rootHints) {
+        public LibraryRoot onGetLibraryRoot(MediaLibrarySession session, ControllerInfo controller,
+                Bundle rootHints) {
             return new LibraryRoot(MockMediaLibraryService2.this, ROOT_ID, EXTRAS);
         }
 
         @Override
-        public MediaItem2 onGetItem(ControllerInfo controller, String mediaId) {
+        public MediaItem2 onGetItem(MediaLibrarySession session, ControllerInfo controller,
+                String mediaId) {
             if (MEDIA_ID_GET_ITEM.equals(mediaId)) {
                 return createMediaItem(mediaId);
             } else {
@@ -164,8 +166,8 @@
         }
 
         @Override
-        public List<MediaItem2> onGetChildren(ControllerInfo controller, String parentId, int page,
-                int pageSize, Bundle extras) {
+        public List<MediaItem2> onGetChildren(MediaLibrarySession session,
+                ControllerInfo controller, String parentId, int page, int pageSize, Bundle extras) {
             if (PARENT_ID.equals(parentId)) {
                 return getPaginatedResult(GET_CHILDREN_RESULT, page, pageSize);
             } else if (PARENT_ID_ERROR.equals(parentId)) {
@@ -176,7 +178,8 @@
         }
 
         @Override
-        public void onSearch(ControllerInfo controllerInfo, String query, Bundle extras) {
+        public void onSearch(MediaLibrarySession session, ControllerInfo controllerInfo,
+                String query, Bundle extras) {
             if (SEARCH_QUERY.equals(query)) {
                 mSession.notifySearchResultChanged(controllerInfo, query, SEARCH_RESULT_COUNT,
                         extras);
@@ -197,8 +200,9 @@
         }
 
         @Override
-        public List<MediaItem2> onGetSearchResult(ControllerInfo controllerInfo,
-                String query, int page, int pageSize, Bundle extras) {
+        public List<MediaItem2> onGetSearchResult(MediaLibrarySession session,
+                ControllerInfo controllerInfo, String query, int page, int pageSize,
+                Bundle extras) {
             if (SEARCH_QUERY.equals(query)) {
                 return getPaginatedResult(SEARCH_RESULT, page, pageSize);
             } else {
@@ -207,12 +211,14 @@
         }
 
         @Override
-        public void onSubscribe(ControllerInfo controller, String parentId, Bundle extras) {
+        public void onSubscribe(MediaLibrarySession session, ControllerInfo controller,
+                String parentId, Bundle extras) {
             mCallbackProxy.onSubscribe(controller, parentId, extras);
         }
 
         @Override
-        public void onUnsubscribe(ControllerInfo controller, String parentId) {
+        public void onUnsubscribe(MediaLibrarySession session, ControllerInfo controller,
+                String parentId) {
             mCallbackProxy.onUnsubscribe(controller, parentId);
         }
     }
diff --git a/packages/MediaComponents/test/src/android/media/MockMediaSessionService2.java b/packages/MediaComponents/test/src/android/media/MockMediaSessionService2.java
index 1c6534d..ce7ce8b 100644
--- a/packages/MediaComponents/test/src/android/media/MockMediaSessionService2.java
+++ b/packages/MediaComponents/test/src/android/media/MockMediaSessionService2.java
@@ -103,7 +103,8 @@
         }
 
         @Override
-        public CommandGroup onConnect(ControllerInfo controller) {
+        public CommandGroup onConnect(MediaSession2 session,
+                ControllerInfo controller) {
             return mCallbackProxy.onConnect(controller);
         }
     }