Only propagates work source when explicitly set.

- Explicitly set means calling setCallingWorkSourceUid,
restoreCallingWorkSource or clearCallingWorkSource
- We clear the work source when handling BR_TRANSACTION to make sure the
work source is not propagated even when AIDL is not used (i.e. when
Parcel#enforceInterface is not called)
- We also restore the original work source after the transaction to
handle nested calls.

Test: atest binderLibTest BinderWorkSourceTest
Change-Id: Ia3ade2d2c5ebde5b1de2573943969fcb0d02d441
diff --git a/libs/binder/IPCThreadState.cpp b/libs/binder/IPCThreadState.cpp
index f0c21f5..8df83f1 100644
--- a/libs/binder/IPCThreadState.cpp
+++ b/libs/binder/IPCThreadState.cpp
@@ -109,9 +109,7 @@
     "BC_DEAD_BINDER_DONE"
 };
 
-// The work source represents the UID of the process we should attribute the transaction to.
-// We use -1 to specify that the work source was not set using #setWorkSource.
-static const int kUnsetWorkSource = -1;
+static const int64_t kWorkSourcePropagatedBitIndex = 32;
 
 static const char* getReturnString(uint32_t cmd)
 {
@@ -389,12 +387,29 @@
 
 int64_t IPCThreadState::setCallingWorkSourceUid(uid_t uid)
 {
-    // Note: we currently only use half of the int64. We return an int64 for extensibility.
-    int64_t token = mWorkSource;
+    int64_t token = setCallingWorkSourceUidWithoutPropagation(uid);
+    mPropagateWorkSource = true;
+    return token;
+}
+
+int64_t IPCThreadState::setCallingWorkSourceUidWithoutPropagation(uid_t uid)
+{
+    const int64_t propagatedBit = ((int64_t)mPropagateWorkSource) << kWorkSourcePropagatedBitIndex;
+    int64_t token = propagatedBit | mWorkSource;
     mWorkSource = uid;
     return token;
 }
 
+void IPCThreadState::clearPropagateWorkSource()
+{
+    mPropagateWorkSource = false;
+}
+
+bool IPCThreadState::shouldPropagateWorkSource() const
+{
+    return mPropagateWorkSource;
+}
+
 uid_t IPCThreadState::getCallingWorkSourceUid() const
 {
     return mWorkSource;
@@ -408,7 +423,8 @@
 void IPCThreadState::restoreCallingWorkSource(int64_t token)
 {
     uid_t uid = (int)token;
-    setCallingWorkSourceUid(uid);
+    setCallingWorkSourceUidWithoutPropagation(uid);
+    mPropagateWorkSource = ((token >> kWorkSourcePropagatedBitIndex) & 1) == 1;
 }
 
 void IPCThreadState::setLastTransactionBinderFlags(int32_t flags)
@@ -765,6 +781,7 @@
 IPCThreadState::IPCThreadState()
     : mProcess(ProcessState::self()),
       mWorkSource(kUnsetWorkSource),
+      mPropagateWorkSource(false),
       mStrictModePolicy(0),
       mLastTransactionBinderFlags(0)
 {
@@ -1127,6 +1144,13 @@
             const uid_t origUid = mCallingUid;
             const int32_t origStrictModePolicy = mStrictModePolicy;
             const int32_t origTransactionBinderFlags = mLastTransactionBinderFlags;
+            const int32_t origWorkSource = mWorkSource;
+            const bool origPropagateWorkSet = mPropagateWorkSource;
+            // Calling work source will be set by Parcel#enforceInterface. Parcel#enforceInterface
+            // is only guaranteed to be called for AIDL-generated stubs so we reset the work source
+            // here to never propagate it.
+            clearCallingWorkSource();
+            clearPropagateWorkSource();
 
             mCallingPid = tr.sender_pid;
             mCallingUid = tr.sender_euid;
@@ -1179,6 +1203,8 @@
             mCallingUid = origUid;
             mStrictModePolicy = origStrictModePolicy;
             mLastTransactionBinderFlags = origTransactionBinderFlags;
+            mWorkSource = origWorkSource;
+            mPropagateWorkSource = origPropagateWorkSet;
 
             IF_LOG_TRANSACTIONS() {
                 TextOutput::Bundle _b(alog);
diff --git a/libs/binder/Parcel.cpp b/libs/binder/Parcel.cpp
index 643f428..a0a634a 100644
--- a/libs/binder/Parcel.cpp
+++ b/libs/binder/Parcel.cpp
@@ -599,9 +599,10 @@
 // Write RPC headers.  (previously just the interface token)
 status_t Parcel::writeInterfaceToken(const String16& interface)
 {
-    writeInt32(IPCThreadState::self()->getStrictModePolicy() |
-               STRICT_MODE_PENALTY_GATHER);
-    writeInt32(IPCThreadState::self()->getCallingWorkSourceUid());
+    const IPCThreadState* threadState = IPCThreadState::self();
+    writeInt32(threadState->getStrictModePolicy() | STRICT_MODE_PENALTY_GATHER);
+    writeInt32(threadState->shouldPropagateWorkSource() ?
+            threadState->getCallingWorkSourceUid() : IPCThreadState::kUnsetWorkSource);
     // currently the interface identification token is just its name as a string
     return writeString16(interface);
 }
@@ -631,7 +632,7 @@
     }
     // WorkSource.
     int32_t workSource = readInt32();
-    threadState->setCallingWorkSourceUid(workSource);
+    threadState->setCallingWorkSourceUidWithoutPropagation(workSource);
     // Interface descriptor.
     const String16 str(readString16());
     if (str == interface) {
diff --git a/libs/binder/include/binder/IPCThreadState.h b/libs/binder/include/binder/IPCThreadState.h
index de126d5..75b348c 100644
--- a/libs/binder/include/binder/IPCThreadState.h
+++ b/libs/binder/include/binder/IPCThreadState.h
@@ -49,12 +49,16 @@
 
             // See Binder#setCallingWorkSourceUid in Binder.java.
             int64_t             setCallingWorkSourceUid(uid_t uid);
+            // Internal only. Use setCallingWorkSourceUid(uid) instead.
+            int64_t             setCallingWorkSourceUidWithoutPropagation(uid_t uid);
             // See Binder#getCallingWorkSourceUid in Binder.java.
             uid_t               getCallingWorkSourceUid() const;
             // See Binder#clearCallingWorkSource in Binder.java.
             int64_t             clearCallingWorkSource();
             // See Binder#restoreCallingWorkSource in Binder.java.
             void                restoreCallingWorkSource(int64_t token);
+            void                clearPropagateWorkSource();
+            bool                shouldPropagateWorkSource() const;
 
             void                setLastTransactionBinderFlags(int32_t flags);
             int32_t             getLastTransactionBinderFlags() const;
@@ -127,6 +131,11 @@
             // infer information about thread state.
             bool                isServingCall() const;
 
+            // The work source represents the UID of the process we should attribute the transaction
+            // to.
+            // We use -1 to specify that the work source was not set using #setWorkSource.
+            static const int32_t kUnsetWorkSource = -1;
+
 private:
                                 IPCThreadState();
                                 ~IPCThreadState();
@@ -167,6 +176,8 @@
             // The UID of the process who is responsible for this transaction.
             // This is used for resource attribution.
             int32_t             mWorkSource;
+            // Whether the work source should be propagated.
+            bool                mPropagateWorkSource;
             int32_t             mStrictModePolicy;
             int32_t             mLastTransactionBinderFlags;
             IPCThreadStateBase  *mIPCThreadStateBase;
diff --git a/libs/binder/tests/binderLibTest.cpp b/libs/binder/tests/binderLibTest.cpp
index 22c1bad..cd37d49 100644
--- a/libs/binder/tests/binderLibTest.cpp
+++ b/libs/binder/tests/binderLibTest.cpp
@@ -180,6 +180,7 @@
     public:
         virtual void SetUp() {
             m_server = static_cast<BinderLibTestEnv *>(binder_env)->getServer();
+            IPCThreadState::self()->restoreCallingWorkSource(0); 
         }
         virtual void TearDown() {
         }
@@ -953,11 +954,28 @@
 {
     status_t ret;
     Parcel data, reply;
+    IPCThreadState::self()->clearCallingWorkSource();
     int64_t previousWorkSource = IPCThreadState::self()->setCallingWorkSourceUid(100);
     data.writeInterfaceToken(binderLibTestServiceName);
     ret = m_server->transact(BINDER_LIB_TEST_GET_WORK_SOURCE_TRANSACTION, data, &reply);
     EXPECT_EQ(100, reply.readInt32());
     EXPECT_EQ(-1, previousWorkSource);
+    EXPECT_EQ(true, IPCThreadState::self()->shouldPropagateWorkSource());
+    EXPECT_EQ(NO_ERROR, ret);
+}
+
+TEST_F(BinderLibTest, WorkSourceSetWithoutPropagation)
+{
+    status_t ret;
+    Parcel data, reply;
+
+    IPCThreadState::self()->setCallingWorkSourceUidWithoutPropagation(100);
+    EXPECT_EQ(false, IPCThreadState::self()->shouldPropagateWorkSource());
+
+    data.writeInterfaceToken(binderLibTestServiceName);
+    ret = m_server->transact(BINDER_LIB_TEST_GET_WORK_SOURCE_TRANSACTION, data, &reply);
+    EXPECT_EQ(-1, reply.readInt32());
+    EXPECT_EQ(false, IPCThreadState::self()->shouldPropagateWorkSource());
     EXPECT_EQ(NO_ERROR, ret);
 }
 
@@ -967,7 +985,8 @@
     Parcel data, reply;
 
     IPCThreadState::self()->setCallingWorkSourceUid(100);
-    int64_t previousWorkSource = IPCThreadState::self()->clearCallingWorkSource();
+    int64_t token = IPCThreadState::self()->clearCallingWorkSource();
+    int32_t previousWorkSource = (int32_t)token;
     data.writeInterfaceToken(binderLibTestServiceName);
     ret = m_server->transact(BINDER_LIB_TEST_GET_WORK_SOURCE_TRANSACTION, data, &reply);
 
@@ -989,9 +1008,58 @@
     ret = m_server->transact(BINDER_LIB_TEST_GET_WORK_SOURCE_TRANSACTION, data, &reply);
 
     EXPECT_EQ(100, reply.readInt32());
+    EXPECT_EQ(true, IPCThreadState::self()->shouldPropagateWorkSource());
     EXPECT_EQ(NO_ERROR, ret);
 }
 
+TEST_F(BinderLibTest, PropagateFlagSet)
+{
+    status_t ret;
+    Parcel data, reply;
+
+    IPCThreadState::self()->clearPropagateWorkSource();
+    IPCThreadState::self()->setCallingWorkSourceUid(100);
+    EXPECT_EQ(true, IPCThreadState::self()->shouldPropagateWorkSource());
+}
+
+TEST_F(BinderLibTest, PropagateFlagCleared)
+{
+    status_t ret;
+    Parcel data, reply;
+
+    IPCThreadState::self()->setCallingWorkSourceUid(100);
+    IPCThreadState::self()->clearPropagateWorkSource();
+    EXPECT_EQ(false, IPCThreadState::self()->shouldPropagateWorkSource());
+}
+
+TEST_F(BinderLibTest, PropagateFlagRestored)
+{
+    status_t ret;
+    Parcel data, reply;
+
+    int token = IPCThreadState::self()->setCallingWorkSourceUid(100);
+    IPCThreadState::self()->restoreCallingWorkSource(token);
+
+    EXPECT_EQ(false, IPCThreadState::self()->shouldPropagateWorkSource());
+}
+
+TEST_F(BinderLibTest, WorkSourcePropagatedForAllFollowingBinderCalls)
+{
+    IPCThreadState::self()->setCallingWorkSourceUid(100);
+
+    Parcel data, reply;
+    status_t ret;
+    data.writeInterfaceToken(binderLibTestServiceName);
+    ret = m_server->transact(BINDER_LIB_TEST_GET_WORK_SOURCE_TRANSACTION, data, &reply);
+
+    Parcel data2, reply2;
+    status_t ret2;
+    data2.writeInterfaceToken(binderLibTestServiceName);
+    ret2 = m_server->transact(BINDER_LIB_TEST_GET_WORK_SOURCE_TRANSACTION, data2, &reply2);
+    EXPECT_EQ(100, reply2.readInt32());
+    EXPECT_EQ(NO_ERROR, ret2);
+}
+
 class BinderLibTestService : public BBinder
 {
     public: