Merge "Fix possible race conditions during channel unregistration." into gingerbread
diff --git a/include/ui/InputDispatcher.h b/include/ui/InputDispatcher.h
index d3495fe..2505cb0 100644
--- a/include/ui/InputDispatcher.h
+++ b/include/ui/InputDispatcher.h
@@ -554,6 +554,8 @@
     // All registered connections mapped by receive pipe file descriptor.
     KeyedVector<int, sp<Connection> > mConnectionsByReceiveFd;
 
+    ssize_t getConnectionIndex(const sp<InputChannel>& inputChannel);
+
     // Active connections are connections that have a non-empty outbound queue.
     // We don't use a ref-counted pointer here because we explicitly abort connections
     // during unregistration which causes the connection's outbound queue to be cleared
diff --git a/libs/ui/InputDispatcher.cpp b/libs/ui/InputDispatcher.cpp
index b53f140..13030b5 100644
--- a/libs/ui/InputDispatcher.cpp
+++ b/libs/ui/InputDispatcher.cpp
@@ -433,8 +433,7 @@
     for (size_t i = 0; i < mCurrentInputTargets.size(); i++) {
         const InputTarget& inputTarget = mCurrentInputTargets.itemAt(i);
 
-        ssize_t connectionIndex = mConnectionsByReceiveFd.indexOfKey(
-                inputTarget.inputChannel->getReceivePipeFd());
+        ssize_t connectionIndex = getConnectionIndex(inputTarget.inputChannel);
         if (connectionIndex >= 0) {
             sp<Connection> connection = mConnectionsByReceiveFd.valueAt(connectionIndex);
             prepareDispatchCycleLocked(currentTime, connection, eventEntry, & inputTarget,
@@ -1367,12 +1366,10 @@
     LOGD("channel '%s' ~ registerInputChannel", inputChannel->getName().string());
 #endif
 
-    int receiveFd;
     { // acquire lock
         AutoMutex _l(mLock);
 
-        receiveFd = inputChannel->getReceivePipeFd();
-        if (mConnectionsByReceiveFd.indexOfKey(receiveFd) >= 0) {
+        if (getConnectionIndex(inputChannel) >= 0) {
             LOGW("Attempted to register already registered input channel '%s'",
                     inputChannel->getName().string());
             return BAD_VALUE;
@@ -1386,12 +1383,13 @@
             return status;
         }
 
+        int32_t receiveFd = inputChannel->getReceivePipeFd();
         mConnectionsByReceiveFd.add(receiveFd, connection);
 
+        mPollLoop->setCallback(receiveFd, POLLIN, handleReceiveCallback, this);
+
         runCommandsLockedInterruptible();
     } // release lock
-
-    mPollLoop->setCallback(receiveFd, POLLIN, handleReceiveCallback, this);
     return OK;
 }
 
@@ -1400,12 +1398,10 @@
     LOGD("channel '%s' ~ unregisterInputChannel", inputChannel->getName().string());
 #endif
 
-    int32_t receiveFd;
     { // acquire lock
         AutoMutex _l(mLock);
 
-        receiveFd = inputChannel->getReceivePipeFd();
-        ssize_t connectionIndex = mConnectionsByReceiveFd.indexOfKey(receiveFd);
+        ssize_t connectionIndex = getConnectionIndex(inputChannel);
         if (connectionIndex < 0) {
             LOGW("Attempted to unregister already unregistered input channel '%s'",
                     inputChannel->getName().string());
@@ -1417,20 +1413,32 @@
 
         connection->status = Connection::STATUS_ZOMBIE;
 
+        mPollLoop->removeCallback(inputChannel->getReceivePipeFd());
+
         nsecs_t currentTime = now();
         abortDispatchCycleLocked(currentTime, connection, true /*broken*/);
 
         runCommandsLockedInterruptible();
     } // release lock
 
-    mPollLoop->removeCallback(receiveFd);
-
     // Wake the poll loop because removing the connection may have changed the current
     // synchronization state.
     mPollLoop->wake();
     return OK;
 }
 
+ssize_t InputDispatcher::getConnectionIndex(const sp<InputChannel>& inputChannel) {
+    ssize_t connectionIndex = mConnectionsByReceiveFd.indexOfKey(inputChannel->getReceivePipeFd());
+    if (connectionIndex >= 0) {
+        sp<Connection> connection = mConnectionsByReceiveFd.valueAt(connectionIndex);
+        if (connection->inputChannel.get() == inputChannel.get()) {
+            return connectionIndex;
+        }
+    }
+
+    return -1;
+}
+
 void InputDispatcher::activateConnectionLocked(Connection* connection) {
     for (size_t i = 0; i < mActiveConnections.size(); i++) {
         if (mActiveConnections.itemAt(i) == connection) {