Transcoder: Refactor sample writer to not block clients.

This commit fixes an issue with hangs in the transcoder
by not letting samples from all tracks go directly to the
backing muxer. This relies on tracks being synchronized by
the sample reader and that the muxer buffers and interleaves
samples internally.

Test: Transcoder unit tests.
Fixes: 165374867
Change-Id: I99d2dbfa4eb094b7364848a1a8aa3d3d8742140d
diff --git a/media/libmediatranscoding/transcoder/MediaSampleWriter.cpp b/media/libmediatranscoding/transcoder/MediaSampleWriter.cpp
index bb0da88..afa5021 100644
--- a/media/libmediatranscoding/transcoder/MediaSampleWriter.cpp
+++ b/media/libmediatranscoding/transcoder/MediaSampleWriter.cpp
@@ -72,6 +72,11 @@
     AMediaMuxer* mMuxer;
 };
 
+// static
+std::shared_ptr<MediaSampleWriter> MediaSampleWriter::Create() {
+    return std::shared_ptr<MediaSampleWriter>(new MediaSampleWriter());
+}
+
 MediaSampleWriter::~MediaSampleWriter() {
     if (mState == STARTED) {
         stop();  // Join thread.
@@ -92,7 +97,7 @@
         return false;
     }
 
-    std::scoped_lock lock(mStateMutex);
+    std::scoped_lock lock(mMutex);
     if (mState != UNINITIALIZED) {
         LOG(ERROR) << "Sample writer is already initialized";
         return false;
@@ -104,39 +109,58 @@
     return true;
 }
 
-bool MediaSampleWriter::addTrack(const std::shared_ptr<MediaSampleQueue>& sampleQueue,
-                                 const std::shared_ptr<AMediaFormat>& trackFormat) {
-    if (sampleQueue == nullptr || trackFormat == nullptr) {
-        LOG(ERROR) << "Sample queue and track format must be non-null";
-        return false;
+MediaSampleWriter::MediaSampleConsumerFunction MediaSampleWriter::addTrack(
+        const std::shared_ptr<AMediaFormat>& trackFormat) {
+    if (trackFormat == nullptr) {
+        LOG(ERROR) << "Track format must be non-null";
+        return nullptr;
     }
 
-    std::scoped_lock lock(mStateMutex);
+    std::scoped_lock lock(mMutex);
     if (mState != INITIALIZED) {
         LOG(ERROR) << "Muxer needs to be initialized when adding tracks.";
-        return false;
+        return nullptr;
     }
-    ssize_t trackIndex = mMuxer->addTrack(trackFormat.get());
-    if (trackIndex < 0) {
-        LOG(ERROR) << "Failed to add media track to muxer: " << trackIndex;
-        return false;
+    ssize_t trackIndexOrError = mMuxer->addTrack(trackFormat.get());
+    if (trackIndexOrError < 0) {
+        LOG(ERROR) << "Failed to add media track to muxer: " << trackIndexOrError;
+        return nullptr;
     }
+    const size_t trackIndex = static_cast<size_t>(trackIndexOrError);
 
     int64_t durationUs;
     if (!AMediaFormat_getInt64(trackFormat.get(), AMEDIAFORMAT_KEY_DURATION, &durationUs)) {
         durationUs = 0;
     }
 
-    mAllTracks.push_back(std::make_unique<TrackRecord>(sampleQueue, static_cast<size_t>(trackIndex),
-                                                       durationUs));
-    mSortedTracks.insert(mAllTracks.back().get());
-    return true;
+    mTracks.emplace(trackIndex, durationUs);
+    std::shared_ptr<MediaSampleWriter> thisWriter = shared_from_this();
+
+    return [self = shared_from_this(), trackIndex](const std::shared_ptr<MediaSample>& sample) {
+        self->addSampleToTrack(trackIndex, sample);
+    };
+}
+
+void MediaSampleWriter::addSampleToTrack(size_t trackIndex,
+                                         const std::shared_ptr<MediaSample>& sample) {
+    if (sample == nullptr) return;
+
+    bool wasEmpty;
+    {
+        std::scoped_lock lock(mMutex);
+        wasEmpty = mSampleQueue.empty();
+        mSampleQueue.push(std::make_pair(trackIndex, sample));
+    }
+
+    if (wasEmpty) {
+        mSampleSignal.notify_one();
+    }
 }
 
 bool MediaSampleWriter::start() {
-    std::scoped_lock lock(mStateMutex);
+    std::scoped_lock lock(mMutex);
 
-    if (mAllTracks.size() == 0) {
+    if (mTracks.size() == 0) {
         LOG(ERROR) << "No tracks to write.";
         return false;
     } else if (mState != INITIALIZED) {
@@ -144,30 +168,28 @@
         return false;
     }
 
+    mState = STARTED;
     mThread = std::thread([this] {
         media_status_t status = writeSamples();
         if (auto callbacks = mCallbacks.lock()) {
             callbacks->onFinished(this, status);
         }
     });
-    mState = STARTED;
     return true;
 }
 
 bool MediaSampleWriter::stop() {
-    std::scoped_lock lock(mStateMutex);
-
-    if (mState != STARTED) {
-        LOG(ERROR) << "Sample writer is not started.";
-        return false;
+    {
+        std::scoped_lock lock(mMutex);
+        if (mState != STARTED) {
+            LOG(ERROR) << "Sample writer is not started.";
+            return false;
+        }
+        mState = STOPPED;
     }
 
-    // Stop the sources, and wait for thread to join.
-    for (auto& track : mAllTracks) {
-        track->mSampleQueue->abort();
-    }
+    mSampleSignal.notify_all();
     mThread.join();
-    mState = STOPPED;
     return true;
 }
 
@@ -191,83 +213,69 @@
     return writeStatus != AMEDIA_OK ? writeStatus : muxerStatus;
 }
 
-std::multiset<MediaSampleWriter::TrackRecord*>::iterator MediaSampleWriter::getNextOutputTrack() {
-    // Find the first track that has samples ready in its queue AND is not more than
-    // mMaxTrackDivergenceUs ahead of the slowest track. If no such track exists then return the
-    // slowest track and let the writer wait for samples to become ready. Note that mSortedTracks is
-    // sorted by each track's previous sample timestamp in ascending order.
-    auto slowestTrack = mSortedTracks.begin();
-    if (slowestTrack == mSortedTracks.end() || !(*slowestTrack)->mSampleQueue->isEmpty()) {
-        return slowestTrack;
-    }
-
-    const int64_t slowestTimeUs = (*slowestTrack)->mPrevSampleTimeUs;
-    int64_t divergenceUs;
-
-    for (auto it = std::next(slowestTrack); it != mSortedTracks.end(); ++it) {
-        // If the current track has diverged then the rest will have too, so we can stop the search.
-        // If not and it has samples ready then return it, otherwise keep looking.
-        if (__builtin_sub_overflow((*it)->mPrevSampleTimeUs, slowestTimeUs, &divergenceUs) ||
-            divergenceUs >= mMaxTrackDivergenceUs) {
-            break;
-        } else if (!(*it)->mSampleQueue->isEmpty()) {
-            return it;
-        }
-    }
-
-    // No track with pending samples within acceptable time interval was found, so let the writer
-    // wait for the slowest track to produce a new sample.
-    return slowestTrack;
-}
-
-media_status_t MediaSampleWriter::runWriterLoop() {
+media_status_t MediaSampleWriter::runWriterLoop() NO_THREAD_SAFETY_ANALYSIS {
     AMediaCodecBufferInfo bufferInfo;
     int32_t lastProgressUpdate = 0;
+    int trackEosCount = 0;
 
     // Set the "primary" track that will be used to determine progress to the track with longest
     // duration.
     int primaryTrackIndex = -1;
     int64_t longestDurationUs = 0;
-    for (auto& track : mAllTracks) {
-        if (track->mDurationUs > longestDurationUs) {
-            primaryTrackIndex = track->mTrackIndex;
-            longestDurationUs = track->mDurationUs;
+    for (auto it = mTracks.begin(); it != mTracks.end(); ++it) {
+        if (it->second.mDurationUs > longestDurationUs) {
+            primaryTrackIndex = it->first;
+            longestDurationUs = it->second.mDurationUs;
         }
     }
 
     while (true) {
-        auto outputTrackIter = getNextOutputTrack();
-
-        // Exit if all tracks have reached end of stream.
-        if (outputTrackIter == mSortedTracks.end()) {
+        if (trackEosCount >= mTracks.size()) {
             break;
         }
 
-        // Remove the track from the set, update it, and then reinsert it to keep the set in order.
-        TrackRecord* track = *outputTrackIter;
-        mSortedTracks.erase(outputTrackIter);
-
+        size_t trackIndex;
         std::shared_ptr<MediaSample> sample;
-        if (track->mSampleQueue->dequeue(&sample)) {
-            // Track queue was aborted.
-            return AMEDIA_ERROR_UNKNOWN;  // TODO(lnilsson): Custom error code.
-        } else if (sample->info.flags & SAMPLE_FLAG_END_OF_STREAM) {
+        {
+            std::unique_lock lock(mMutex);
+            while (mSampleQueue.empty() && mState == STARTED) {
+                mSampleSignal.wait(lock);
+            }
+
+            if (mState != STARTED) {
+                return AMEDIA_ERROR_UNKNOWN;  // TODO(lnilsson): Custom error code.
+            }
+
+            auto& topEntry = mSampleQueue.top();
+            trackIndex = topEntry.first;
+            sample = topEntry.second;
+            mSampleQueue.pop();
+        }
+
+        TrackRecord& track = mTracks[trackIndex];
+
+        if (sample->info.flags & SAMPLE_FLAG_END_OF_STREAM) {
+            if (track.mReachedEos) {
+                continue;
+            }
+
             // Track reached end of stream.
-            track->mReachedEos = true;
+            track.mReachedEos = true;
+            trackEosCount++;
 
             // Preserve source track duration by setting the appropriate timestamp on the
             // empty End-Of-Stream sample.
-            if (track->mDurationUs > 0 && track->mFirstSampleTimeSet) {
-                sample->info.presentationTimeUs = track->mDurationUs + track->mFirstSampleTimeUs;
+            if (track.mDurationUs > 0 && track.mFirstSampleTimeSet) {
+                sample->info.presentationTimeUs = track.mDurationUs + track.mFirstSampleTimeUs;
             }
         }
 
-        track->mPrevSampleTimeUs = sample->info.presentationTimeUs;
-        if (!track->mFirstSampleTimeSet) {
+        track.mPrevSampleTimeUs = sample->info.presentationTimeUs;
+        if (!track.mFirstSampleTimeSet) {
             // Record the first sample's timestamp in order to translate duration to EOS
             // time for tracks that does not start at 0.
-            track->mFirstSampleTimeUs = sample->info.presentationTimeUs;
-            track->mFirstSampleTimeSet = true;
+            track.mFirstSampleTimeUs = sample->info.presentationTimeUs;
+            track.mFirstSampleTimeSet = true;
         }
 
         bufferInfo.offset = sample->dataOffset;
@@ -275,8 +283,7 @@
         bufferInfo.flags = sample->info.flags;
         bufferInfo.presentationTimeUs = sample->info.presentationTimeUs;
 
-        media_status_t status =
-                mMuxer->writeSampleData(track->mTrackIndex, sample->buffer, &bufferInfo);
+        media_status_t status = mMuxer->writeSampleData(trackIndex, sample->buffer, &bufferInfo);
         if (status != AMEDIA_OK) {
             LOG(ERROR) << "writeSampleData returned " << status;
             return status;
@@ -284,9 +291,9 @@
         sample.reset();
 
         // TODO(lnilsson): Add option to toggle progress reporting on/off.
-        if (track->mTrackIndex == primaryTrackIndex) {
-            const int64_t elapsed = track->mPrevSampleTimeUs - track->mFirstSampleTimeUs;
-            int32_t progress = (elapsed * 100) / track->mDurationUs;
+        if (trackIndex == primaryTrackIndex) {
+            const int64_t elapsed = track.mPrevSampleTimeUs - track.mFirstSampleTimeUs;
+            int32_t progress = (elapsed * 100) / track.mDurationUs;
             progress = std::clamp(progress, 0, 100);
 
             if (progress > lastProgressUpdate) {
@@ -296,10 +303,6 @@
                 lastProgressUpdate = progress;
             }
         }
-
-        if (!track->mReachedEos) {
-            mSortedTracks.insert(track);
-        }
     }
 
     return AMEDIA_OK;
diff --git a/media/libmediatranscoding/transcoder/MediaTrackTranscoder.cpp b/media/libmediatranscoding/transcoder/MediaTrackTranscoder.cpp
index 92ce60a..698594f 100644
--- a/media/libmediatranscoding/transcoder/MediaTrackTranscoder.cpp
+++ b/media/libmediatranscoding/transcoder/MediaTrackTranscoder.cpp
@@ -94,7 +94,10 @@
         abortTranscodeLoop();
         mMediaSampleReader->setEnforceSequentialAccess(false);
         mTranscodingThread.join();
-        mOutputQueue->abort();  // Wake up any threads waiting for samples.
+        {
+            std::scoped_lock lock{mSampleMutex};
+            mSampleQueue.abort();  // Release any buffered samples.
+        }
         mState = STOPPED;
         return true;
     }
@@ -109,8 +112,24 @@
     }
 }
 
-std::shared_ptr<MediaSampleQueue> MediaTrackTranscoder::getOutputQueue() const {
-    return mOutputQueue;
+void MediaTrackTranscoder::onOutputSampleAvailable(const std::shared_ptr<MediaSample>& sample) {
+    std::scoped_lock lock{mSampleMutex};
+    if (mSampleConsumer == nullptr) {
+        mSampleQueue.enqueue(sample);
+    } else {
+        mSampleConsumer(sample);
+    }
+}
+
+void MediaTrackTranscoder::setSampleConsumer(
+        const MediaSampleWriter::MediaSampleConsumerFunction& sampleConsumer) {
+    std::scoped_lock lock{mSampleMutex};
+    mSampleConsumer = sampleConsumer;
+
+    std::shared_ptr<MediaSample> sample;
+    while (!mSampleQueue.isEmpty() && !mSampleQueue.dequeue(&sample)) {
+        mSampleConsumer(sample);
+    }
 }
 
 }  // namespace android
diff --git a/media/libmediatranscoding/transcoder/MediaTranscoder.cpp b/media/libmediatranscoding/transcoder/MediaTranscoder.cpp
index fbed5c2..61cc459 100644
--- a/media/libmediatranscoding/transcoder/MediaTranscoder.cpp
+++ b/media/libmediatranscoding/transcoder/MediaTranscoder.cpp
@@ -123,14 +123,16 @@
     }
 
     // Add track to the writer.
-    const bool ok =
-            mSampleWriter->addTrack(transcoder->getOutputQueue(), transcoder->getOutputFormat());
-    if (!ok) {
+    auto consumer = mSampleWriter->addTrack(transcoder->getOutputFormat());
+    if (consumer == nullptr) {
         LOG(ERROR) << "Unable to add track to sample writer.";
         sendCallback(AMEDIA_ERROR_UNKNOWN);
         return;
     }
 
+    MediaTrackTranscoder* mutableTranscoder = const_cast<MediaTrackTranscoder*>(transcoder);
+    mutableTranscoder->setSampleConsumer(consumer);
+
     mTracksAdded.insert(transcoder);
     if (mTracksAdded.size() == mTrackTranscoders.size()) {
         // Enable sequential access mode on the sample reader to achieve optimal read performance.
@@ -304,7 +306,7 @@
         return AMEDIA_ERROR_INVALID_OPERATION;
     }
 
-    mSampleWriter = std::make_unique<MediaSampleWriter>();
+    mSampleWriter = MediaSampleWriter::Create();
     const bool initOk = mSampleWriter->init(fd, shared_from_this());
 
     if (!initOk) {
diff --git a/media/libmediatranscoding/transcoder/PassthroughTrackTranscoder.cpp b/media/libmediatranscoding/transcoder/PassthroughTrackTranscoder.cpp
index e7c0271..35b1d33 100644
--- a/media/libmediatranscoding/transcoder/PassthroughTrackTranscoder.cpp
+++ b/media/libmediatranscoding/transcoder/PassthroughTrackTranscoder.cpp
@@ -138,10 +138,7 @@
         }
 
         sample->info = info;
-        if (mOutputQueue->enqueue(sample)) {
-            LOG(ERROR) << "Output queue aborted";
-            return AMEDIA_ERROR_IO;
-        }
+        onOutputSampleAvailable(sample);
     }
 
     if (mStopRequested && !mEosFromSource) {
diff --git a/media/libmediatranscoding/transcoder/VideoTrackTranscoder.cpp b/media/libmediatranscoding/transcoder/VideoTrackTranscoder.cpp
index b0bf59f..c7d775c 100644
--- a/media/libmediatranscoding/transcoder/VideoTrackTranscoder.cpp
+++ b/media/libmediatranscoding/transcoder/VideoTrackTranscoder.cpp
@@ -375,12 +375,7 @@
         sample->info.flags = bufferInfo.flags;
         sample->info.presentationTimeUs = bufferInfo.presentationTimeUs;
 
-        const bool aborted = mOutputQueue->enqueue(sample);
-        if (aborted) {
-            LOG(ERROR) << "Output sample queue was aborted. Stopping transcode.";
-            mStatus = AMEDIA_ERROR_IO;  // TODO: Define custom error codes?
-            return;
-        }
+        onOutputSampleAvailable(sample);
     } else if (bufferIndex == AMEDIACODEC_INFO_OUTPUT_FORMAT_CHANGED) {
         AMediaFormat* newFormat = AMediaCodec_getOutputFormat(mEncoder->getCodec());
         LOG(DEBUG) << "Encoder output format changed: " << AMediaFormat_toString(newFormat);
diff --git a/media/libmediatranscoding/transcoder/include/media/MediaSampleReaderNDK.h b/media/libmediatranscoding/transcoder/include/media/MediaSampleReaderNDK.h
index 5f9822d..2032def 100644
--- a/media/libmediatranscoding/transcoder/include/media/MediaSampleReaderNDK.h
+++ b/media/libmediatranscoding/transcoder/include/media/MediaSampleReaderNDK.h
@@ -58,7 +58,6 @@
     virtual ~MediaSampleReaderNDK() override;
 
 private:
-
     /**
      * SamplePosition describes the position of a single sample in the media file using its
      * timestamp and index in the file.
diff --git a/media/libmediatranscoding/transcoder/include/media/MediaSampleWriter.h b/media/libmediatranscoding/transcoder/include/media/MediaSampleWriter.h
index d4b1fcf..f762556 100644
--- a/media/libmediatranscoding/transcoder/include/media/MediaSampleWriter.h
+++ b/media/libmediatranscoding/transcoder/include/media/MediaSampleWriter.h
@@ -17,17 +17,19 @@
 #ifndef ANDROID_MEDIA_SAMPLE_WRITER_H
 #define ANDROID_MEDIA_SAMPLE_WRITER_H
 
-#include <media/MediaSampleQueue.h>
+#include <media/MediaSample.h>
 #include <media/NdkMediaCodec.h>
 #include <media/NdkMediaError.h>
 #include <media/NdkMediaFormat.h>
 #include <utils/Mutex.h>
 
+#include <condition_variable>
 #include <functional>
 #include <memory>
 #include <mutex>
-#include <set>
+#include <queue>
 #include <thread>
+#include <unordered_map>
 
 namespace android {
 
@@ -62,18 +64,16 @@
 };
 
 /**
- * MediaSampleWriter writes samples to a muxer while keeping its input sources synchronized. Each
- * source track have its own MediaSampleQueue from which samples are dequeued by the sample writer
- * and written to the muxer. The sample writer always prioritizes dequeueing samples from the source
- * track that is farthest behind by comparing sample timestamps. If the slowest track does not have
- * any samples pending the writer moves on to the next track but never allows tracks to diverge more
- * than a configurable duration of time. The default muxer interface implementation is based
+ * MediaSampleWriter is a wrapper around a muxer. The sample writer puts samples on a queue that
+ * is serviced by an internal thread to minimize blocking time for clients. MediaSampleWriter also
+ * provides progress reporting. The default muxer interface implementation is based
  * directly on AMediaMuxer.
  */
-class MediaSampleWriter {
+class MediaSampleWriter : public std::enable_shared_from_this<MediaSampleWriter> {
 public:
-    /** The default maximum track divergence in microseconds. */
-    static constexpr uint32_t kDefaultMaxTrackDivergenceUs = 1 * 1000 * 1000;  // 1 second.
+    /** Function prototype for delivering media samples to the writer. */
+    using MediaSampleConsumerFunction =
+            std::function<void(const std::shared_ptr<MediaSample>& sample)>;
 
     /** Callback interface. */
     class CallbackInterface {
@@ -90,18 +90,7 @@
         virtual ~CallbackInterface() = default;
     };
 
-    /**
-     * Constructor with custom maximum track divergence.
-     * @param maxTrackDivergenceUs The maximum track divergence in microseconds.
-     */
-    MediaSampleWriter(uint32_t maxTrackDivergenceUs)
-          : mMaxTrackDivergenceUs(maxTrackDivergenceUs), mMuxer(nullptr), mState(UNINITIALIZED){};
-
-    /** Constructor using the default maximum track divergence. */
-    MediaSampleWriter() : MediaSampleWriter(kDefaultMaxTrackDivergenceUs){};
-
-    /** Destructor. */
-    ~MediaSampleWriter();
+    static std::shared_ptr<MediaSampleWriter> Create();
 
     /**
      * Initializes the sample writer with its default muxer implementation. MediaSampleWriter needs
@@ -125,12 +114,12 @@
     /**
      * Adds a new track to the sample writer. Tracks must be added after the sample writer has been
      * initialized and before it is started.
-     * @param sampleQueue The MediaSampleQueue to pull samples from.
      * @param trackFormat The format of the track to add.
-     * @return True if the track was successfully added.
+     * @return A sample consumer to add samples to if the track was successfully added, or nullptr
+     * if the track could not be added.
      */
-    bool addTrack(const std::shared_ptr<MediaSampleQueue>& sampleQueue /* nonnull */,
-                  const std::shared_ptr<AMediaFormat>& trackFormat /* nonnull */);
+    MediaSampleConsumerFunction addTrack(
+            const std::shared_ptr<AMediaFormat>& trackFormat /* nonnull */);
 
     /**
      * Starts the sample writer. The sample writer will start processing samples and writing them to
@@ -150,51 +139,69 @@
      */
     bool stop();
 
+    /** Destructor. */
+    ~MediaSampleWriter();
+
 private:
     struct TrackRecord {
-        TrackRecord(const std::shared_ptr<MediaSampleQueue>& sampleQueue, size_t trackIndex,
-                    int64_t durationUs)
-              : mSampleQueue(sampleQueue),
-                mTrackIndex(trackIndex),
-                mDurationUs(durationUs),
+        TrackRecord(int64_t durationUs)
+              : mDurationUs(durationUs),
                 mFirstSampleTimeUs(0),
                 mPrevSampleTimeUs(INT64_MIN),
                 mFirstSampleTimeSet(false),
-                mReachedEos(false) {}
+                mReachedEos(false){};
 
-        std::shared_ptr<MediaSampleQueue> mSampleQueue;
-        const size_t mTrackIndex;
+        TrackRecord() : TrackRecord(0){};
+
         int64_t mDurationUs;
         int64_t mFirstSampleTimeUs;
         int64_t mPrevSampleTimeUs;
         bool mFirstSampleTimeSet;
         bool mReachedEos;
-
-        struct compare {
-            bool operator()(const TrackRecord* lhs, const TrackRecord* rhs) const {
-                return lhs->mPrevSampleTimeUs < rhs->mPrevSampleTimeUs;
-            }
-        };
     };
 
-    const uint32_t mMaxTrackDivergenceUs;
+    // Track index and sample.
+    using SampleEntry = std::pair<size_t, std::shared_ptr<MediaSample>>;
+
+    struct SampleComparator {
+        // Return true if lhs should come after rhs in the sample queue.
+        bool operator()(const SampleEntry& lhs, const SampleEntry& rhs) {
+            const bool lhsEos = lhs.second->info.flags & SAMPLE_FLAG_END_OF_STREAM;
+            const bool rhsEos = rhs.second->info.flags & SAMPLE_FLAG_END_OF_STREAM;
+
+            if (lhsEos && !rhsEos) {
+                return true;
+            } else if (!lhsEos && rhsEos) {
+                return false;
+            } else if (lhsEos && rhsEos) {
+                return lhs.first > rhs.first;
+            }
+
+            return lhs.second->info.presentationTimeUs > rhs.second->info.presentationTimeUs;
+        }
+    };
+
     std::weak_ptr<CallbackInterface> mCallbacks;
     std::shared_ptr<MediaSampleWriterMuxerInterface> mMuxer;
-    std::vector<std::unique_ptr<TrackRecord>> mAllTracks;
-    std::multiset<TrackRecord*, TrackRecord::compare> mSortedTracks;
-    std::thread mThread;
 
-    std::mutex mStateMutex;
+    std::mutex mMutex;  // Protects sample queue and state.
+    std::condition_variable mSampleSignal;
+    std::thread mThread;
+    std::unordered_map<size_t, TrackRecord> mTracks;
+    std::priority_queue<SampleEntry, std::vector<SampleEntry>, SampleComparator> mSampleQueue
+            GUARDED_BY(mMutex);
+
     enum : int {
         UNINITIALIZED,
         INITIALIZED,
         STARTED,
         STOPPED,
-    } mState GUARDED_BY(mStateMutex);
+    } mState GUARDED_BY(mMutex);
 
+    MediaSampleWriter() : mState(UNINITIALIZED){};
+    void addSampleToTrack(size_t trackIndex, const std::shared_ptr<MediaSample>& sample);
     media_status_t writeSamples();
     media_status_t runWriterLoop();
-    std::multiset<TrackRecord*>::iterator getNextOutputTrack();
 };
 
 }  // namespace android
diff --git a/media/libmediatranscoding/transcoder/include/media/MediaTrackTranscoder.h b/media/libmediatranscoding/transcoder/include/media/MediaTrackTranscoder.h
index 60a9139..c5e161c 100644
--- a/media/libmediatranscoding/transcoder/include/media/MediaTrackTranscoder.h
+++ b/media/libmediatranscoding/transcoder/include/media/MediaTrackTranscoder.h
@@ -19,6 +19,7 @@
 
 #include <media/MediaSampleQueue.h>
 #include <media/MediaSampleReader.h>
+#include <media/MediaSampleWriter.h>
 #include <media/NdkMediaError.h>
 #include <media/NdkMediaFormat.h>
 #include <utils/Mutex.h>
@@ -75,10 +76,13 @@
     bool stop();
 
     /**
-     * Retrieves the track transcoder's output sample queue.
-     * @return The output sample queue.
+     * Set the sample consumer function. The MediaTrackTranscoder will deliver transcoded samples to
+     * this function. If the MediaTrackTranscoder is started before a consumer is set the transcoder
+     * will buffer a limited number of samples internally before stalling. Once a consumer has been
+     * set the internally buffered samples will be delivered to the consumer.
+     * @param sampleConsumer The sample consumer function.
      */
-    std::shared_ptr<MediaSampleQueue> getOutputQueue() const;
+    void setSampleConsumer(const MediaSampleWriter::MediaSampleConsumerFunction& sampleConsumer);
 
     /**
       * Retrieves the track transcoder's final output format. The output is available after the
@@ -91,12 +95,14 @@
 
 protected:
     MediaTrackTranscoder(const std::weak_ptr<MediaTrackTranscoderCallback>& transcoderCallback)
-          : mOutputQueue(std::make_shared<MediaSampleQueue>()),
-            mTranscoderCallback(transcoderCallback){};
+          : mTranscoderCallback(transcoderCallback){};
 
     // Called by subclasses when the actual track format becomes available.
     void notifyTrackFormatAvailable();
 
+    // Called by subclasses when a transcoded sample is available.
+    void onOutputSampleAvailable(const std::shared_ptr<MediaSample>& sample);
+
     // configureDestinationFormat needs to be implemented by subclasses, and gets called on an
     // external thread before start.
     virtual media_status_t configureDestinationFormat(
@@ -110,12 +116,14 @@
     // be aborted as soon as possible. It should be safe to call abortTranscodeLoop multiple times.
     virtual void abortTranscodeLoop() = 0;
 
-    std::shared_ptr<MediaSampleQueue> mOutputQueue;
     std::shared_ptr<MediaSampleReader> mMediaSampleReader;
     int mTrackIndex;
     std::shared_ptr<AMediaFormat> mSourceFormat;
 
 private:
+    std::mutex mSampleMutex;
+    MediaSampleQueue mSampleQueue GUARDED_BY(mSampleMutex);
+    MediaSampleWriter::MediaSampleConsumerFunction mSampleConsumer GUARDED_BY(mSampleMutex);
     const std::weak_ptr<MediaTrackTranscoderCallback> mTranscoderCallback;
     std::mutex mStateMutex;
     std::thread mTranscodingThread GUARDED_BY(mStateMutex);
diff --git a/media/libmediatranscoding/transcoder/include/media/MediaTranscoder.h b/media/libmediatranscoding/transcoder/include/media/MediaTranscoder.h
index 8d96867..9a367ca 100644
--- a/media/libmediatranscoding/transcoder/include/media/MediaTranscoder.h
+++ b/media/libmediatranscoding/transcoder/include/media/MediaTranscoder.h
@@ -138,7 +138,7 @@
 
     std::shared_ptr<CallbackInterface> mCallbacks;
     std::shared_ptr<MediaSampleReader> mSampleReader;
-    std::unique_ptr<MediaSampleWriter> mSampleWriter;
+    std::shared_ptr<MediaSampleWriter> mSampleWriter;
     std::vector<std::shared_ptr<AMediaFormat>> mSourceTrackFormats;
     std::vector<std::shared_ptr<MediaTrackTranscoder>> mTrackTranscoders;
     std::mutex mTracksAddedMutex;
diff --git a/media/libmediatranscoding/transcoder/tests/MediaSampleReaderNDKTests.cpp b/media/libmediatranscoding/transcoder/tests/MediaSampleReaderNDKTests.cpp
index e8acd48..9c9c8b5 100644
--- a/media/libmediatranscoding/transcoder/tests/MediaSampleReaderNDKTests.cpp
+++ b/media/libmediatranscoding/transcoder/tests/MediaSampleReaderNDKTests.cpp
@@ -26,6 +26,7 @@
 #include <gtest/gtest.h>
 #include <media/MediaSampleReaderNDK.h>
 #include <utils/Timers.h>
+
 #include <cmath>
 #include <mutex>
 #include <thread>
diff --git a/media/libmediatranscoding/transcoder/tests/MediaSampleWriterTests.cpp b/media/libmediatranscoding/transcoder/tests/MediaSampleWriterTests.cpp
index 64240d4..46f3e9b 100644
--- a/media/libmediatranscoding/transcoder/tests/MediaSampleWriterTests.cpp
+++ b/media/libmediatranscoding/transcoder/tests/MediaSampleWriterTests.cpp
@@ -274,102 +274,95 @@
     void SetUp() override {
         LOG(DEBUG) << "MediaSampleWriterTests set up";
         mTestMuxer = std::make_shared<TestMuxer>();
-        mSampleQueue = std::make_shared<MediaSampleQueue>();
     }
 
     void TearDown() override {
         LOG(DEBUG) << "MediaSampleWriterTests tear down";
         mTestMuxer.reset();
-        mSampleQueue.reset();
     }
 
 protected:
     std::shared_ptr<TestMuxer> mTestMuxer;
-    std::shared_ptr<MediaSampleQueue> mSampleQueue;
     std::shared_ptr<TestCallbacks> mTestCallbacks = std::make_shared<TestCallbacks>();
 };
 
 TEST_F(MediaSampleWriterTests, TestAddTrackWithoutInit) {
     const TestMediaSource& mediaSource = getMediaSource();
 
-    MediaSampleWriter writer{};
-    EXPECT_FALSE(writer.addTrack(mSampleQueue, mediaSource.mTrackFormats[0]));
+    std::shared_ptr<MediaSampleWriter> writer = MediaSampleWriter::Create();
+    EXPECT_EQ(writer->addTrack(mediaSource.mTrackFormats[0]), nullptr);
 }
 
 TEST_F(MediaSampleWriterTests, TestStartWithoutInit) {
-    MediaSampleWriter writer{};
-    EXPECT_FALSE(writer.start());
+    std::shared_ptr<MediaSampleWriter> writer = MediaSampleWriter::Create();
+    EXPECT_FALSE(writer->start());
 }
 
 TEST_F(MediaSampleWriterTests, TestStartWithoutTracks) {
-    MediaSampleWriter writer{};
-    EXPECT_TRUE(writer.init(mTestMuxer, mTestCallbacks));
-    EXPECT_FALSE(writer.start());
+    std::shared_ptr<MediaSampleWriter> writer = MediaSampleWriter::Create();
+    EXPECT_TRUE(writer->init(mTestMuxer, mTestCallbacks));
+    EXPECT_FALSE(writer->start());
     EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::NoEvent);
 }
 
 TEST_F(MediaSampleWriterTests, TestAddInvalidTrack) {
-    MediaSampleWriter writer{};
-    EXPECT_TRUE(writer.init(mTestMuxer, mTestCallbacks));
+    std::shared_ptr<MediaSampleWriter> writer = MediaSampleWriter::Create();
+    EXPECT_TRUE(writer->init(mTestMuxer, mTestCallbacks));
 
-    EXPECT_FALSE(writer.addTrack(mSampleQueue, nullptr));
-    EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::NoEvent);
-
-    const TestMediaSource& mediaSource = getMediaSource();
-    EXPECT_FALSE(writer.addTrack(nullptr, mediaSource.mTrackFormats[0]));
+    EXPECT_EQ(writer->addTrack(nullptr), nullptr);
     EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::NoEvent);
 }
 
 TEST_F(MediaSampleWriterTests, TestDoubleStartStop) {
-    MediaSampleWriter writer{};
+    std::shared_ptr<MediaSampleWriter> writer = MediaSampleWriter::Create();
 
     std::shared_ptr<TestCallbacks> callbacks =
             std::make_shared<TestCallbacks>(false /* expectSuccess */);
-    EXPECT_TRUE(writer.init(mTestMuxer, callbacks));
+    EXPECT_TRUE(writer->init(mTestMuxer, callbacks));
 
     const TestMediaSource& mediaSource = getMediaSource();
-    EXPECT_TRUE(writer.addTrack(mSampleQueue, mediaSource.mTrackFormats[0]));
+    EXPECT_NE(writer->addTrack(mediaSource.mTrackFormats[0]), nullptr);
     EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::AddTrack(mediaSource.mTrackFormats[0].get()));
 
-    ASSERT_TRUE(writer.start());
-    EXPECT_FALSE(writer.start());
+    ASSERT_TRUE(writer->start());
+    EXPECT_FALSE(writer->start());
 
-    EXPECT_TRUE(writer.stop());
+    EXPECT_TRUE(writer->stop());
     EXPECT_TRUE(callbacks->hasFinished());
-    EXPECT_FALSE(writer.stop());
+    EXPECT_FALSE(writer->stop());
 }
 
 TEST_F(MediaSampleWriterTests, TestStopWithoutStart) {
-    MediaSampleWriter writer{};
-    EXPECT_TRUE(writer.init(mTestMuxer, mTestCallbacks));
+    std::shared_ptr<MediaSampleWriter> writer = MediaSampleWriter::Create();
+    EXPECT_TRUE(writer->init(mTestMuxer, mTestCallbacks));
 
     const TestMediaSource& mediaSource = getMediaSource();
-    EXPECT_TRUE(writer.addTrack(mSampleQueue, mediaSource.mTrackFormats[0]));
+    EXPECT_NE(writer->addTrack(mediaSource.mTrackFormats[0]), nullptr);
     EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::AddTrack(mediaSource.mTrackFormats[0].get()));
 
-    EXPECT_FALSE(writer.stop());
+    EXPECT_FALSE(writer->stop());
     EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::NoEvent);
 }
 
 TEST_F(MediaSampleWriterTests, TestStartWithoutCallback) {
-    MediaSampleWriter writer{};
+    std::shared_ptr<MediaSampleWriter> writer = MediaSampleWriter::Create();
 
     std::weak_ptr<MediaSampleWriter::CallbackInterface> unassignedWp;
-    EXPECT_FALSE(writer.init(mTestMuxer, unassignedWp));
+    EXPECT_FALSE(writer->init(mTestMuxer, unassignedWp));
 
     std::shared_ptr<MediaSampleWriter::CallbackInterface> unassignedSp;
-    EXPECT_FALSE(writer.init(mTestMuxer, unassignedSp));
+    EXPECT_FALSE(writer->init(mTestMuxer, unassignedSp));
 
     const TestMediaSource& mediaSource = getMediaSource();
-    EXPECT_FALSE(writer.addTrack(mSampleQueue, mediaSource.mTrackFormats[0]));
-    ASSERT_FALSE(writer.start());
+    EXPECT_EQ(writer->addTrack(mediaSource.mTrackFormats[0]), nullptr);
+    ASSERT_FALSE(writer->start());
 }
 
 TEST_F(MediaSampleWriterTests, TestProgressUpdate) {
     const TestMediaSource& mediaSource = getMediaSource();
 
-    MediaSampleWriter writer{};
-    EXPECT_TRUE(writer.init(mTestMuxer, mTestCallbacks));
+    std::shared_ptr<MediaSampleWriter> writer = MediaSampleWriter::Create();
+    EXPECT_TRUE(writer->init(mTestMuxer, mTestCallbacks));
 
     std::shared_ptr<AMediaFormat> videoFormat =
             std::shared_ptr<AMediaFormat>(AMediaFormat_new(), &AMediaFormat_delete);
@@ -377,42 +370,41 @@
                       mediaSource.mTrackFormats[mediaSource.mVideoTrackIndex].get());
 
     AMediaFormat_setInt64(videoFormat.get(), AMEDIAFORMAT_KEY_DURATION, 100);
-    EXPECT_TRUE(writer.addTrack(mSampleQueue, videoFormat));
-    ASSERT_TRUE(writer.start());
+    auto sampleConsumer = writer->addTrack(videoFormat);
+    EXPECT_NE(sampleConsumer, nullptr);
+    ASSERT_TRUE(writer->start());
 
     for (int64_t pts = 0; pts < 100; ++pts) {
-        mSampleQueue->enqueue(newSampleWithPts(pts));
+        sampleConsumer(newSampleWithPts(pts));
     }
-    mSampleQueue->enqueue(newSampleEos());
+    sampleConsumer(newSampleEos());
     mTestCallbacks->waitForWritingFinished();
 
     EXPECT_EQ(mTestCallbacks->getProgressUpdateCount(), 100);
 }
 
 TEST_F(MediaSampleWriterTests, TestInterleaving) {
-    MediaSampleWriter writer{};
-    EXPECT_TRUE(writer.init(mTestMuxer, mTestCallbacks));
+    std::shared_ptr<MediaSampleWriter> writer = MediaSampleWriter::Create();
+    EXPECT_TRUE(writer->init(mTestMuxer, mTestCallbacks));
 
     // Use two tracks for this test.
     static constexpr int kNumTracks = 2;
-    std::shared_ptr<MediaSampleQueue> sampleQueues[kNumTracks];
-    std::vector<std::pair<std::shared_ptr<MediaSample>, size_t>> interleavedSamples;
+    MediaSampleWriter::MediaSampleConsumerFunction sampleConsumers[kNumTracks];
+    std::vector<std::pair<std::shared_ptr<MediaSample>, size_t>> addedSamples;
     const TestMediaSource& mediaSource = getMediaSource();
 
     for (int trackIdx = 0; trackIdx < kNumTracks; ++trackIdx) {
-        sampleQueues[trackIdx] = std::make_shared<MediaSampleQueue>();
-
         auto trackFormat = mediaSource.mTrackFormats[trackIdx % mediaSource.mTrackCount];
-        EXPECT_TRUE(writer.addTrack(sampleQueues[trackIdx], trackFormat));
+        sampleConsumers[trackIdx] = writer->addTrack(trackFormat);
+        EXPECT_NE(sampleConsumers[trackIdx], nullptr);
         EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::AddTrack(trackFormat.get()));
     }
 
     // Create samples in the expected interleaved order for easy verification.
-    auto addSampleToTrackWithPts = [&interleavedSamples, &sampleQueues](int trackIndex,
-                                                                        int64_t pts) {
+    auto addSampleToTrackWithPts = [&addedSamples, &sampleConsumers](int trackIndex, int64_t pts) {
         auto sample = newSampleWithPts(pts);
-        sampleQueues[trackIndex]->enqueue(sample);
-        interleavedSamples.emplace_back(sample, trackIndex);
+        sampleConsumers[trackIndex](sample);
+        addedSamples.emplace_back(sample, trackIndex);
     };
 
     addSampleToTrackWithPts(0, 0);
@@ -431,18 +423,24 @@
     addSampleToTrackWithPts(1, 13);
 
     for (int trackIndex = 0; trackIndex < kNumTracks; ++trackIndex) {
-        sampleQueues[trackIndex]->enqueue(newSampleEos());
+        sampleConsumers[trackIndex](newSampleEos());
     }
 
     // Start the writer.
-    ASSERT_TRUE(writer.start());
+    ASSERT_TRUE(writer->start());
 
     // Wait for writer to complete.
     mTestCallbacks->waitForWritingFinished();
     EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::Start());
 
+    std::sort(addedSamples.begin(), addedSamples.end(),
+              [](const std::pair<std::shared_ptr<MediaSample>, size_t>& left,
+                 const std::pair<std::shared_ptr<MediaSample>, size_t>& right) {
+                  return left.first->info.presentationTimeUs < right.first->info.presentationTimeUs;
+              });
+
     // Verify sample order.
-    for (auto entry : interleavedSamples) {
+    for (auto entry : addedSamples) {
         auto sample = entry.first;
         auto trackIndex = entry.second;
 
@@ -470,162 +468,10 @@
     }
 
     EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::Stop());
-    EXPECT_TRUE(writer.stop());
+    EXPECT_TRUE(writer->stop());
     EXPECT_TRUE(mTestCallbacks->hasFinished());
 }
 
-TEST_F(MediaSampleWriterTests, TestMaxDivergence) {
-    static constexpr uint32_t kMaxDivergenceUs = 10;
-
-    MediaSampleWriter writer{kMaxDivergenceUs};
-    EXPECT_TRUE(writer.init(mTestMuxer, mTestCallbacks));
-
-    // Use two tracks for this test.
-    static constexpr int kNumTracks = 2;
-    std::shared_ptr<MediaSampleQueue> sampleQueues[kNumTracks];
-    const TestMediaSource& mediaSource = getMediaSource();
-
-    for (int trackIdx = 0; trackIdx < kNumTracks; ++trackIdx) {
-        sampleQueues[trackIdx] = std::make_shared<MediaSampleQueue>();
-
-        auto trackFormat = mediaSource.mTrackFormats[trackIdx % mediaSource.mTrackCount];
-        EXPECT_TRUE(writer.addTrack(sampleQueues[trackIdx], trackFormat));
-        EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::AddTrack(trackFormat.get()));
-    }
-
-    ASSERT_TRUE(writer.start());
-    EXPECT_EQ(mTestMuxer->popEvent(true), TestMuxer::Start());
-
-    // The first samples of each track can be written in any order since the writer does not have
-    // any previous timestamps to compare.
-    sampleQueues[0]->enqueue(newSampleWithPtsOnly(0));
-    sampleQueues[1]->enqueue(newSampleWithPtsOnly(1));
-    mTestMuxer->popEvent(true);
-    mTestMuxer->popEvent(true);
-
-    // The writer will now be waiting on track 0 since it has the lowest previous timestamp.
-    sampleQueues[0]->enqueue(newSampleWithPtsOnly(kMaxDivergenceUs + 1));
-    sampleQueues[0]->enqueue(newSampleWithPtsOnly(kMaxDivergenceUs + 2));
-
-    // The writer should dequeue the first sample above but not the second since track 0 now is too
-    // far ahead. Instead it should wait for track 1.
-    EXPECT_EQ(mTestMuxer->popEvent(true), TestMuxer::WriteSampleWithPts(0, kMaxDivergenceUs + 1));
-
-    // Enqueue a sample from track 1 that puts it within acceptable divergence range again. The
-    // writer should dequeue that sample and then go back to track 0 since track 1 is empty.
-    sampleQueues[1]->enqueue(newSampleWithPtsOnly(kMaxDivergenceUs));
-    EXPECT_EQ(mTestMuxer->popEvent(true), TestMuxer::WriteSampleWithPts(1, kMaxDivergenceUs));
-    EXPECT_EQ(mTestMuxer->popEvent(true), TestMuxer::WriteSampleWithPts(0, kMaxDivergenceUs + 2));
-
-    // Both tracks are now empty so the writer should wait for track 1 which is farthest behind.
-    sampleQueues[1]->enqueue(newSampleWithPtsOnly(kMaxDivergenceUs + 3));
-    EXPECT_EQ(mTestMuxer->popEvent(true), TestMuxer::WriteSampleWithPts(1, kMaxDivergenceUs + 3));
-
-    for (int trackIndex = 0; trackIndex < kNumTracks; ++trackIndex) {
-        sampleQueues[trackIndex]->enqueue(newSampleEos());
-    }
-
-    // Wait for writer to complete.
-    mTestCallbacks->waitForWritingFinished();
-
-    // Verify EOS samples.
-    for (int trackIndex = 0; trackIndex < kNumTracks; ++trackIndex) {
-        auto trackFormat = mediaSource.mTrackFormats[trackIndex % mediaSource.mTrackCount];
-        int64_t duration = 0;
-        AMediaFormat_getInt64(trackFormat.get(), AMEDIAFORMAT_KEY_DURATION, &duration);
-
-        // EOS timestamp = first sample timestamp + duration.
-        const int64_t endTime = duration + (trackIndex == 1 ? 1 : 0);
-        const AMediaCodecBufferInfo info = {0, 0, endTime, AMEDIACODEC_BUFFER_FLAG_END_OF_STREAM};
-        EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::WriteSample(trackIndex, nullptr, &info));
-    }
-
-    EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::Stop());
-    EXPECT_TRUE(writer.stop());
-    EXPECT_TRUE(mTestCallbacks->hasFinished());
-}
-
-TEST_F(MediaSampleWriterTests, TestTimestampDivergenceOverflow) {
-    auto testCallbacks = std::make_shared<TestCallbacks>(false /* expectSuccess */);
-    MediaSampleWriter writer{};
-    EXPECT_TRUE(writer.init(mTestMuxer, testCallbacks));
-
-    // Use two tracks for this test.
-    static constexpr int kNumTracks = 2;
-    std::shared_ptr<MediaSampleQueue> sampleQueues[kNumTracks];
-    const TestMediaSource& mediaSource = getMediaSource();
-
-    for (int trackIdx = 0; trackIdx < kNumTracks; ++trackIdx) {
-        sampleQueues[trackIdx] = std::make_shared<MediaSampleQueue>();
-
-        auto trackFormat = mediaSource.mTrackFormats[trackIdx % mediaSource.mTrackCount];
-        EXPECT_TRUE(writer.addTrack(sampleQueues[trackIdx], trackFormat));
-        EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::AddTrack(trackFormat.get()));
-    }
-
-    // Prime track 0 with lower end of INT64 range, and track 1 with positive timestamps making the
-    // difference larger than INT64_MAX.
-    sampleQueues[0]->enqueue(newSampleWithPtsOnly(INT64_MIN + 1));
-    sampleQueues[1]->enqueue(newSampleWithPtsOnly(1000));
-    sampleQueues[1]->enqueue(newSampleWithPtsOnly(1001));
-
-    ASSERT_TRUE(writer.start());
-    EXPECT_EQ(mTestMuxer->popEvent(true), TestMuxer::Start());
-
-    // The first sample of each track can be pulled in any order.
-    mTestMuxer->popEvent(true);
-    mTestMuxer->popEvent(true);
-
-    // Wait to make sure the writer compares track 0 empty against track 1 non-empty. The writer
-    // should handle the large timestamp differences and chose to wait for track 0 even though
-    // track 1 has a sample ready.
-    std::this_thread::sleep_for(std::chrono::milliseconds(20));
-
-    sampleQueues[0]->enqueue(newSampleWithPtsOnly(INT64_MIN + 2));
-    sampleQueues[0]->enqueue(newSampleWithPtsOnly(1000));  // <-- Close the gap between the tracks.
-    EXPECT_EQ(mTestMuxer->popEvent(true), TestMuxer::WriteSampleWithPts(0, INT64_MIN + 2));
-    EXPECT_EQ(mTestMuxer->popEvent(true), TestMuxer::WriteSampleWithPts(0, 1000));
-    EXPECT_EQ(mTestMuxer->popEvent(true), TestMuxer::WriteSampleWithPts(1, 1001));
-
-    EXPECT_TRUE(writer.stop());
-    EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::Stop());
-    EXPECT_TRUE(testCallbacks->hasFinished());
-}
-
-TEST_F(MediaSampleWriterTests, TestAbortInputQueue) {
-    MediaSampleWriter writer{};
-    std::shared_ptr<TestCallbacks> callbacks =
-            std::make_shared<TestCallbacks>(false /* expectSuccess */);
-    EXPECT_TRUE(writer.init(mTestMuxer, callbacks));
-
-    // Use two tracks for this test.
-    static constexpr int kNumTracks = 2;
-    std::shared_ptr<MediaSampleQueue> sampleQueues[kNumTracks];
-    const TestMediaSource& mediaSource = getMediaSource();
-
-    for (int trackIdx = 0; trackIdx < kNumTracks; ++trackIdx) {
-        sampleQueues[trackIdx] = std::make_shared<MediaSampleQueue>();
-
-        auto trackFormat = mediaSource.mTrackFormats[trackIdx % mediaSource.mTrackCount];
-        EXPECT_TRUE(writer.addTrack(sampleQueues[trackIdx], trackFormat));
-        EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::AddTrack(trackFormat.get()));
-    }
-
-    // Start the writer.
-    ASSERT_TRUE(writer.start());
-
-    // Abort the input queues and wait for the writer to complete.
-    for (int trackIdx = 0; trackIdx < kNumTracks; ++trackIdx) {
-        sampleQueues[trackIdx]->abort();
-    }
-
-    callbacks->waitForWritingFinished();
-
-    EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::Start());
-    EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::Stop());
-    EXPECT_TRUE(writer.stop());
-}
-
 // Convenience function for reading a sample from an AMediaExtractor represented as a MediaSample.
 static std::shared_ptr<MediaSample> readSampleAndAdvance(AMediaExtractor* extractor,
                                                          size_t* trackIndexOut) {
@@ -667,36 +513,35 @@
     ASSERT_GT(destinationFd, 0);
 
     // Initialize writer.
-    MediaSampleWriter writer{};
-    EXPECT_TRUE(writer.init(destinationFd, mTestCallbacks));
+    std::shared_ptr<MediaSampleWriter> writer = MediaSampleWriter::Create();
+    EXPECT_TRUE(writer->init(destinationFd, mTestCallbacks));
     close(destinationFd);
 
     // Add tracks.
     const TestMediaSource& mediaSource = getMediaSource();
-    std::vector<std::shared_ptr<MediaSampleQueue>> inputQueues;
+    std::vector<MediaSampleWriter::MediaSampleConsumerFunction> sampleConsumers;
 
     for (size_t trackIndex = 0; trackIndex < mediaSource.mTrackCount; trackIndex++) {
-        inputQueues.push_back(std::make_shared<MediaSampleQueue>());
-        EXPECT_TRUE(
-                writer.addTrack(inputQueues[trackIndex], mediaSource.mTrackFormats[trackIndex]));
+        auto consumer = writer->addTrack(mediaSource.mTrackFormats[trackIndex]);
+        sampleConsumers.push_back(consumer);
     }
 
     // Start the writer.
-    ASSERT_TRUE(writer.start());
+    ASSERT_TRUE(writer->start());
 
     // Enqueue samples and finally End Of Stream.
     std::shared_ptr<MediaSample> sample;
     size_t trackIndex;
     while ((sample = readSampleAndAdvance(mediaSource.mExtractor, &trackIndex)) != nullptr) {
-        inputQueues[trackIndex]->enqueue(sample);
+        sampleConsumers[trackIndex](sample);
     }
     for (trackIndex = 0; trackIndex < mediaSource.mTrackCount; trackIndex++) {
-        inputQueues[trackIndex]->enqueue(newSampleEos());
+        sampleConsumers[trackIndex](newSampleEos());
     }
 
     // Wait for writer.
     mTestCallbacks->waitForWritingFinished();
-    EXPECT_TRUE(writer.stop());
+    EXPECT_TRUE(writer->stop());
 
     // Compare output file with source.
     mediaSource.reset();
diff --git a/media/libmediatranscoding/transcoder/tests/MediaTrackTranscoderTests.cpp b/media/libmediatranscoding/transcoder/tests/MediaTrackTranscoderTests.cpp
index a46c2bd..83f0a4a 100644
--- a/media/libmediatranscoding/transcoder/tests/MediaTrackTranscoderTests.cpp
+++ b/media/libmediatranscoding/transcoder/tests/MediaTrackTranscoderTests.cpp
@@ -60,7 +60,6 @@
             break;
         }
         ASSERT_NE(mTranscoder, nullptr);
