RELAND MediaCodec: don't cache buffers until requested

Do not cache input/output buffers until explicitly requested.

Track validity of buffers before getInput/OutputBuffers so that the
dequeued buffers are marked accessible when the array is returned
to the client.

Bug: 189889230
Bug: 236106096
Test: manual
Change-Id: I6892f02853f00c3e355c827ec943fbdf183469f1
diff --git a/media/java/android/media/MediaCodec.java b/media/java/android/media/MediaCodec.java
index 72dd2bd..102d89b 100644
--- a/media/java/android/media/MediaCodec.java
+++ b/media/java/android/media/MediaCodec.java
@@ -42,6 +42,7 @@
 import java.nio.ReadOnlyBufferException;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.BitSet;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -1794,7 +1795,7 @@
                     synchronized(mBufferLock) {
                         switch (mBufferMode) {
                             case BUFFER_MODE_LEGACY:
-                                validateInputByteBuffer(mCachedInputBuffers, index);
+                                validateInputByteBufferLocked(mCachedInputBuffers, index);
                                 break;
                             case BUFFER_MODE_BLOCK:
                                 while (mQueueRequests.size() <= index) {
@@ -1823,7 +1824,7 @@
                     synchronized(mBufferLock) {
                         switch (mBufferMode) {
                             case BUFFER_MODE_LEGACY:
-                                validateOutputByteBuffer(mCachedOutputBuffers, index, info);
+                                validateOutputByteBufferLocked(mCachedOutputBuffers, index, info);
                                 break;
                             case BUFFER_MODE_BLOCK:
                                 while (mOutputFrames.size() <= index) {
@@ -2282,10 +2283,6 @@
      */
     public final void start() {
         native_start();
-        synchronized(mBufferLock) {
-            cacheBuffers(true /* input */);
-            cacheBuffers(false /* input */);
-        }
     }
     private native final void native_start();
 
@@ -2342,8 +2339,10 @@
      */
     public final void flush() {
         synchronized(mBufferLock) {
-            invalidateByteBuffers(mCachedInputBuffers);
-            invalidateByteBuffers(mCachedOutputBuffers);
+            invalidateByteBuffersLocked(mCachedInputBuffers);
+            invalidateByteBuffersLocked(mCachedOutputBuffers);
+            mValidInputIndices.clear();
+            mValidOutputIndices.clear();
             mDequeuedInputBuffers.clear();
             mDequeuedOutputBuffers.clear();
         }
@@ -2627,14 +2626,14 @@
                         + "is not compatible with CONFIGURE_FLAG_USE_BLOCK_MODEL. "
                         + "Please use getQueueRequest() to queue buffers");
             }
-            invalidateByteBuffer(mCachedInputBuffers, index);
+            invalidateByteBufferLocked(mCachedInputBuffers, index, true /* input */);
             mDequeuedInputBuffers.remove(index);
         }
         try {
             native_queueInputBuffer(
                     index, offset, size, presentationTimeUs, flags);
         } catch (CryptoException | IllegalStateException e) {
-            revalidateByteBuffer(mCachedInputBuffers, index);
+            revalidateByteBuffer(mCachedInputBuffers, index, true /* input */);
             throw e;
         }
     }
@@ -2897,14 +2896,14 @@
                         + "is not compatible with CONFIGURE_FLAG_USE_BLOCK_MODEL. "
                         + "Please use getQueueRequest() to queue buffers");
             }
-            invalidateByteBuffer(mCachedInputBuffers, index);
+            invalidateByteBufferLocked(mCachedInputBuffers, index, true /* input */);
             mDequeuedInputBuffers.remove(index);
         }
         try {
             native_queueSecureInputBuffer(
                     index, offset, info, presentationTimeUs, flags);
         } catch (CryptoException | IllegalStateException e) {
-            revalidateByteBuffer(mCachedInputBuffers, index);
+            revalidateByteBuffer(mCachedInputBuffers, index, true /* input */);
             throw e;
         }
     }
@@ -2938,7 +2937,7 @@
         int res = native_dequeueInputBuffer(timeoutUs);
         if (res >= 0) {
             synchronized(mBufferLock) {
-                validateInputByteBuffer(mCachedInputBuffers, res);
+                validateInputByteBufferLocked(mCachedInputBuffers, res);
             }
         }
         return res;
