Merge "Fix the potential memory leak issue caused by setExtension." into main
diff --git a/core/java/android/os/Binder.java b/core/java/android/os/Binder.java
index ed75491..86dc20c 100644
--- a/core/java/android/os/Binder.java
+++ b/core/java/android/os/Binder.java
@@ -149,6 +149,11 @@
     private static volatile boolean sStackTrackingEnabled = false;
 
     /**
+     * The extension binder object
+     */
+    private IBinder mExtension = null;
+
+    /**
      * Enable Binder IPC stack tracking. If enabled, every binder transaction will be logged to
      * {@link TransactionTracker}.
      *
@@ -1234,7 +1239,9 @@
 
     /** @hide */
     @Override
-    public final native @Nullable IBinder getExtension();
+    public final @Nullable IBinder getExtension() {
+        return mExtension;
+    }
 
     /**
      * Set the binder extension.
@@ -1242,7 +1249,12 @@
      *
      * @hide
      */
-    public final native void setExtension(@Nullable IBinder extension);
+    public final void setExtension(@Nullable IBinder extension) {
+        mExtension = extension;
+        setExtensionNative(extension);
+    }
+
+    private final native void setExtensionNative(@Nullable IBinder extension);
 
     /**
      * Default implementation rewinds the parcels and calls onTransact. On
diff --git a/core/jni/android_util_Binder.cpp b/core/jni/android_util_Binder.cpp
index 8003bb7..e85b33e 100644
--- a/core/jni/android_util_Binder.cpp
+++ b/core/jni/android_util_Binder.cpp
@@ -74,6 +74,7 @@
     jmethodID mExecTransact;
     jmethodID mGetInterfaceDescriptor;
     jmethodID mTransactionCallback;
+    jmethodID mGetExtension;
 
     // Object state.
     jfieldID mObject;
@@ -489,8 +490,12 @@
             if (mVintf) {
                 ::android::internal::Stability::markVintf(b.get());
             }
-            if (mExtension != nullptr) {
-                b.get()->setExtension(mExtension);
+            if (mSetExtensionCalled) {
+                jobject javaIBinderObject = env->CallObjectMethod(obj, gBinderOffsets.mGetExtension);
+                sp<IBinder> extensionFromJava = ibinderForJavaObject(env, javaIBinderObject);
+                if (extensionFromJava != nullptr) {
+                    b.get()->setExtension(extensionFromJava);
+                }
             }
             mBinder = b;
             ALOGV("Creating JavaBinder %p (refs %p) for Object %p, weakCount=%" PRId32 "\n",
@@ -516,21 +521,12 @@
         mVintf = false;
     }
 
-    sp<IBinder> getExtension() {
-        AutoMutex _l(mLock);
-        sp<JavaBBinder> b = mBinder.promote();
-        if (b != nullptr) {
-            return b.get()->getExtension();
-        }
-        return mExtension;
-    }
-
     void setExtension(const sp<IBinder>& extension) {
         AutoMutex _l(mLock);
-        mExtension = extension;
+        mSetExtensionCalled = true;
         sp<JavaBBinder> b = mBinder.promote();
         if (b != nullptr) {
-            b.get()->setExtension(mExtension);
+            b.get()->setExtension(extension);
         }
     }
 
@@ -542,8 +538,7 @@
     // is too much binder state here, we can think about making JavaBBinder an
     // sp here (avoid recreating it)
     bool            mVintf = false;
-
-    sp<IBinder>     mExtension;
+    bool            mSetExtensionCalled = false;
 };
 
 // ----------------------------------------------------------------------------
@@ -1249,10 +1244,6 @@
     return IPCThreadState::self()->blockUntilThreadAvailable();
 }
 
-static jobject android_os_Binder_getExtension(JNIEnv* env, jobject obj) {
-    JavaBBinderHolder* jbh = (JavaBBinderHolder*) env->GetLongField(obj, gBinderOffsets.mObject);
-    return javaObjectForIBinder(env, jbh->getExtension());
-}
 
 static void android_os_Binder_setExtension(JNIEnv* env, jobject obj, jobject extensionObject) {
     JavaBBinderHolder* jbh = (JavaBBinderHolder*) env->GetLongField(obj, gBinderOffsets.mObject);
@@ -1295,8 +1286,7 @@
     { "getNativeBBinderHolder", "()J", (void*)android_os_Binder_getNativeBBinderHolder },
     { "getNativeFinalizer", "()J", (void*)android_os_Binder_getNativeFinalizer },
     { "blockUntilThreadAvailable", "()V", (void*)android_os_Binder_blockUntilThreadAvailable },
-    { "getExtension", "()Landroid/os/IBinder;", (void*)android_os_Binder_getExtension },
-    { "setExtension", "(Landroid/os/IBinder;)V", (void*)android_os_Binder_setExtension },
+    { "setExtensionNative", "(Landroid/os/IBinder;)V", (void*)android_os_Binder_setExtension },
 };
 // clang-format on
 
@@ -1313,6 +1303,8 @@
     gBinderOffsets.mTransactionCallback =
             GetStaticMethodIDOrDie(env, clazz, "transactionCallback", "(IIII)V");
     gBinderOffsets.mObject = GetFieldIDOrDie(env, clazz, "mObject", "J");
+    gBinderOffsets.mGetExtension = GetMethodIDOrDie(env, clazz, "getExtension",
+                                                        "()Landroid/os/IBinder;");
 
     return RegisterMethodsOrDie(
         env, kBinderPathName,