-        mTranscoderOutputQueue = mTranscoder->getOutputQueue();
 
         initSampleReader();
     }
@@ -115,34 +114,29 @@
     }
 
     // Drains the transcoder's output queue in a loop.
-    void drainOutputSampleQueue() {
-        mSampleQueueDrainThread = std::thread{[this] {
-            std::shared_ptr<MediaSample> sample;
-            bool aborted = false;
-            do {
-                aborted = mTranscoderOutputQueue->dequeue(&sample);
-            } while (!aborted && !(sample->info.flags & SAMPLE_FLAG_END_OF_STREAM));
-            mQueueWasAborted = aborted;
-            mGotEndOfStream =
-                    sample != nullptr && (sample->info.flags & SAMPLE_FLAG_END_OF_STREAM) != 0;
-        }};
+    void drainOutputSamples(int numSamplesToSave = 0) {
+        mTranscoder->setSampleConsumer(
+                [this, numSamplesToSave](const std::shared_ptr<MediaSample>& sample) {
+                    ASSERT_NE(sample, nullptr);
+
+                    mGotEndOfStream = (sample->info.flags & SAMPLE_FLAG_END_OF_STREAM) != 0;
+
+                    if (mSavedSamples.size() < numSamplesToSave) {
+                        mSavedSamples.push_back(sample);
+                    }
+
+                    if (mSavedSamples.size() == numSamplesToSave || mGotEndOfStream) {
+                        mSamplesSavedSemaphore.signal();
+                    }
+                });
     }
 
-    void joinDrainThread() {
-        if (mSampleQueueDrainThread.joinable()) {
-            mSampleQueueDrainThread.join();
-        }
-    }
-    void TearDown() override {
-        LOG(DEBUG) << "MediaTrackTranscoderTests tear down";
-        joinDrainThread();
-    }
+    void TearDown() override { LOG(DEBUG) << "MediaTrackTranscoderTests tear down"; }
 
     ~MediaTrackTranscoderTests() { LOG(DEBUG) << "MediaTrackTranscoderTests destroyed"; }
 
 protected:
     std::shared_ptr<MediaTrackTranscoder> mTranscoder;
-    std::shared_ptr<MediaSampleQueue> mTranscoderOutputQueue;
     std::shared_ptr<TestCallback> mCallback;
 
     std::shared_ptr<MediaSampleReader> mMediaSampleReader;
@@ -151,8 +145,8 @@
     std::shared_ptr<AMediaFormat> mSourceFormat;
     std::shared_ptr<AMediaFormat> mDestinationFormat;
 
-    std::thread mSampleQueueDrainThread;
-    bool mQueueWasAborted = false;
+    std::vector<std::shared_ptr<MediaSample>> mSavedSamples;
+    OneShotSemaphore mSamplesSavedSemaphore;
     bool mGotEndOfStream = false;
 };
 