@@ -3535,10 +3534,10 @@
         int res = native_dequeueOutputBuffer(info, timeoutUs);
         synchronized (mBufferLock) {
             if (res == INFO_OUTPUT_BUFFERS_CHANGED) {
-                cacheBuffers(false /* input */);
+                cacheBuffersLocked(false /* input */);
             } else if (res >= 0) {
-                validateOutputByteBuffer(mCachedOutputBuffers, res, info);
-                if (mHasSurface) {
+                validateOutputByteBufferLocked(mCachedOutputBuffers, res, info);
+                if (mHasSurface || mCachedOutputBuffers == null) {
                     mDequeuedOutputInfos.put(res, info.dup());
                 }
             }
@@ -3632,9 +3631,9 @@
         synchronized(mBufferLock) {
             switch (mBufferMode) {
                 case BUFFER_MODE_LEGACY:
-                    invalidateByteBuffer(mCachedOutputBuffers, index);
+                    invalidateByteBufferLocked(mCachedOutputBuffers, index, false /* input */);
                     mDequeuedOutputBuffers.remove(index);
-                    if (mHasSurface) {
+                    if (mHasSurface || mCachedOutputBuffers == null) {
                         info = mDequeuedOutputInfos.remove(index);
                     }
                     break;
@@ -3786,15 +3785,24 @@
 
     private ByteBuffer[] mCachedInputBuffers;
     private ByteBuffer[] mCachedOutputBuffers;
+    private BitSet mValidInputIndices = new BitSet();
+    private BitSet mValidOutputIndices = new BitSet();
+
     private final BufferMap mDequeuedInputBuffers = new BufferMap();
     private final BufferMap mDequeuedOutputBuffers = new BufferMap();
     private final Map<Integer, BufferInfo> mDequeuedOutputInfos =
         new HashMap<Integer, BufferInfo>();
     final private Object mBufferLock;
 
-    private final void invalidateByteBuffer(
-            @Nullable ByteBuffer[] buffers, int index) {
-        if (buffers != null && index >= 0 && index < buffers.length) {
+    private void invalidateByteBufferLocked(
+            @Nullable ByteBuffer[] buffers, int index, boolean input) {
+        if (buffers == null) {
+            if (index < 0) {
+                throw new IllegalStateException("index is negative (" + index + ")");
+            }
+            BitSet indices = input ? mValidInputIndices : mValidOutputIndices;
+            indices.clear(index);
+        } else if (index >= 0 && index < buffers.length) {
             ByteBuffer buffer = buffers[index];
             if (buffer != null) {
                 buffer.setAccessible(false);
@@ -3802,9 +3810,14 @@
         }
     }
 
-    private final void validateInputByteBuffer(
+    private void validateInputByteBufferLocked(
             @Nullable ByteBuffer[] buffers, int index) {
-        if (buffers != null && index >= 0 && index < buffers.length) {
+        if (buffers == null) {
+            if (index < 0) {
+                throw new IllegalStateException("index is negative (" + index + ")");
+            }
+            mValidInputIndices.set(index);
+        } else if (index >= 0 && index < buffers.length) {
             ByteBuffer buffer = buffers[index];
             if (buffer != null) {
                 buffer.setAccessible(true);
@@ -3813,10 +3826,16 @@
         }
     }
 
-    private final void revalidateByteBuffer(
-            @Nullable ByteBuffer[] buffers, int index) {
+    private void revalidateByteBuffer(
+            @Nullable ByteBuffer[] buffers, int index, boolean input) {
         synchronized(mBufferLock) {
-            if (buffers != null && index >= 0 && index < buffers.length) {
+            if (buffers == null) {
+                if (index < 0) {
+                    throw new IllegalStateException("index is negative (" + index + ")");
+                }
+                BitSet indices = input ? mValidInputIndices : mValidOutputIndices;
+                indices.set(index);
+            } else if (index >= 0 && index < buffers.length) {
                 ByteBuffer buffer = buffers[index];
                 if (buffer != null) {
                     buffer.setAccessible(true);
@@ -3825,9 +3844,14 @@
         }
     }
 
-    private final void validateOutputByteBuffer(
+    private void validateOutputByteBufferLocked(
             @Nullable ByteBuffer[] buffers, int index, @NonNull BufferInfo info) {
-        if (buffers != null && index >= 0 && index < buffers.length) {
+        if (buffers == null) {
+            if (index < 0) {
+                throw new IllegalStateException("index is negative (" + index + ")");
+            }
+            mValidOutputIndices.set(index);
+        } else if (index >= 0 && index < buffers.length) {
             ByteBuffer buffer = buffers[index];
             if (buffer != null) {
                 buffer.setAccessible(true);
@@ -3836,7 +3860,7 @@
         }
     }
 
-    private final void invalidateByteBuffers(@Nullable ByteBuffer[] buffers) {
+    private void invalidateByteBuffersLocked(@Nullable ByteBuffer[] buffers) {
         if (buffers != null) {
             for (ByteBuffer buffer: buffers) {
                 if (buffer != null) {
@@ -3846,27 +3870,29 @@
         }
     }
 
-    private final void freeByteBuffer(@Nullable ByteBuffer buffer) {
+    private void freeByteBufferLocked(@Nullable ByteBuffer buffer) {
         if (buffer != null /* && buffer.isDirect() */) {
             // all of our ByteBuffers are direct
             java.nio.NioUtils.freeDirectBuffer(buffer);
         }
     }
 
-    private final void freeByteBuffers(@Nullable ByteBuffer[] buffers) {
+    private void freeByteBuffersLocked(@Nullable ByteBuffer[] buffers) {
         if (buffers != null) {
             for (ByteBuffer buffer: buffers) {
-                freeByteBuffer(buffer);
+                freeByteBufferLocked(buffer);
             }
         }
     }
 
-    private final void freeAllTrackedBuffers() {
+    private void freeAllTrackedBuffers() {
         synchronized(mBufferLock) {
-            freeByteBuffers(mCachedInputBuffers);
-            freeByteBuffers(mCachedOutputBuffers);
+            freeByteBuffersLocked(mCachedInputBuffers);
+            freeByteBuffersLocked(mCachedOutputBuffers);
             mCachedInputBuffers = null;
             mCachedOutputBuffers = null;
+            mValidInputIndices.clear();
+            mValidOutputIndices.clear();
             mDequeuedInputBuffers.clear();
             mDequeuedOutputBuffers.clear();
             mQueueRequests.clear();
@@ -3874,14 +3900,31 @@
         }
     }
 
-    private final void cacheBuffers(boolean input) {
+    private void cacheBuffersLocked(boolean input) {
         ByteBuffer[] buffers = null;
         try {
             buffers = getBuffers(input);
-            invalidateByteBuffers(buffers);
+            invalidateByteBuffersLocked(buffers);
         } catch (IllegalStateException e) {
             // we don't get buffers in async mode
         }
+        if (buffers != null) {
+            BitSet indices = input ? mValidInputIndices : mValidOutputIndices;
+            for (int i = 0; i < buffers.length; ++i) {
+                ByteBuffer buffer = buffers[i];
+                if (buffer == null || !indices.get(i)) {
+                    continue;
+                }
+                buffer.setAccessible(true);
+                if (!input) {
+                    BufferInfo info = mDequeuedOutputInfos.get(i);
+                    if (info != null) {
+                        buffer.limit(info.offset + info.size).position(info.offset);
+                    }
+                }
+            }
+            indices.clear();
+        }
         if (input) {
             mCachedInputBuffers = buffers;
         } else {
@@ -3917,6 +3960,9 @@
                         + "objects and attach to QueueRequest objects.");
             }
             if (mCachedInputBuffers == null) {
+                cacheBuffersLocked(true /* input */);
+            }
+            if (mCachedInputBuffers == null) {
                 throw new IllegalStateException();
             }
             // FIXME: check codec status
@@ -3955,6 +4001,9 @@
                         + "Please use getOutputFrame to get output frames.");
             }
             if (mCachedOutputBuffers == null) {
+                cacheBuffersLocked(false /* input */);
+            }
+            if (mCachedOutputBuffers == null) {
                 throw new IllegalStateException();
             }
             // FIXME: check codec status
@@ -3992,7 +4041,7 @@
         }
         ByteBuffer newBuffer = getBuffer(true /* input */, index);
         synchronized (mBufferLock) {
-            invalidateByteBuffer(mCachedInputBuffers, index);
+            invalidateByteBufferLocked(mCachedInputBuffers, index, true /* input */);
             mDequeuedInputBuffers.put(index, newBuffer);
         }
         return newBuffer;
@@ -4029,7 +4078,7 @@
         }
         Image newImage = getImage(true /* input */, index);
         synchronized (mBufferLock) {
-            invalidateByteBuffer(mCachedInputBuffers, index);
+            invalidateByteBufferLocked(mCachedInputBuffers, index, true /* input */);
             mDequeuedInputBuffers.put(index, newImage);
         }
         return newImage;
@@ -4065,7 +4114,7 @@
         }
         ByteBuffer newBuffer = getBuffer(false /* input */, index);
         synchronized (mBufferLock) {
-            invalidateByteBuffer(mCachedOutputBuffers, index);
+            invalidateByteBufferLocked(mCachedOutputBuffers, index, false /* input */);
             mDequeuedOutputBuffers.put(index, newBuffer);
         }
         return newBuffer;
@@ -4100,7 +4149,7 @@
         }
         Image newImage = getImage(false /* input */, index);
         synchronized (mBufferLock) {
-            invalidateByteBuffer(mCachedOutputBuffers, index);
+            invalidateByteBufferLocked(mCachedOutputBuffers, index, false /* input */);
             mDequeuedOutputBuffers.put(index, newImage);
         }
         return newImage;