Merge "Make Readhead more usable in multithread."
diff --git a/staticlibs/hostdevice/com/android/net/module/util/TrackRecord.kt b/staticlibs/hostdevice/com/android/net/module/util/TrackRecord.kt
index b647d99..efd77d1 100644
--- a/staticlibs/hostdevice/com/android/net/module/util/TrackRecord.kt
+++ b/staticlibs/hostdevice/com/android/net/module/util/TrackRecord.kt
@@ -19,6 +19,7 @@
 import java.util.concurrent.TimeUnit
 import java.util.concurrent.locks.Condition
 import java.util.concurrent.locks.ReentrantLock
+import java.util.concurrent.locks.StampedLock
 import kotlin.concurrent.withLock
 
 /**
@@ -137,12 +138,6 @@
      * instance can also be used concurrently. ReadHead maintains the current index that is
      * the next to be read, and calls this the "mark".
      *
-     * A ReadHead delegates all TrackRecord methods to its associated ArrayTrackRecord, and
-     * inherits its thread-safe properties. However, the additional methods that ReadHead
-     * offers on top of TrackRecord do not share these properties and can only be used by
-     * the thread that created the ReadHead. This is because by construction it does not
-     * make sense to use a ReadHead on multiple threads concurrently (see below for details).
-     *
      * In a ReadHead, {@link poll(Long, (E) -> Boolean)} works similarly to a LinkedBlockingQueue.
      * It can be called repeatedly and will return the elements as they arrive.
      *
@@ -162,21 +157,43 @@
      * The point is that the caller does not have to track the mark like it would have to if
      * it was using ArrayTrackRecord directly.
      *
-     * Note that if multiple threads were using poll() concurrently on the same ReadHead, what
-     * happens to the mark and the return values could be well defined, but it could not
-     * be useful because there is no way to provide either a guarantee not to skip objects nor
-     * a guarantee about the mark position at the exit of poll(). This is even more true in the
-     * presence of a predicate to filter returned elements, because one thread might be
-     * filtering out the events the other is interested in.
-     * Instead, this use case is supported by creating multiple ReadHeads on the same instance
-     * of ArrayTrackRecord. Each ReadHead is then guaranteed to see all events always and
-     * guarantees are made on the value of the mark upon return. {@see poll(Long, (E) -> Boolean)}
-     * for details. Be careful to create each ReadHead on the thread it is meant to be used on.
+     * Thread safety :
+     * A ReadHead delegates all TrackRecord methods to its associated ArrayTrackRecord, and
+     * inherits its thread-safe properties for all the TrackRecord methods.
      *
-     * Users of a ReadHead can ask for the current position of the mark at any time. This mark
-     * can be used later to replay the history of events either on this ReadHead, on the associated
-     * ArrayTrackRecord or on another ReadHead associated with the same ArrayTrackRecord. It
-     * might look like this in the reader thread :
+     * Poll() operates under its own set of rules that only allow execution on multiple threads
+     * within constrained boundaries, and never concurrently or pseudo-concurrently. This is
+     * because concurrent calls to poll() fundamentally do not make sense. poll() will move
+     * the mark according to what events remained to be read by this read head, and therefore
+     * if multiple threads were calling poll() concurrently on the same ReadHead, what
+     * happens to the mark and the return values could not be useful because there is no way to
+     * provide either a guarantee not to skip objects nor a guarantee about the mark position at
+     * the exit of poll(). This is even more true in the presence of a predicate to filter
+     * returned elements, because one thread might be filtering out the events the other is
+     * interested in. For this reason, this class will fail-fast if any concurrent access is
+     * detected with ConcurrentAccessException.
+     * It is possible to use poll() on different threads as long as the following can be
+     * guaranteed : one thread must call poll() for the last time, then execute a write barrier,
+     * then the other thread must execute a read barrier before calling poll() for the first time.
+     * This allows in particular to call poll in @Before and @After methods in JUnit unit tests,
+     * because JUnit will enforce those barriers by creating the testing thread after executing
+     * @Before and joining the thread after executing @After.
+     *
+     * peek() can be used by multiple threads concurrently, but only if no thread is calling
+     * poll() outside of the boundaries above. For simplicity, it can be considered that peek()
+     * is safe to call only when poll() is safe to call.
+     *
+     * Polling concurrently from the same ArrayTrackRecord is supported by creating multiple
+     * ReadHeads on the same instance of ArrayTrackRecord (or of course by using ArrayTrackRecord
+     * directly). Each ReadHead is then guaranteed to see all events always and
+     * guarantees are made on the value of the mark upon return. {@see poll(Long, (E) -> Boolean)}
+     * for details. Be careful to create each ReadHead on the thread it is meant to be used on, or
+     * to have a clear synchronization point between creation and use.
+     *
+     * Users of a ReadHead can ask for the current position of the mark at any time, on a thread
+     * where it's safe to call peek(). This mark can be used later to replay the history of events
+     * either on this ReadHead, on the associated ArrayTrackRecord or on another ReadHead
+     * associated with the same ArrayTrackRecord. It might look like this in the reader thread :
      *
      * val markAtStart = record.mark
      * // Start processing interesting events
@@ -190,22 +207,39 @@
      * val specialElement = record.poll(timeout, markAtStart) { it.isSpecial() }
      */
     inner class ReadHead : TrackRecord<E> by this@ArrayTrackRecord {
-        private val owningThread = Thread.currentThread()
+        // This lock only controls access to the readHead member below. The ArrayTrackRecord
+        // object has its own synchronization following different (and more usual) semantics.
+        // See the comment on the ReadHead class for details.
+        private val slock = StampedLock()
         private var readHead = 0
 
         /**
          * @return the current value of the mark.
          */
         var mark
-            get() = readHead.also { checkThread() }
+            get() = checkThread { readHead }
             set(v: Int) = rewind(v)
         fun rewind(v: Int) {
-            checkThread()
+            val stamp = slock.tryWriteLock()
+            if (0L == stamp) concurrentAccessDetected()
             readHead = v
+            slock.unlockWrite(stamp)
         }
 
-        private fun checkThread() = check(Thread.currentThread() == owningThread) {
-            "Must be called by the thread that created this object"
+        private fun <T> checkThread(r: (Long) -> T): T {
+            // tryOptimisticRead is a read barrier, guarantees writes from other threads are visible
+            // after it
+            val stamp = slock.tryOptimisticRead()
+            val result = r(stamp)
+            // validate also performs a read barrier, guaranteeing that if validate returns true,
+            // then any change either happens-before tryOptimisticRead, or happens-after validate.
+            if (!slock.validate(stamp)) concurrentAccessDetected()
+            return result
+        }
+
+        private fun concurrentAccessDetected(): Nothing {
+            throw ConcurrentModificationException(
+                    "ReadHeads can't be used concurrently. Check your threading model.")
         }
 
         /**
@@ -225,21 +259,27 @@
          * @return an element matching the predicate, or null if timeout.
          */
         fun poll(timeoutMs: Long, predicate: (E) -> Boolean = { true }): E? {
-            checkThread()
-            lock.withLock {
-                val index = pollForIndexReadLocked(timeoutMs, readHead, predicate)
-                readHead = if (index < 0) size else index + 1
-                return getOrNull(index)
+            val stamp = slock.tryWriteLock()
+            if (0L == stamp) concurrentAccessDetected()
+            try {
+                lock.withLock {
+                    val index = pollForIndexReadLocked(timeoutMs, readHead, predicate)
+                    readHead = if (index < 0) size else index + 1
+                    return getOrNull(index)
+                }
+            } finally {
+                slock.unlockWrite(stamp)
             }
         }
 
         /**
          * Returns the first element after the mark or null. This never blocks.
          *
-         * This method can only be used by the thread that created this ManagedRecordingQueue.
-         * If used on another thread, this throws IllegalStateException.
+         * This method is subject to threading restrictions. It can be used concurrently on
+         * multiple threads but not if any other thread might be executing poll() at the same
+         * time. See the class comment for details.
          */
-        fun peek(): E? = getOrNull(readHead).also { checkThread() }
+        fun peek(): E? = checkThread { getOrNull(readHead) }
     }
 }