@@ -161,11 +155,9 @@
     EXPECT_EQ(mTranscoder->configure(mMediaSampleReader, mTrackIndex, mDestinationFormat),
               AMEDIA_OK);
     ASSERT_TRUE(mTranscoder->start());
-    drainOutputSampleQueue();
+    drainOutputSamples();
     EXPECT_EQ(mCallback->waitUntilFinished(), AMEDIA_OK);
-    joinDrainThread();
     EXPECT_TRUE(mTranscoder->stop());
-    EXPECT_FALSE(mQueueWasAborted);
     EXPECT_TRUE(mGotEndOfStream);
 }
 
@@ -229,49 +221,27 @@
     EXPECT_EQ(mTranscoder->configure(mMediaSampleReader, mTrackIndex, mDestinationFormat),
               AMEDIA_OK);
     ASSERT_TRUE(mTranscoder->start());
-    drainOutputSampleQueue();
+    drainOutputSamples();
     EXPECT_EQ(mCallback->waitUntilFinished(), AMEDIA_OK);
-    joinDrainThread();
     EXPECT_TRUE(mTranscoder->stop());
     EXPECT_FALSE(mTranscoder->start());
-    EXPECT_FALSE(mQueueWasAborted);
     EXPECT_TRUE(mGotEndOfStream);
 }
 
