Transcoder: Improve AV transcoding speed by enforcing sequential sample access.

MediaSampleReader was bottlenecking the transcoding pipeline due to
non-sequential sample access. This commit adds an option to the sample
reader to enforce sequential sample access by blocking reads until
the underlying extractor advances to that specific track.
Follow-up: b/165374867 Make MediaSampleWriter robust against buffering track transcoders

Fixes: 160268606
Test: Transcoder unit tests, and benchmark tests.
Change-Id: Id2a363d06df927ea3e547462c52803594e0511e1
diff --git a/media/libmediatranscoding/transcoder/MediaSampleReaderNDK.cpp b/media/libmediatranscoding/transcoder/MediaSampleReaderNDK.cpp
index 6a00a10..53d567e 100644
--- a/media/libmediatranscoding/transcoder/MediaSampleReaderNDK.cpp
+++ b/media/libmediatranscoding/transcoder/MediaSampleReaderNDK.cpp
@@ -22,7 +22,6 @@
 
 #include <algorithm>
 #include <cmath>
-#include <vector>
 
 namespace android {
 
@@ -47,12 +46,6 @@
     }
 
     auto sampleReader = std::shared_ptr<MediaSampleReaderNDK>(new MediaSampleReaderNDK(extractor));
-    status = sampleReader->init();
-    if (status != AMEDIA_OK) {
-        LOG(ERROR) << "MediaSampleReaderNDK::init returned error: " << status;
-        return nullptr;
-    }
-
     return sampleReader;
 }
 
@@ -60,39 +53,42 @@
       : mExtractor(extractor), mTrackCount(AMediaExtractor_getTrackCount(mExtractor)) {
     if (mTrackCount > 0) {
         mTrackCursors.resize(mTrackCount);
-        mTrackCursors.resize(mTrackCount);
     }
 }
 
-media_status_t MediaSampleReaderNDK::init() {
-    for (size_t trackIndex = 0; trackIndex < mTrackCount; trackIndex++) {
-        media_status_t status = AMediaExtractor_selectTrack(mExtractor, trackIndex);
-        if (status != AMEDIA_OK) {
-            LOG(ERROR) << "AMediaExtractor_selectTrack returned error: " << status;
-            return status;
-        }
-    }
-
-    mExtractorTrackIndex = AMediaExtractor_getSampleTrackIndex(mExtractor);
-    if (mExtractorTrackIndex >= 0) {
-        mTrackCursors[mExtractorTrackIndex].current.set(mExtractorSampleIndex,
-                                                        AMediaExtractor_getSampleTime(mExtractor));
-    } else if (mTrackCount > 0) {
-        // The extractor track index is only allowed to be invalid if there are no tracks.
-        LOG(ERROR) << "Track index " << mExtractorTrackIndex << " is invalid for track count "
-                   << mTrackCount;
-        return AMEDIA_ERROR_MALFORMED;
-    }
-
-    return AMEDIA_OK;
-}
-
 MediaSampleReaderNDK::~MediaSampleReaderNDK() {
     if (mExtractor != nullptr) {
         AMediaExtractor_delete(mExtractor);
     }
 }
 
+void MediaSampleReaderNDK::advanceTrack_l(int trackIndex) {
+    if (!mEnforceSequentialAccess) {
+        // Note: Positioning the extractor before advancing the track is needed for two reasons:
+        // 1. To enable multiple advances without explicitly letting the extractor catch up.
+        // 2. To prevent the extractor from being farther than "next".
+        (void)moveToTrack_l(trackIndex);
+    }
+
+    SampleCursor& cursor = mTrackCursors[trackIndex];
+    cursor.previous = cursor.current;
+    cursor.current = cursor.next;
+    cursor.next.reset();
+
+    if (mEnforceSequentialAccess && trackIndex == mExtractorTrackIndex) {
+        while (advanceExtractor_l()) {
+            SampleCursor& cursor = mTrackCursors[mExtractorTrackIndex];
+            if (cursor.current.isSet && cursor.current.index == mExtractorSampleIndex) {
+                if (mExtractorTrackIndex != trackIndex) {
+                    mTrackSignals[mExtractorTrackIndex].notify_all();
+                }
+                break;
+            }
+        }
+    }
+    return;
+}
+
 bool MediaSampleReaderNDK::advanceExtractor_l() {
     // Reset the "next" sample time whenever the extractor advances past a sample that is current,
     // to ensure that "next" is appropriately updated when the extractor advances over the next
@@ -103,6 +99,10 @@
     }
 
     if (!AMediaExtractor_advance(mExtractor)) {
+        mEosReached = true;
+        for (auto it = mTrackSignals.begin(); it != mTrackSignals.end(); ++it) {
+            it->second.notify_all();
+        }
         return false;
     }
 
