Transcoder: Fix error with short clips

- Fixes a bug in the sample reader where an error was incorrectly
reported if one track reached EOS right before switching to
sequential access.
- Adds more sample reader tests for different combinations of
sample access patterns and access modes.

Bug: 153453392
Fixes: 173643110
Test: Unit test (MediaSampleReaderNDKTests)
Change-Id: I3b683c5d8eb18a5b57d419ce113e08b40363ba9e
diff --git a/media/libmediatranscoding/transcoder/MediaSampleReaderNDK.cpp b/media/libmediatranscoding/transcoder/MediaSampleReaderNDK.cpp
index d2f6c40..92ba818 100644
--- a/media/libmediatranscoding/transcoder/MediaSampleReaderNDK.cpp
+++ b/media/libmediatranscoding/transcoder/MediaSampleReaderNDK.cpp
@@ -99,6 +99,7 @@
     }
 
     if (!AMediaExtractor_advance(mExtractor)) {
+        LOG(DEBUG) << "  EOS in advanceExtractor_l";
         mEosReached = true;
         for (auto it = mTrackSignals.begin(); it != mTrackSignals.end(); ++it) {
             it->second.notify_all();
@@ -137,6 +138,8 @@
         LOG(ERROR) << "Unable to seek to " << seekToTimeUs << ", target " << targetTimeUs;
         return status;
     }
+
+    mEosReached = false;
     mExtractorTrackIndex = AMediaExtractor_getSampleTrackIndex(mExtractor);
     int64_t sampleTimeUs = AMediaExtractor_getSampleTime(mExtractor);
 
@@ -233,6 +236,8 @@
 }
 
 media_status_t MediaSampleReaderNDK::setEnforceSequentialAccess(bool enforce) {
+    LOG(DEBUG) << "setEnforceSequentialAccess( " << enforce << " )";
+
     std::scoped_lock lock(mExtractorMutex);
 
     if (mEnforceSequentialAccess && !enforce) {
@@ -374,7 +379,11 @@
         info->presentationTimeUs = 0;
         info->flags = SAMPLE_FLAG_END_OF_STREAM;
         info->size = 0;
+        LOG(DEBUG) << "  getSampleInfoForTrack #" << trackIndex << ": End Of Stream";
+    } else {
+        LOG(ERROR) << "  getSampleInfoForTrack #" << trackIndex << ": Error " << status;
     }
+
     return status;
 }
 
diff --git a/media/libmediatranscoding/transcoder/tests/Android.bp b/media/libmediatranscoding/transcoder/tests/Android.bp
index 7ae6261..8ad583f 100644
--- a/media/libmediatranscoding/transcoder/tests/Android.bp
+++ b/media/libmediatranscoding/transcoder/tests/Android.bp
@@ -15,6 +15,8 @@
 
     shared_libs: [
         "libbase",
+        "libbinder_ndk",
+        "libcrypto",
         "libcutils",
         "libmediandk",
         "libmediatranscoder_asan",
@@ -59,7 +61,6 @@
     name: "MediaTrackTranscoderTests",
     defaults: ["testdefaults"],
     srcs: ["MediaTrackTranscoderTests.cpp"],
-    shared_libs: ["libbinder_ndk"],
 }
 
 // VideoTrackTranscoder unit test
@@ -74,7 +75,6 @@
     name: "PassthroughTrackTranscoderTests",
     defaults: ["testdefaults"],
     srcs: ["PassthroughTrackTranscoderTests.cpp"],
-    shared_libs: ["libcrypto"],
 }
 
 // MediaSampleWriter unit test
@@ -89,5 +89,4 @@
     name: "MediaTranscoderTests",
     defaults: ["testdefaults"],
     srcs: ["MediaTranscoderTests.cpp"],
-    shared_libs: ["libbinder_ndk"],
 }