-TEST_P(MediaTrackTranscoderTests, AbortOutputQueue) {
-    LOG(DEBUG) << "Testing AbortOutputQueue";
-    EXPECT_EQ(mTranscoder->configure(mMediaSampleReader, mTrackIndex, mDestinationFormat),
-              AMEDIA_OK);
-    ASSERT_TRUE(mTranscoder->start());
-    mTranscoderOutputQueue->abort();
-    drainOutputSampleQueue();
-    EXPECT_EQ(mCallback->waitUntilFinished(), AMEDIA_ERROR_IO);
-    joinDrainThread();
-    EXPECT_TRUE(mTranscoder->stop());
-    EXPECT_TRUE(mQueueWasAborted);
-    EXPECT_FALSE(mGotEndOfStream);
-}
-
 TEST_P(MediaTrackTranscoderTests, HoldSampleAfterTranscoderRelease) {
     LOG(DEBUG) << "Testing HoldSampleAfterTranscoderRelease";
     EXPECT_EQ(mTranscoder->configure(mMediaSampleReader, mTrackIndex, mDestinationFormat),
               AMEDIA_OK);
     ASSERT_TRUE(mTranscoder->start());
-
-    std::shared_ptr<MediaSample> sample;
-    EXPECT_FALSE(mTranscoderOutputQueue->dequeue(&sample));
-
-    drainOutputSampleQueue();
+    drainOutputSamples(1 /* numSamplesToSave */);
     EXPECT_EQ(mCallback->waitUntilFinished(), AMEDIA_OK);
-    joinDrainThread();
     EXPECT_TRUE(mTranscoder->stop());
-    EXPECT_FALSE(mQueueWasAborted);
     EXPECT_TRUE(mGotEndOfStream);
 
     mTranscoder.reset();
-    mTranscoderOutputQueue.reset();
+
     std::this_thread::sleep_for(std::chrono::milliseconds(20));
-    sample.reset();
+    mSavedSamples.clear();
 }
 
 TEST_P(MediaTrackTranscoderTests, HoldSampleAfterTranscoderStop) {
@@ -279,13 +249,12 @@
     EXPECT_EQ(mTranscoder->configure(mMediaSampleReader, mTrackIndex, mDestinationFormat),
               AMEDIA_OK);
     ASSERT_TRUE(mTranscoder->start());
-
-    std::shared_ptr<MediaSample> sample;
-    EXPECT_FALSE(mTranscoderOutputQueue->dequeue(&sample));
+    drainOutputSamples(1 /* numSamplesToSave */);
+    mSamplesSavedSemaphore.wait();
     EXPECT_TRUE(mTranscoder->stop());
 
     std::this_thread::sleep_for(std::chrono::milliseconds(20));
-    sample.reset();
+    mSavedSamples.clear();
 }
 
 TEST_P(MediaTrackTranscoderTests, NullSampleReader) {
diff --git a/media/libmediatranscoding/transcoder/tests/PassthroughTrackTranscoderTests.cpp b/media/libmediatranscoding/transcoder/tests/PassthroughTrackTranscoderTests.cpp
index a2ffbe4..9713e17 100644
--- a/media/libmediatranscoding/transcoder/tests/PassthroughTrackTranscoderTests.cpp
+++ b/media/libmediatranscoding/transcoder/tests/PassthroughTrackTranscoderTests.cpp
@@ -165,21 +165,23 @@
     ASSERT_TRUE(transcoder.start());
 
     // Pull transcoder's output samples and compare against input checksums.
+    bool eos = false;
     uint64_t sampleCount = 0;
-    std::shared_ptr<MediaSample> sample;
-    std::shared_ptr<MediaSampleQueue> outputQueue = transcoder.getOutputQueue();
-    while (!outputQueue->dequeue(&sample)) {
-        ASSERT_NE(sample, nullptr);
+    transcoder.setSampleConsumer(
+            [&sampleCount, &sampleChecksums, &eos](const std::shared_ptr<MediaSample>& sample) {
+                ASSERT_NE(sample, nullptr);
+                EXPECT_FALSE(eos);
 
-        if (sample->info.flags & SAMPLE_FLAG_END_OF_STREAM) {
-            break;
-        }
+                if (sample->info.flags & SAMPLE_FLAG_END_OF_STREAM) {
+                    eos = true;
+                } else {
+                    SampleID sampleId{sample->buffer, static_cast<ssize_t>(sample->info.size)};
+                    EXPECT_TRUE(sampleId == sampleChecksums[sampleCount]);
+                    ++sampleCount;
+                }
+            });
 
-        SampleID sampleId{sample->buffer, static_cast<ssize_t>(sample->info.size)};
-        EXPECT_TRUE(sampleId == sampleChecksums[sampleCount]);
-        ++sampleCount;
-    }
-
+    callback->waitUntilFinished();
     EXPECT_EQ(sampleCount, sampleChecksums.size());
     EXPECT_TRUE(transcoder.stop());
 }
