Transcoder: Refactor sample writer to not block clients.

This commit fixes an issue with hangs in the transcoder
by not letting samples from all tracks go directly to the
backing muxer. This relies on tracks being synchronized by
the sample reader and that the muxer buffers and interleaves
samples internally.

Test: Transcoder unit tests.
Fixes: 165374867
Change-Id: I99d2dbfa4eb094b7364848a1a8aa3d3d8742140d
diff --git a/media/libmediatranscoding/transcoder/MediaSampleWriter.cpp b/media/libmediatranscoding/transcoder/MediaSampleWriter.cpp
index bb0da88..afa5021 100644
--- a/media/libmediatranscoding/transcoder/MediaSampleWriter.cpp
+++ b/media/libmediatranscoding/transcoder/MediaSampleWriter.cpp
@@ -72,6 +72,11 @@
     AMediaMuxer* mMuxer;
 };
 
+// static
+std::shared_ptr<MediaSampleWriter> MediaSampleWriter::Create() {
+    return std::shared_ptr<MediaSampleWriter>(new MediaSampleWriter());
+}
+
 MediaSampleWriter::~MediaSampleWriter() {
     if (mState == STARTED) {
         stop();  // Join thread.
@@ -92,7 +97,7 @@
         return false;
     }
 
-    std::scoped_lock lock(mStateMutex);
+    std::scoped_lock lock(mMutex);
     if (mState != UNINITIALIZED) {
         LOG(ERROR) << "Sample writer is already initialized";
         return false;
@@ -104,39 +109,58 @@
     return true;
 }
 
-bool MediaSampleWriter::addTrack(const std::shared_ptr<MediaSampleQueue>& sampleQueue,
-                                 const std::shared_ptr<AMediaFormat>& trackFormat) {
-    if (sampleQueue == nullptr || trackFormat == nullptr) {
-        LOG(ERROR) << "Sample queue and track format must be non-null";
-        return false;
+MediaSampleWriter::MediaSampleConsumerFunction MediaSampleWriter::addTrack(
+        const std::shared_ptr<AMediaFormat>& trackFormat) {
+    if (trackFormat == nullptr) {
+        LOG(ERROR) << "Track format must be non-null";
+        return nullptr;
     }
 
-    std::scoped_lock lock(mStateMutex);
+    std::scoped_lock lock(mMutex);
     if (mState != INITIALIZED) {
         LOG(ERROR) << "Muxer needs to be initialized when adding tracks.";
-        return false;
+        return nullptr;
     }
-    ssize_t trackIndex = mMuxer->addTrack(trackFormat.get());
-    if (trackIndex < 0) {
-        LOG(ERROR) << "Failed to add media track to muxer: " << trackIndex;
-        return false;
+    ssize_t trackIndexOrError = mMuxer->addTrack(trackFormat.get());
+    if (trackIndexOrError < 0) {
+        LOG(ERROR) << "Failed to add media track to muxer: " << trackIndexOrError;
+        return nullptr;
     }
+    const size_t trackIndex = static_cast<size_t>(trackIndexOrError);
 
     int64_t durationUs;
     if (!AMediaFormat_getInt64(trackFormat.get(), AMEDIAFORMAT_KEY_DURATION, &durationUs)) {
         durationUs = 0;
     }
 
-    mAllTracks.push_back(std::make_unique<TrackRecord>(sampleQueue, static_cast<size_t>(trackIndex),
-                                                       durationUs));
-    mSortedTracks.insert(mAllTracks.back().get());
-    return true;
+    mTracks.emplace(trackIndex, durationUs);
+    std::shared_ptr<MediaSampleWriter> thisWriter = shared_from_this();
+
+    return [self = shared_from_this(), trackIndex](const std::shared_ptr<MediaSample>& sample) {
+        self->addSampleToTrack(trackIndex, sample);
+    };
+}
+
+void MediaSampleWriter::addSampleToTrack(size_t trackIndex,
+                                         const std::shared_ptr<MediaSample>& sample) {
+    if (sample == nullptr) return;
+
+    bool wasEmpty;
+    {
+        std::scoped_lock lock(mMutex);
+        wasEmpty = mSampleQueue.empty();
+        mSampleQueue.push(std::make_pair(trackIndex, sample));
+    }
+
+    if (wasEmpty) {
+        mSampleSignal.notify_one();
+    }
 }
 
 bool MediaSampleWriter::start() {
-    std::scoped_lock lock(mStateMutex);
+    std::scoped_lock lock(mMutex);
 
-    if (mAllTracks.size() == 0) {
+    if (mTracks.size() == 0) {
         LOG(ERROR) << "No tracks to write.";
         return false;
     } else if (mState != INITIALIZED) {
@@ -144,30 +168,28 @@
         return false;
     }
 
+    mState = STARTED;
     mThread = std::thread([this] {
         media_status_t status = writeSamples();
         if (auto callbacks = mCallbacks.lock()) {
             callbacks->onFinished(this, status);
         }
     });
-    mState = STARTED;
     return true;
 }
 
 bool MediaSampleWriter::stop() {
-    std::scoped_lock lock(mStateMutex);
-
-    if (mState != STARTED) {
-        LOG(ERROR) << "Sample writer is not started.";
-        return false;
+    {
+        std::scoped_lock lock(mMutex);
+        if (mState != STARTED) {
+            LOG(ERROR) << "Sample writer is not started.";
+            return false;
+        }
+        mState = STOPPED;
     }
 
-    // Stop the sources, and wait for thread to join.
-    for (auto& track : mAllTracks) {
-        track->mSampleQueue->abort();
-    }
+    mSampleSignal.notify_all();
     mThread.join();
-    mState = STOPPED;
     return true;
 }
 
@@ -191,83 +213,69 @@
     return writeStatus != AMEDIA_OK ? writeStatus : muxerStatus;
 }
 
-std::multiset<MediaSampleWriter::TrackRecord*>::iterator MediaSampleWriter::getNextOutputTrack() {
-    // Find the first track that has samples ready in its queue AND is not more than
-    // mMaxTrackDivergenceUs ahead of the slowest track. If no such track exists then return the
-    // slowest track and let the writer wait for samples to become ready. Note that mSortedTracks is
-    // sorted by each track's previous sample timestamp in ascending order.
-    auto slowestTrack = mSortedTracks.begin();
-    if (slowestTrack == mSortedTracks.end() || !(*slowestTrack)->mSampleQueue->isEmpty()) {
-        return slowestTrack;
-    }
-
-    const int64_t slowestTimeUs = (*slowestTrack)->mPrevSampleTimeUs;
-    int64_t divergenceUs;
-
-    for (auto it = std::next(slowestTrack); it != mSortedTracks.end(); ++it) {
-        // If the current track has diverged then the rest will have too, so we can stop the search.
-        // If not and it has samples ready then return it, otherwise keep looking.
-        if (__builtin_sub_overflow((*it)->mPrevSampleTimeUs, slowestTimeUs, &divergenceUs) ||
-            divergenceUs >= mMaxTrackDivergenceUs) {
-            break;
-        } else if (!(*it)->mSampleQueue->isEmpty()) {
-            return it;
-        }
-    }
-
-    // No track with pending samples within acceptable time interval was found, so let the writer
-    // wait for the slowest track to produce a new sample.
-    return slowestTrack;
-}
-
-media_status_t MediaSampleWriter::runWriterLoop() {
+media_status_t MediaSampleWriter::runWriterLoop() NO_THREAD_SAFETY_ANALYSIS {
     AMediaCodecBufferInfo bufferInfo;
     int32_t lastProgressUpdate = 0;
+    int trackEosCount = 0;
 
     // Set the "primary" track that will be used to determine progress to the track with longest
     // duration.
     int primaryTrackIndex = -1;
     int64_t longestDurationUs = 0;
-    for (auto& track : mAllTracks) {
-        if (track->mDurationUs > longestDurationUs) {
-            primaryTrackIndex = track->mTrackIndex;
-            longestDurationUs = track->mDurationUs;
+    for (auto it = mTracks.begin(); it != mTracks.end(); ++it) {
+        if (it->second.mDurationUs > longestDurationUs) {
+            primaryTrackIndex = it->first;
+            longestDurationUs = it->second.mDurationUs;
         }
     }
 
     while (true) {
-        auto outputTrackIter = getNextOutputTrack();
-
-        // Exit if all tracks have reached end of stream.
-        if (outputTrackIter == mSortedTracks.end()) {
+        if (trackEosCount >= mTracks.size()) {
             break;
         }
 
-        // Remove the track from the set, update it, and then reinsert it to keep the set in order.
-        TrackRecord* track = *outputTrackIter;
-        mSortedTracks.erase(outputTrackIter);
-
+        size_t trackIndex;
         std::shared_ptr<MediaSample> sample;
-        if (track->mSampleQueue->dequeue(&sample)) {
-            // Track queue was aborted.
-            return AMEDIA_ERROR_UNKNOWN;  // TODO(lnilsson): Custom error code.
-        } else if (sample->info.flags & SAMPLE_FLAG_END_OF_STREAM) {
+        {
+            std::unique_lock lock(mMutex);
+            while (mSampleQueue.empty() && mState == STARTED) {
+                mSampleSignal.wait(lock);
+            }
+
+            if (mState != STARTED) {
+                return AMEDIA_ERROR_UNKNOWN;  // TODO(lnilsson): Custom error code.
+            }
+
+            auto& topEntry = mSampleQueue.top();
+            trackIndex = topEntry.first;
+            sample = topEntry.second;
+            mSampleQueue.pop();
+        }
+
+        TrackRecord& track = mTracks[trackIndex];
+
+        if (sample->info.flags & SAMPLE_FLAG_END_OF_STREAM) {
+            if (track.mReachedEos) {
+                continue;
+            }
+
             // Track reached end of stream.
-            track->mReachedEos = true;
+            track.mReachedEos = true;
+            trackEosCount++;
 
             // Preserve source track duration by setting the appropriate timestamp on the
             // empty End-Of-Stream sample.
-            if (track->mDurationUs > 0 && track->mFirstSampleTimeSet) {
-                sample->info.presentationTimeUs = track->mDurationUs + track->mFirstSampleTimeUs;
+            if (track.mDurationUs > 0 && track.mFirstSampleTimeSet) {
+                sample->info.presentationTimeUs = track.mDurationUs + track.mFirstSampleTimeUs;
             }
         }
 
-        track->mPrevSampleTimeUs = sample->info.presentationTimeUs;
-        if (!track->mFirstSampleTimeSet) {
+        track.mPrevSampleTimeUs = sample->info.presentationTimeUs;
+        if (!track.mFirstSampleTimeSet) {
             // Record the first sample's timestamp in order to translate duration to EOS
             // time for tracks that does not start at 0.
-            track->mFirstSampleTimeUs = sample->info.presentationTimeUs;
-            track->mFirstSampleTimeSet = true;
+            track.mFirstSampleTimeUs = sample->info.presentationTimeUs;
+            track.mFirstSampleTimeSet = true;
         }
 
         bufferInfo.offset = sample->dataOffset;
@@ -275,8 +283,7 @@
         bufferInfo.flags = sample->info.flags;
         bufferInfo.presentationTimeUs = sample->info.presentationTimeUs;
 
-        media_status_t status =
-                mMuxer->writeSampleData(track->mTrackIndex, sample->buffer, &bufferInfo);
+        media_status_t status = mMuxer->writeSampleData(trackIndex, sample->buffer, &bufferInfo);
         if (status != AMEDIA_OK) {
             LOG(ERROR) << "writeSampleData returned " << status;
             return status;
@@ -284,9 +291,9 @@
         sample.reset();
 
         // TODO(lnilsson): Add option to toggle progress reporting on/off.
-        if (track->mTrackIndex == primaryTrackIndex) {
-            const int64_t elapsed = track->mPrevSampleTimeUs - track->mFirstSampleTimeUs;
-            int32_t progress = (elapsed * 100) / track->mDurationUs;
+        if (trackIndex == primaryTrackIndex) {
+            const int64_t elapsed = track.mPrevSampleTimeUs - track.mFirstSampleTimeUs;
+            int32_t progress = (elapsed * 100) / track.mDurationUs;
             progress = std::clamp(progress, 0, 100);
 
             if (progress > lastProgressUpdate) {
@@ -296,10 +303,6 @@
                 lastProgressUpdate = progress;
             }
         }
-
-        if (!track->mReachedEos) {
-            mSortedTracks.insert(track);
-        }
     }
 
     return AMEDIA_OK;