diff --git a/media/libmediatranscoding/transcoder/tests/MediaSampleReaderNDKTests.cpp b/media/libmediatranscoding/transcoder/tests/MediaSampleReaderNDKTests.cpp
index 9c9c8b5..11af0b1 100644
--- a/media/libmediatranscoding/transcoder/tests/MediaSampleReaderNDKTests.cpp
+++ b/media/libmediatranscoding/transcoder/tests/MediaSampleReaderNDKTests.cpp
@@ -25,39 +25,166 @@
 #include <fcntl.h>
 #include <gtest/gtest.h>
 #include <media/MediaSampleReaderNDK.h>
+#include <openssl/md5.h>
 #include <utils/Timers.h>
 
 #include <cmath>
 #include <mutex>
 #include <thread>
 
-// TODO(b/153453392): Test more asset types and validate sample data from readSampleDataForTrack.
-// TODO(b/153453392): Test for sequential and parallel (single thread and multi thread) access.
-// TODO(b/153453392): Test for switching between sequential and parallel access in different points
-//  of time.
+// TODO(b/153453392): Test more asset types (frame reordering?).
 
 namespace android {
 
 #define SEC_TO_USEC(s) ((s)*1000 * 1000)
 
+/** Helper class for comparing sample data using checksums. */
+class Sample {
+public:
+    Sample(uint32_t flags, int64_t timestamp, size_t size, const uint8_t* buffer)
+          : mFlags{flags}, mTimestamp{timestamp}, mSize{size} {
+        initChecksum(buffer);
+    }
+
+    Sample(AMediaExtractor* extractor) {
+        mFlags = AMediaExtractor_getSampleFlags(extractor);
+        mTimestamp = AMediaExtractor_getSampleTime(extractor);
+        mSize = static_cast<size_t>(AMediaExtractor_getSampleSize(extractor));
+
+        auto buffer = std::make_unique<uint8_t[]>(mSize);
+        AMediaExtractor_readSampleData(extractor, buffer.get(), mSize);
+
+        initChecksum(buffer.get());
+    }
+
+    void initChecksum(const uint8_t* buffer) {
+        MD5_CTX md5Ctx;
+        MD5_Init(&md5Ctx);
+        MD5_Update(&md5Ctx, buffer, mSize);
+        MD5_Final(mChecksum, &md5Ctx);
+    }
+
+    bool operator==(const Sample& rhs) const {
+        return mSize == rhs.mSize && mFlags == rhs.mFlags && mTimestamp == rhs.mTimestamp &&
+               memcmp(mChecksum, rhs.mChecksum, MD5_DIGEST_LENGTH) == 0;
+    }
+
+    uint32_t mFlags;
+    int64_t mTimestamp;
+    size_t mSize;
+    uint8_t mChecksum[MD5_DIGEST_LENGTH];
+};
+
+/** Constant for selecting all samples. */
+static constexpr int SAMPLE_COUNT_ALL = -1;
+
+/**
+ * Utility class to test different sample access patterns combined with sequential or parallel
+ * sample access modes.
+ */
+class SampleAccessTester {
+public:
+    SampleAccessTester(int sourceFd, size_t fileSize) {
+        mSampleReader = MediaSampleReaderNDK::createFromFd(sourceFd, 0, fileSize);
+        EXPECT_TRUE(mSampleReader);
+
+        mTrackCount = mSampleReader->getTrackCount();
+
+        for (int trackIndex = 0; trackIndex < mTrackCount; trackIndex++) {
+            EXPECT_EQ(mSampleReader->selectTrack(trackIndex), AMEDIA_OK);
+        }
+
+        mSamples.resize(mTrackCount);
+        mTrackThreads.resize(mTrackCount);
+    }
+
+    void getSampleInfo(int trackIndex) {
+        MediaSampleInfo info;
+        media_status_t status = mSampleReader->getSampleInfoForTrack(trackIndex, &info);
+        EXPECT_EQ(status, AMEDIA_OK);
+    }
+
+    void readSamplesAsync(int trackIndex, int sampleCount) {
+        mTrackThreads[trackIndex] = std::thread{[this, trackIndex, sampleCount] {
+            int samplesRead = 0;
+            MediaSampleInfo info;
+            while (samplesRead < sampleCount || sampleCount == SAMPLE_COUNT_ALL) {
+                media_status_t status = mSampleReader->getSampleInfoForTrack(trackIndex, &info);
+                if (status != AMEDIA_OK) {
+                    EXPECT_EQ(status, AMEDIA_ERROR_END_OF_STREAM);
+                    EXPECT_TRUE((info.flags & SAMPLE_FLAG_END_OF_STREAM) != 0);
+                    break;
+                }
+                ASSERT_TRUE((info.flags & SAMPLE_FLAG_END_OF_STREAM) == 0);
+
+                auto buffer = std::make_unique<uint8_t[]>(info.size);
+                status = mSampleReader->readSampleDataForTrack(trackIndex, buffer.get(), info.size);
+                EXPECT_EQ(status, AMEDIA_OK);
+
+                mSampleMutex.lock();
+                const uint8_t* bufferPtr = buffer.get();
+                mSamples[trackIndex].emplace_back(info.flags, info.presentationTimeUs, info.size,
+                                                  bufferPtr);
+                mSampleMutex.unlock();
+                ++samplesRead;
+            }
+        }};
+    }
+
+    void readSamplesAsync(int sampleCount) {
+        for (int trackIndex = 0; trackIndex < mTrackCount; trackIndex++) {
+            readSamplesAsync(trackIndex, sampleCount);
+        }
+    }
+
+    void waitForTrack(int trackIndex) {
+        ASSERT_TRUE(mTrackThreads[trackIndex].joinable());
+        mTrackThreads[trackIndex].join();
+    }
+
+    void waitForTracks() {
+        for (int trackIndex = 0; trackIndex < mTrackCount; trackIndex++) {
+            waitForTrack(trackIndex);
+        }
+    }
+
+    void setEnforceSequentialAccess(bool enforce) {
+        media_status_t status = mSampleReader->setEnforceSequentialAccess(enforce);
+        EXPECT_EQ(status, AMEDIA_OK);
+    }
+
+    std::vector<std::vector<Sample>>& getSamples() { return mSamples; }
+
+    std::shared_ptr<MediaSampleReader> mSampleReader;
+    size_t mTrackCount;
+    std::mutex mSampleMutex;
+    std::vector<std::thread> mTrackThreads;
+    std::vector<std::vector<Sample>> mSamples;
+};
+
 class MediaSampleReaderNDKTests : public ::testing::Test {
 public:
     MediaSampleReaderNDKTests() { LOG(DEBUG) << "MediaSampleReaderNDKTests created"; }
 
     void SetUp() override {
         LOG(DEBUG) << "MediaSampleReaderNDKTests set up";
+
+        // Need to start a thread pool to prevent AMediaExtractor binder calls from starving
+        // (b/155663561).
+        ABinderProcess_startThreadPool();
+
         const char* sourcePath =
                 "/data/local/tmp/TranscodingTestAssets/cubicle_avc_480x240_aac_24KHz.mp4";
 
-        mExtractor = AMediaExtractor_new();
-        ASSERT_NE(mExtractor, nullptr);
-
         mSourceFd = open(sourcePath, O_RDONLY);
         ASSERT_GT(mSourceFd, 0);
 
         mFileSize = lseek(mSourceFd, 0, SEEK_END);
         lseek(mSourceFd, 0, SEEK_SET);
 
+        mExtractor = AMediaExtractor_new();
+        ASSERT_NE(mExtractor, nullptr);
+
         media_status_t status =
                 AMediaExtractor_setDataSourceFd(mExtractor, mSourceFd, 0, mFileSize);
         ASSERT_EQ(status, AMEDIA_OK);
@@ -68,14 +195,14 @@
         }
     }
 
-    void initExtractorTimestamps() {
-        // Save all sample timestamps, per track, as reported by the extractor.
-        mExtractorTimestamps.resize(mTrackCount);
+    void initExtractorSamples() {
+        if (mExtractorSamples.size() == mTrackCount) return;
+
+        // Save sample information, per track, as reported by the extractor.
+        mExtractorSamples.resize(mTrackCount);
         do {
             const int trackIndex = AMediaExtractor_getSampleTrackIndex(mExtractor);
-            const int64_t sampleTime = AMediaExtractor_getSampleTime(mExtractor);
-
-            mExtractorTimestamps[trackIndex].push_back(sampleTime);
+            mExtractorSamples[trackIndex].emplace_back(mExtractor);
         } while (AMediaExtractor_advance(mExtractor));
 
         AMediaExtractor_seekTo(mExtractor, 0, AMEDIAEXTRACTOR_SEEK_PREVIOUS_SYNC);
@@ -104,6 +231,22 @@
         return bitrates;
     }
 
+    void compareSamples(std::vector<std::vector<Sample>>& readerSamples) {
+        initExtractorSamples();
+        EXPECT_EQ(readerSamples.size(), mTrackCount);
+
+        for (int trackIndex = 0; trackIndex < mTrackCount; trackIndex++) {
+            LOG(DEBUG) << "Track " << trackIndex << ", comparing "
+                       << readerSamples[trackIndex].size() << " samples.";
+            EXPECT_EQ(readerSamples[trackIndex].size(), mExtractorSamples[trackIndex].size());
+            for (size_t sampleIndex = 0; sampleIndex < readerSamples[trackIndex].size();
+                 sampleIndex++) {
+                EXPECT_EQ(readerSamples[trackIndex][sampleIndex],
+                          mExtractorSamples[trackIndex][sampleIndex]);
+            }
+        }
+    }
+
     void TearDown() override {
         LOG(DEBUG) << "MediaSampleReaderNDKTests tear down";
         AMediaExtractor_delete(mExtractor);
@@ -116,58 +259,91 @@
     size_t mTrackCount;
     int mSourceFd;
     size_t mFileSize;
-    std::vector<std::vector<int64_t>> mExtractorTimestamps;
+    std::vector<std::vector<Sample>> mExtractorSamples;
 };
 
-TEST_F(MediaSampleReaderNDKTests, TestSampleTimes) {
-    LOG(DEBUG) << "TestSampleTimes Starts";
+/** Reads all samples from all tracks in parallel. */
+TEST_F(MediaSampleReaderNDKTests, TestParallelSampleAccess) {
+    LOG(DEBUG) << "TestParallelSampleAccess Starts";
 
-    std::shared_ptr<MediaSampleReader> sampleReader =
-            MediaSampleReaderNDK::createFromFd(mSourceFd, 0, mFileSize);
-    ASSERT_TRUE(sampleReader);
+    SampleAccessTester tester{mSourceFd, mFileSize};
+    tester.readSamplesAsync(SAMPLE_COUNT_ALL);
+    tester.waitForTracks();
+    compareSamples(tester.getSamples());
+}
 
-    for (int trackIndex = 0; trackIndex < mTrackCount; trackIndex++) {
-        EXPECT_EQ(sampleReader->selectTrack(trackIndex), AMEDIA_OK);
-    }
+/** Reads all samples from all tracks sequentially. */
+TEST_F(MediaSampleReaderNDKTests, TestSequentialSampleAccess) {
+    LOG(DEBUG) << "TestSequentialSampleAccess Starts";
 
-    // Initialize the extractor timestamps.
-    initExtractorTimestamps();
+    SampleAccessTester tester{mSourceFd, mFileSize};
+    tester.setEnforceSequentialAccess(true);
+    tester.readSamplesAsync(SAMPLE_COUNT_ALL);
+    tester.waitForTracks();
+    compareSamples(tester.getSamples());
+}
 
-    std::mutex timestampMutex;
-    std::vector<std::thread> trackThreads;
-    std::vector<std::vector<int64_t>> readerTimestamps(mTrackCount);
+/** Reads all samples from one track in parallel mode before switching to sequential mode. */
+TEST_F(MediaSampleReaderNDKTests, TestMixedSampleAccessTrackEOS) {
+    LOG(DEBUG) << "TestMixedSampleAccessTrackEOS Starts";
 
-    for (int trackIndex = 0; trackIndex < mTrackCount; trackIndex++) {
-        trackThreads.emplace_back([sampleReader, trackIndex, &timestampMutex, &readerTimestamps] {
-            MediaSampleInfo info;
-            while (true) {
-                media_status_t status = sampleReader->getSampleInfoForTrack(trackIndex, &info);
-                if (status != AMEDIA_OK) {
-                    EXPECT_EQ(status, AMEDIA_ERROR_END_OF_STREAM);
-                    EXPECT_TRUE((info.flags & SAMPLE_FLAG_END_OF_STREAM) != 0);
-                    break;
-                }
-                ASSERT_TRUE((info.flags & SAMPLE_FLAG_END_OF_STREAM) == 0);
-                timestampMutex.lock();
-                readerTimestamps[trackIndex].push_back(info.presentationTimeUs);
-                timestampMutex.unlock();
-                sampleReader->advanceTrack(trackIndex);
+    for (int readSampleInfoFlag = 0; readSampleInfoFlag <= 1; readSampleInfoFlag++) {
+        for (int trackIndToEOS = 0; trackIndToEOS < mTrackCount; ++trackIndToEOS) {
+            LOG(DEBUG) << "Testing EOS of track " << trackIndToEOS;
+
+            SampleAccessTester tester{mSourceFd, mFileSize};
+
+            // If the flag is set, read sample info from a different track before draining the track
+            // under test to force the reader to save the extractor position.
+            if (readSampleInfoFlag) {
+                tester.getSampleInfo((trackIndToEOS + 1) % mTrackCount);
             }
-        });
-    }
 
-    for (auto& thread : trackThreads) {
-        thread.join();
-    }
+            // Read all samples from one track before enabling sequential access
+            tester.readSamplesAsync(trackIndToEOS, SAMPLE_COUNT_ALL);
+            tester.waitForTrack(trackIndToEOS);
+            tester.setEnforceSequentialAccess(true);
 
-    for (int trackIndex = 0; trackIndex < mTrackCount; trackIndex++) {
-        LOG(DEBUG) << "Track " << trackIndex << ", comparing "
-                   << readerTimestamps[trackIndex].size() << " samples.";
-        EXPECT_EQ(readerTimestamps[trackIndex].size(), mExtractorTimestamps[trackIndex].size());
-        for (size_t sampleIndex = 0; sampleIndex < readerTimestamps[trackIndex].size();
-             sampleIndex++) {
-            EXPECT_EQ(readerTimestamps[trackIndex][sampleIndex],
-                      mExtractorTimestamps[trackIndex][sampleIndex]);
+            for (int trackIndex = 0; trackIndex < mTrackCount; ++trackIndex) {
+                if (trackIndex == trackIndToEOS) continue;
+
+                tester.readSamplesAsync(trackIndex, SAMPLE_COUNT_ALL);
+                tester.waitForTrack(trackIndex);
+            }
+
+            compareSamples(tester.getSamples());
+        }
+    }
+}
+
+/**
+ * Reads different combinations of sample counts from all tracks in parallel mode before switching
+ * to sequential mode and reading the rest of the samples.
+ */
+TEST_F(MediaSampleReaderNDKTests, TestMixedSampleAccess) {
+    LOG(DEBUG) << "TestMixedSampleAccess Starts";
+    initExtractorSamples();
+
+    for (int trackIndToTest = 0; trackIndToTest < mTrackCount; ++trackIndToTest) {
+        for (int sampleCount = 0; sampleCount <= (mExtractorSamples[trackIndToTest].size() + 1);
+             ++sampleCount) {
+            SampleAccessTester tester{mSourceFd, mFileSize};
+
+            for (int trackIndex = 0; trackIndex < mTrackCount; ++trackIndex) {
+                if (trackIndex == trackIndToTest) {
+                    tester.readSamplesAsync(trackIndex, sampleCount);
+                } else {
+                    tester.readSamplesAsync(trackIndex, mExtractorSamples[trackIndex].size() / 2);
+                }
+            }
+
+            tester.waitForTracks();
+            tester.setEnforceSequentialAccess(true);
+
+            tester.readSamplesAsync(SAMPLE_COUNT_ALL);
+            tester.waitForTracks();
+
+            compareSamples(tester.getSamples());
         }
     }
 }