diff --git a/media/libmediatranscoding/transcoder/tests/TrackTranscoderTestUtils.h b/media/libmediatranscoding/transcoder/tests/TrackTranscoderTestUtils.h
index a3ddd71..8d05353 100644
--- a/media/libmediatranscoding/transcoder/tests/TrackTranscoderTestUtils.h
+++ b/media/libmediatranscoding/transcoder/tests/TrackTranscoderTestUtils.h
@@ -102,4 +102,25 @@
     bool mTrackFormatAvailable = false;
 };
 
+class OneShotSemaphore {
+public:
+    void wait() {
+        std::unique_lock<std::mutex> lock(mMutex);
+        while (!mSignaled) {
+            mCondition.wait(lock);
+        }
+    }
+
+    void signal() {
+        std::unique_lock<std::mutex> lock(mMutex);
+        mSignaled = true;
+        mCondition.notify_all();
+    }
+
+private:
+    std::mutex mMutex;
+    std::condition_variable mCondition;
+    bool mSignaled = false;
+};
+
 };  // namespace android
diff --git a/media/libmediatranscoding/transcoder/tests/VideoTrackTranscoderTests.cpp b/media/libmediatranscoding/transcoder/tests/VideoTrackTranscoderTests.cpp
index e809cbd..1b5bd13 100644
--- a/media/libmediatranscoding/transcoder/tests/VideoTrackTranscoderTests.cpp
+++ b/media/libmediatranscoding/transcoder/tests/VideoTrackTranscoderTests.cpp
@@ -102,46 +102,40 @@
               AMEDIA_OK);
     ASSERT_TRUE(transcoder->start());
 