@@ -117,6 +117,7 @@
             cursor.next.set(mExtractorSampleIndex, AMediaExtractor_getSampleTime(mExtractor));
         }
     }
+
     return true;
 }
 
@@ -150,38 +151,15 @@
     return AMEDIA_OK;
 }
 
-void MediaSampleReaderNDK::advanceTrack(int trackIndex) {
-    std::scoped_lock lock(mExtractorMutex);
-
-    if (trackIndex < 0 || trackIndex >= mTrackCount) {
-        LOG(ERROR) << "Invalid trackIndex " << trackIndex << " for trackCount " << mTrackCount;
-        return;
-    }
-
-    // Note: Positioning the extractor before advancing the track is needed for two reasons:
-    // 1. To enable multiple advances without explicitly letting the extractor catch up.
-    // 2. To prevent the extractor from being farther than "next".
-    (void)positionExtractorForTrack_l(trackIndex);
-
-    SampleCursor& cursor = mTrackCursors[trackIndex];
-    cursor.previous = cursor.current;
-    cursor.current = cursor.next;
-    cursor.next.reset();
-}
-
-media_status_t MediaSampleReaderNDK::positionExtractorForTrack_l(int trackIndex) {
-    media_status_t status = AMEDIA_OK;
-    const SampleCursor& cursor = mTrackCursors[trackIndex];
-
-    // Seek backwards if the extractor is ahead of the current time.
-    if (cursor.current.isSet && mExtractorSampleIndex > cursor.current.index) {
-        status = seekExtractorBackwards_l(cursor.current.timeStampUs, trackIndex,
-                                          cursor.current.index);
+media_status_t MediaSampleReaderNDK::moveToSample_l(SamplePosition& pos, int trackIndex) {
+    // Seek backwards if the extractor is ahead of the sample.
+    if (pos.isSet && mExtractorSampleIndex > pos.index) {
+        media_status_t status = seekExtractorBackwards_l(pos.timeStampUs, trackIndex, pos.index);
         if (status != AMEDIA_OK) return status;
     }
 
-    // Advance until extractor points to the current sample.
-    while (!(cursor.current.isSet && cursor.current.index == mExtractorSampleIndex)) {
+    // Advance until extractor points to the sample.
+    while (!(pos.isSet && pos.index == mExtractorSampleIndex)) {
         if (!advanceExtractor_l()) {
             return AMEDIA_ERROR_END_OF_STREAM;
         }
@@ -190,28 +168,129 @@
     return AMEDIA_OK;
 }
 
-media_status_t MediaSampleReaderNDK::getEstimatedBitrateForTrack(int trackIndex, int32_t* bitrate) {
+media_status_t MediaSampleReaderNDK::moveToTrack_l(int trackIndex) {
+    return moveToSample_l(mTrackCursors[trackIndex].current, trackIndex);
+}
+
+media_status_t MediaSampleReaderNDK::waitForTrack_l(int trackIndex,
+                                                    std::unique_lock<std::mutex>& lockHeld) {
+    while (trackIndex != mExtractorTrackIndex && !mEosReached && mEnforceSequentialAccess) {
+        mTrackSignals[trackIndex].wait(lockHeld);
+    }
+
+    if (mEosReached) {
+        return AMEDIA_ERROR_END_OF_STREAM;
+    }
+    return AMEDIA_OK;
+}
+
+media_status_t MediaSampleReaderNDK::primeExtractorForTrack_l(
+        int trackIndex, std::unique_lock<std::mutex>& lockHeld) {
+    if (mExtractorTrackIndex < 0) {
+        mExtractorTrackIndex = AMediaExtractor_getSampleTrackIndex(mExtractor);
+        if (mExtractorTrackIndex < 0) {
+            return AMEDIA_ERROR_END_OF_STREAM;
+        }
+        mTrackCursors[mExtractorTrackIndex].current.set(mExtractorSampleIndex,
+                                                        AMediaExtractor_getSampleTime(mExtractor));
+    }
+
+    if (mEnforceSequentialAccess) {
+        return waitForTrack_l(trackIndex, lockHeld);
+    } else {
+        return moveToTrack_l(trackIndex);
+    }
+}
+
+media_status_t MediaSampleReaderNDK::selectTrack(int trackIndex) {
     std::scoped_lock lock(mExtractorMutex);
-    media_status_t status = AMEDIA_OK;
 
     if (trackIndex < 0 || trackIndex >= mTrackCount) {
         LOG(ERROR) << "Invalid trackIndex " << trackIndex << " for trackCount " << mTrackCount;
         return AMEDIA_ERROR_INVALID_PARAMETER;
+    } else if (mTrackSignals.find(trackIndex) != mTrackSignals.end()) {
+        LOG(ERROR) << "TrackIndex " << trackIndex << " already selected";
+        return AMEDIA_ERROR_INVALID_PARAMETER;
+    } else if (mExtractorTrackIndex >= 0) {
+        LOG(ERROR) << "Tracks must be selected before sample reading begins.";
+        return AMEDIA_ERROR_UNSUPPORTED;
+    }
+
+    media_status_t status = AMediaExtractor_selectTrack(mExtractor, trackIndex);
+    if (status != AMEDIA_OK) {
+        LOG(ERROR) << "AMediaExtractor_selectTrack returned error: " << status;
+        return status;
+    }
+
+    mTrackSignals.emplace(std::piecewise_construct, std::forward_as_tuple(trackIndex),
+                          std::forward_as_tuple());
+    return AMEDIA_OK;
+}
+
+media_status_t MediaSampleReaderNDK::setEnforceSequentialAccess(bool enforce) {
+    std::scoped_lock lock(mExtractorMutex);
+
+    if (mEnforceSequentialAccess && !enforce) {
+        // If switching from enforcing to not enforcing sequential access there may be threads
+        // waiting that needs to be woken up.
+        for (auto it = mTrackSignals.begin(); it != mTrackSignals.end(); ++it) {
+            it->second.notify_all();
+        }
+    } else if (!mEnforceSequentialAccess && enforce && mExtractorTrackIndex >= 0) {
+        // If switching from not enforcing to enforcing sequential access the extractor needs to be
+        // positioned for the track farthest behind so that it won't get stuck waiting.
+        struct {
+            SamplePosition* pos = nullptr;
+            int trackIndex = -1;
+        } earliestSample;
+
+        for (int trackIndex = 0; trackIndex < mTrackCount; ++trackIndex) {
+            SamplePosition& lastKnownTrackPosition = mTrackCursors[trackIndex].current.isSet
+                                                             ? mTrackCursors[trackIndex].current
+                                                             : mTrackCursors[trackIndex].previous;
+
+            if (lastKnownTrackPosition.isSet) {
+                if (earliestSample.pos == nullptr ||
+                    earliestSample.pos->index > lastKnownTrackPosition.index) {
+                    earliestSample.pos = &lastKnownTrackPosition;
+                    earliestSample.trackIndex = trackIndex;
+                }
+            }
+        }
+
+        if (earliestSample.pos == nullptr) {
+            LOG(ERROR) << "No known sample position found";
+            return AMEDIA_ERROR_UNKNOWN;
+        }
+
+        media_status_t status = moveToSample_l(*earliestSample.pos, earliestSample.trackIndex);
+        if (status != AMEDIA_OK) return status;
+
+        while (!(mTrackCursors[mExtractorTrackIndex].current.isSet &&
+                 mTrackCursors[mExtractorTrackIndex].current.index == mExtractorSampleIndex)) {
+            if (!advanceExtractor_l()) {
+                return AMEDIA_ERROR_END_OF_STREAM;
+            }
+        }
+    }
+
+    mEnforceSequentialAccess = enforce;
+    return AMEDIA_OK;
+}
+
+media_status_t MediaSampleReaderNDK::getEstimatedBitrateForTrack(int trackIndex, int32_t* bitrate) {
+    std::scoped_lock lock(mExtractorMutex);
+    media_status_t status = AMEDIA_OK;
+
+    if (mTrackSignals.find(trackIndex) == mTrackSignals.end()) {
+        LOG(ERROR) << "Track is not selected.";
+        return AMEDIA_ERROR_INVALID_PARAMETER;
     } else if (bitrate == nullptr) {
         LOG(ERROR) << "bitrate pointer is NULL.";
         return AMEDIA_ERROR_INVALID_PARAMETER;
-    }
-
-    // Rewind the extractor and sample from the beginning of the file.
-    if (mExtractorSampleIndex > 0) {
-        status = AMediaExtractor_seekTo(mExtractor, 0, AMEDIAEXTRACTOR_SEEK_PREVIOUS_SYNC);
-        if (status != AMEDIA_OK) {
-            LOG(ERROR) << "Unable to reset extractor: " << status;
-            return status;
-        }
-
-        mExtractorTrackIndex = AMediaExtractor_getSampleTrackIndex(mExtractor);
-        mExtractorSampleIndex = 0;
+    } else if (mExtractorTrackIndex >= 0) {
+        LOG(ERROR) << "getEstimatedBitrateForTrack must be called before sample reading begins.";
+        return AMEDIA_ERROR_UNSUPPORTED;
     }
 
     // Sample the track.
@@ -222,7 +301,7 @@
     int64_t lastSampleTimeUs = 0;
 
     do {
-        if (mExtractorTrackIndex == trackIndex) {
+        if (AMediaExtractor_getSampleTrackIndex(mExtractor) == trackIndex) {
             lastSampleTimeUs = AMediaExtractor_getSampleTime(mExtractor);
             if (totalSampleSize == 0) {
                 firstSampleTimeUs = lastSampleTimeUs;
@@ -231,7 +310,15 @@
             lastSampleSize = AMediaExtractor_getSampleSize(mExtractor);
             totalSampleSize += lastSampleSize;
         }
-    } while ((lastSampleTimeUs - firstSampleTimeUs) < kSamplingDurationUs && advanceExtractor_l());
+    } while ((lastSampleTimeUs - firstSampleTimeUs) < kSamplingDurationUs &&
+             AMediaExtractor_advance(mExtractor));
+
+    // Reset the extractor to the beginning.
+    status = AMediaExtractor_seekTo(mExtractor, 0, AMEDIAEXTRACTOR_SEEK_PREVIOUS_SYNC);
+    if (status != AMEDIA_OK) {
+        LOG(ERROR) << "Unable to reset extractor: " << status;
+        return status;
+    }
 
     int64_t durationUs = 0;
     const int64_t sampledDurationUs = lastSampleTimeUs - firstSampleTimeUs;
@@ -263,17 +350,17 @@
 }
 
 media_status_t MediaSampleReaderNDK::getSampleInfoForTrack(int trackIndex, MediaSampleInfo* info) {
-    std::scoped_lock lock(mExtractorMutex);
+    std::unique_lock<std::mutex> lock(mExtractorMutex);
 
-    if (trackIndex < 0 || trackIndex >= mTrackCount) {
-        LOG(ERROR) << "Invalid trackIndex " << trackIndex << " for trackCount " << mTrackCount;
+    if (mTrackSignals.find(trackIndex) == mTrackSignals.end()) {
+        LOG(ERROR) << "Track not selected.";
         return AMEDIA_ERROR_INVALID_PARAMETER;
     } else if (info == nullptr) {
         LOG(ERROR) << "MediaSampleInfo pointer is NULL.";
         return AMEDIA_ERROR_INVALID_PARAMETER;
     }
 
-    media_status_t status = positionExtractorForTrack_l(trackIndex);
+    media_status_t status = primeExtractorForTrack_l(trackIndex, lock);
     if (status == AMEDIA_OK) {
         info->presentationTimeUs = AMediaExtractor_getSampleTime(mExtractor);
         info->flags = AMediaExtractor_getSampleFlags(mExtractor);
@@ -283,24 +370,25 @@
         info->flags = SAMPLE_FLAG_END_OF_STREAM;
         info->size = 0;
     }
-
     return status;
 }
 
 media_status_t MediaSampleReaderNDK::readSampleDataForTrack(int trackIndex, uint8_t* buffer,
                                                             size_t bufferSize) {
-    std::scoped_lock lock(mExtractorMutex);
+    std::unique_lock<std::mutex> lock(mExtractorMutex);
 
-    if (trackIndex < 0 || trackIndex >= mTrackCount) {
-        LOG(ERROR) << "Invalid trackIndex " << trackIndex << " for trackCount " << mTrackCount;
+    if (mTrackSignals.find(trackIndex) == mTrackSignals.end()) {
+        LOG(ERROR) << "Track not selected.";
         return AMEDIA_ERROR_INVALID_PARAMETER;
     } else if (buffer == nullptr) {
         LOG(ERROR) << "buffer pointer is NULL";
         return AMEDIA_ERROR_INVALID_PARAMETER;
     }
 
-    media_status_t status = positionExtractorForTrack_l(trackIndex);
-    if (status != AMEDIA_OK) return status;
+    media_status_t status = primeExtractorForTrack_l(trackIndex, lock);
+    if (status != AMEDIA_OK) {
+        return status;
+    }
 
     ssize_t sampleSize = AMediaExtractor_getSampleSize(mExtractor);
     if (bufferSize < sampleSize) {
@@ -314,9 +402,21 @@
         return AMEDIA_ERROR_INVALID_PARAMETER;
     }
 
+    advanceTrack_l(trackIndex);
+
     return AMEDIA_OK;
 }
 
+void MediaSampleReaderNDK::advanceTrack(int trackIndex) {
+    std::scoped_lock lock(mExtractorMutex);
+
+    if (mTrackSignals.find(trackIndex) != mTrackSignals.end()) {
+        advanceTrack_l(trackIndex);
+    } else {
+        LOG(ERROR) << "Trying to advance a track that is not selected (#" << trackIndex << ")";
+    }
+}
+
 AMediaFormat* MediaSampleReaderNDK::getFileFormat() {
     return AMediaExtractor_getFileFormat(mExtractor);
 }