-    std::shared_ptr<MediaSampleQueue> outputQueue = transcoder->getOutputQueue();
-    std::thread sampleConsumerThread{[&outputQueue] {
-        uint64_t sampleCount = 0;
-        std::shared_ptr<MediaSample> sample;
-        while (!outputQueue->dequeue(&sample)) {
-            ASSERT_NE(sample, nullptr);
-            const uint32_t flags = sample->info.flags;
+    bool eos = false;
+    uint64_t sampleCount = 0;
+    transcoder->setSampleConsumer([&sampleCount, &eos](const std::shared_ptr<MediaSample>& sample) {
+        ASSERT_NE(sample, nullptr);
+        const uint32_t flags = sample->info.flags;
 
-            if (sampleCount == 0) {
-                // Expect first sample to be a codec config.
-                EXPECT_TRUE((flags & SAMPLE_FLAG_CODEC_CONFIG) != 0);
-                EXPECT_TRUE((flags & SAMPLE_FLAG_SYNC_SAMPLE) == 0);
-                EXPECT_TRUE((flags & SAMPLE_FLAG_END_OF_STREAM) == 0);
-                EXPECT_TRUE((flags & SAMPLE_FLAG_PARTIAL_FRAME) == 0);
-            } else if (sampleCount == 1) {
-                // Expect second sample to be a sync sample.
-                EXPECT_TRUE((flags & SAMPLE_FLAG_CODEC_CONFIG) == 0);
-                EXPECT_TRUE((flags & SAMPLE_FLAG_SYNC_SAMPLE) != 0);
-                EXPECT_TRUE((flags & SAMPLE_FLAG_END_OF_STREAM) == 0);
-            }
-
-            if (!(flags & SAMPLE_FLAG_END_OF_STREAM)) {
-                // Expect a valid buffer unless it is EOS.
-                EXPECT_NE(sample->buffer, nullptr);
-                EXPECT_NE(sample->bufferId, 0xBAADF00D);
-                EXPECT_GT(sample->info.size, 0);
-            }
-
-            ++sampleCount;
-            if (sample->info.flags & SAMPLE_FLAG_END_OF_STREAM) {
-                break;
-            }
-            sample.reset();
+        if (sampleCount == 0) {
+            // Expect first sample to be a codec config.
+            EXPECT_TRUE((flags & SAMPLE_FLAG_CODEC_CONFIG) != 0);
+            EXPECT_TRUE((flags & SAMPLE_FLAG_SYNC_SAMPLE) == 0);
+            EXPECT_TRUE((flags & SAMPLE_FLAG_END_OF_STREAM) == 0);
+            EXPECT_TRUE((flags & SAMPLE_FLAG_PARTIAL_FRAME) == 0);
+        } else if (sampleCount == 1) {
+            // Expect second sample to be a sync sample.
+            EXPECT_TRUE((flags & SAMPLE_FLAG_CODEC_CONFIG) == 0);
+            EXPECT_TRUE((flags & SAMPLE_FLAG_SYNC_SAMPLE) != 0);
+            EXPECT_TRUE((flags & SAMPLE_FLAG_END_OF_STREAM) == 0);
         }
-    }};
+
+        if (!(flags & SAMPLE_FLAG_END_OF_STREAM)) {
+            // Expect a valid buffer unless it is EOS.
+            EXPECT_NE(sample->buffer, nullptr);
+            EXPECT_NE(sample->bufferId, 0xBAADF00D);
+            EXPECT_GT(sample->info.size, 0);
+        } else {
+            EXPECT_FALSE(eos);
+            eos = true;
+        }
+
+        ++sampleCount;
+    });
 
     EXPECT_EQ(callback->waitUntilFinished(), AMEDIA_OK);
     EXPECT_TRUE(transcoder->stop());
-
-    sampleConsumerThread.join();
 }
 
 TEST_F(VideoTrackTranscoderTests, PreserveBitrate) {
@@ -167,7 +161,6 @@
     ASSERT_NE(outputFormat, nullptr);
 
     ASSERT_TRUE(transcoder->stop());
-    transcoder->getOutputQueue()->abort();
 
     int32_t outBitrate;
     EXPECT_TRUE(AMediaFormat_getInt32(outputFormat.get(), AMEDIAFORMAT_KEY_BIT_RATE, &outBitrate));
@@ -187,25 +180,7 @@
 }
 
 TEST_F(VideoTrackTranscoderTests, LingeringEncoder) {
-    struct {
-        void wait() {
-            std::unique_lock<std::mutex> lock(mMutex);
-            while (!mSignaled) {
-                mCondition.wait(lock);
-            }
-        }
-
-        void signal() {
-            std::unique_lock<std::mutex> lock(mMutex);
-            mSignaled = true;
-            mCondition.notify_all();
-        }
-
-        std::mutex mMutex;
-        std::condition_variable mCondition;
-        bool mSignaled = false;
-    } semaphore;
-
+    OneShotSemaphore semaphore;
     auto callback = std::make_shared<TestCallback>();
     auto transcoder = VideoTrackTranscoder::create(callback);
 
@@ -214,29 +189,24 @@
               AMEDIA_OK);
     ASSERT_TRUE(transcoder->start());
 
-    std::shared_ptr<MediaSampleQueue> outputQueue = transcoder->getOutputQueue();
     std::vector<std::shared_ptr<MediaSample>> samples;
-    std::thread sampleConsumerThread([&outputQueue, &samples, &semaphore] {
-        std::shared_ptr<MediaSample> sample;
-        while (samples.size() < 4 && !outputQueue->dequeue(&sample)) {
-            ASSERT_NE(sample, nullptr);
-            samples.push_back(sample);
+    transcoder->setSampleConsumer(
+            [&samples, &semaphore](const std::shared_ptr<MediaSample>& sample) {
+                if (samples.size() >= 4) return;
 
-            if (sample->info.flags & SAMPLE_FLAG_END_OF_STREAM) {
-                break;
-            }
-            sample.reset();
-        }
+                ASSERT_NE(sample, nullptr);
+                samples.push_back(sample);
 
-        semaphore.signal();
-    });
+                if (samples.size() == 4 || sample->info.flags & SAMPLE_FLAG_END_OF_STREAM) {
+                    semaphore.signal();
+                }
+            });
 
     // Wait for the encoder to output samples before stopping and releasing the transcoder.
     semaphore.wait();
 
     EXPECT_TRUE(transcoder->stop());
     transcoder.reset();
-    sampleConsumerThread.join();
 
     // Return buffers to the codec so that it can resume processing, but keep one buffer to avoid
     // the codec being released.