Transcoder: Add support for progress updates.

Progress updates are delivered as long as any track
in the file has a valid duration.

Fixes: 160277443
Test: MediaSampleWriter and MediaTranscoder unit tests.
Change-Id: I52bbf55cfd2445b98dfc4d9c9ae09bcf7de86a86
diff --git a/media/libmediatranscoding/transcoder/MediaSampleWriter.cpp b/media/libmediatranscoding/transcoder/MediaSampleWriter.cpp
index 91dbf78..3676d73 100644
--- a/media/libmediatranscoding/transcoder/MediaSampleWriter.cpp
+++ b/media/libmediatranscoding/transcoder/MediaSampleWriter.cpp
@@ -78,14 +78,14 @@
     }
 }
 
-bool MediaSampleWriter::init(int fd, const OnWritingFinishedCallback& callback) {
-    return init(DefaultMuxer::create(fd), callback);
+bool MediaSampleWriter::init(int fd, const std::weak_ptr<CallbackInterface>& callbacks) {
+    return init(DefaultMuxer::create(fd), callbacks);
 }
 
 bool MediaSampleWriter::init(const std::shared_ptr<MediaSampleWriterMuxerInterface>& muxer,
-                             const OnWritingFinishedCallback& callback) {
-    if (callback == nullptr) {
-        LOG(ERROR) << "Callback cannot be null";
+                             const std::weak_ptr<CallbackInterface>& callbacks) {
+    if (callbacks.lock() == nullptr) {
+        LOG(ERROR) << "Callback object cannot be null";
         return false;
     } else if (muxer == nullptr) {
         LOG(ERROR) << "Muxer cannot be null";
@@ -100,7 +100,7 @@
 
     mState = INITIALIZED;
     mMuxer = muxer;
-    mWritingFinishedCallback = callback;
+    mCallbacks = callbacks;
     return true;
 }
 
@@ -127,7 +127,11 @@
         durationUs = 0;
     }
 
-    mTracks.emplace_back(sampleQueue, static_cast<size_t>(trackIndex), durationUs);
+    const char* mime = nullptr;
+    const bool isVideo = AMediaFormat_getString(trackFormat.get(), AMEDIAFORMAT_KEY_MIME, &mime) &&
+                         (strncmp(mime, "video/", 6) == 0);
+
+    mTracks.emplace_back(sampleQueue, static_cast<size_t>(trackIndex), durationUs, isVideo);
     return true;
 }
 
@@ -144,7 +148,9 @@
 
     mThread = std::thread([this] {
         media_status_t status = writeSamples();
-        mWritingFinishedCallback(status);
+        if (auto callbacks = mCallbacks.lock()) {
+            callbacks->onFinished(this, status);
+        }
     });
     mState = STARTED;
     return true;
@@ -191,6 +197,18 @@
     AMediaCodecBufferInfo bufferInfo;
     uint32_t segmentEndTimeUs = mTrackSegmentLengthUs;
     bool samplesLeft = true;
+    int32_t lastProgressUpdate = 0;
+
+    // Set the "primary" track that will be used to determine progress to the track with longest
+    // duration.
+    int primaryTrackIndex = -1;
+    int64_t longestDurationUs = 0;
+    for (int trackIndex = 0; trackIndex < mTracks.size(); ++trackIndex) {
+        if (mTracks[trackIndex].mDurationUs > longestDurationUs) {
+            primaryTrackIndex = trackIndex;
+            longestDurationUs = mTracks[trackIndex].mDurationUs;
+        }
+    }
 
     while (samplesLeft) {
         samplesLeft = false;
@@ -216,9 +234,10 @@
                     samplesLeft = true;
                 }
 
-                // Record the first sample's timestamp in order to translate duration to EOS time
-                // for tracks that does not start at 0.
+                track.mPrevSampleTimeUs = sample->info.presentationTimeUs;
                 if (!track.mFirstSampleTimeSet) {
+                    // Record the first sample's timestamp in order to translate duration to EOS
+                    // time for tracks that does not start at 0.
                     track.mFirstSampleTimeUs = sample->info.presentationTimeUs;
                     track.mFirstSampleTimeSet = true;
                 }
@@ -238,6 +257,22 @@
             } while (sample->info.presentationTimeUs < segmentEndTimeUs && !track.mReachedEos);
         }
 
+        // TODO(lnilsson): Add option to toggle progress reporting on/off.
+        if (primaryTrackIndex >= 0) {
+            const TrackRecord& track = mTracks[primaryTrackIndex];
+
+            const int64_t elapsed = track.mPrevSampleTimeUs - track.mFirstSampleTimeUs;
+            int32_t progress = (elapsed * 100) / track.mDurationUs;
+            progress = std::clamp(progress, 0, 100);
+
+            if (progress > lastProgressUpdate) {
+                if (auto callbacks = mCallbacks.lock()) {
+                    callbacks->onProgressUpdate(this, progress);
+                }
+                lastProgressUpdate = progress;
+            }
+        }
+
         segmentEndTimeUs += mTrackSegmentLengthUs;
     }
 
diff --git a/media/libmediatranscoding/transcoder/MediaTranscoder.cpp b/media/libmediatranscoding/transcoder/MediaTranscoder.cpp
index bde1cf6..49bfdfe 100644
--- a/media/libmediatranscoding/transcoder/MediaTranscoder.cpp
+++ b/media/libmediatranscoding/transcoder/MediaTranscoder.cpp
@@ -151,11 +151,16 @@
     sendCallback(status);
 }
 
-void MediaTranscoder::onSampleWriterFinished(media_status_t status) {
+void MediaTranscoder::onFinished(const MediaSampleWriter* writer __unused, media_status_t status) {
     LOG((status != AMEDIA_OK) ? ERROR : DEBUG) << "Sample writer finished with status " << status;
     sendCallback(status);
 }
 
+void MediaTranscoder::onProgressUpdate(const MediaSampleWriter* writer __unused, int32_t progress) {
+    // Dispatch progress updated to the client.
+    mCallbacks->onProgressUpdate(this, progress);
+}
+
 MediaTranscoder::MediaTranscoder(const std::shared_ptr<CallbackInterface>& callbacks)
       : mCallbacks(callbacks) {}
 
@@ -288,8 +293,7 @@
     }
 
     mSampleWriter = std::make_unique<MediaSampleWriter>();
-    const bool initOk = mSampleWriter->init(
-            fd, std::bind(&MediaTranscoder::onSampleWriterFinished, this, std::placeholders::_1));
+    const bool initOk = mSampleWriter->init(fd, shared_from_this());
 
     if (!initOk) {
         LOG(ERROR) << "Unable to initialize sample writer with destination fd: " << fd;
diff --git a/media/libmediatranscoding/transcoder/include/media/MediaSampleWriter.h b/media/libmediatranscoding/transcoder/include/media/MediaSampleWriter.h
index d971f3e..92ddc2f 100644
--- a/media/libmediatranscoding/transcoder/include/media/MediaSampleWriter.h
+++ b/media/libmediatranscoding/transcoder/include/media/MediaSampleWriter.h
@@ -71,18 +71,27 @@
     /** The default segment length. */
     static constexpr uint32_t kDefaultTrackSegmentLengthUs = 1 * 1000 * 1000;  // 1 sec.
 
-    /** Client callback for when the writer is finished. */
-    using OnWritingFinishedCallback = std::function<void(media_status_t)>;
+    /** Callback interface. */
+    class CallbackInterface {
+    public:
+        /**
+         * Sample writer finished. The finished callback is only called after the sample writer has
+         * been successfully started.
+         */
+        virtual void onFinished(const MediaSampleWriter* writer, media_status_t status) = 0;
+
+        /** Sample writer progress update in percent. */
+        virtual void onProgressUpdate(const MediaSampleWriter* writer, int32_t progress) = 0;
+
+        virtual ~CallbackInterface() = default;
+    };
 
     /**
      * Constructor with custom segment length.
      * @param trackSegmentLengthUs The segment length to use for this MediaSampleWriter.
      */
     MediaSampleWriter(uint32_t trackSegmentLengthUs)
-          : mTrackSegmentLengthUs(trackSegmentLengthUs),
-            mWritingFinishedCallback(nullptr),
-            mMuxer(nullptr),
-            mState(UNINITIALIZED){};
+          : mTrackSegmentLengthUs(trackSegmentLengthUs), mMuxer(nullptr), mState(UNINITIALIZED){};
 
     /** Constructor using the default segment length. */
     MediaSampleWriter() : MediaSampleWriter(kDefaultTrackSegmentLengthUs){};
@@ -95,21 +104,19 @@
      * to be initialized before tracks are added and can only be initialized once.
      * @param fd An open file descriptor to write to. The caller is responsible for closing this
      *        file descriptor and it is safe to do so once this method returns.
-     * @param callback Client callback that gets called when the sample writer has finished, after
-     *        it was successfully started.
+     * @param callbacks Client callback object that gets called by the sample writer.
      * @return True if the writer was successfully initialized.
      */
-    bool init(int fd, const OnWritingFinishedCallback& callback /* nonnull */);
+    bool init(int fd, const std::weak_ptr<CallbackInterface>& callbacks /* nonnull */);
 
     /**
      * Initializes the sample writer with a custom muxer interface implementation.
      * @param muxer The custom muxer interface implementation.
-     * @param callback Client callback that gets called when the sample writer has finished, after
-     *        it was successfully started.
+     * @param @param callbacks Client callback object that gets called by the sample writer.
      * @return True if the writer was successfully initialized.
      */
     bool init(const std::shared_ptr<MediaSampleWriterMuxerInterface>& muxer /* nonnull */,
-              const OnWritingFinishedCallback& callback /* nonnull */);
+              const std::weak_ptr<CallbackInterface>& callbacks /* nonnull */);
 
     /**
      * Adds a new track to the sample writer. Tracks must be added after the sample writer has been
@@ -145,24 +152,28 @@
 
     struct TrackRecord {
         TrackRecord(const std::shared_ptr<MediaSampleQueue>& sampleQueue, size_t trackIndex,
-                    int64_t durationUs)
+                    int64_t durationUs, bool isVideo)
               : mSampleQueue(sampleQueue),
                 mTrackIndex(trackIndex),
                 mDurationUs(durationUs),
                 mFirstSampleTimeUs(0),
+                mPrevSampleTimeUs(0),
                 mFirstSampleTimeSet(false),
-                mReachedEos(false) {}
+                mReachedEos(false),
+                mIsVideo(isVideo) {}
 
         std::shared_ptr<MediaSampleQueue> mSampleQueue;
         const size_t mTrackIndex;
         int64_t mDurationUs;
         int64_t mFirstSampleTimeUs;
+        int64_t mPrevSampleTimeUs;
         bool mFirstSampleTimeSet;
         bool mReachedEos;
+        bool mIsVideo;
     };
 
     const uint32_t mTrackSegmentLengthUs;
-    OnWritingFinishedCallback mWritingFinishedCallback;
+    std::weak_ptr<CallbackInterface> mCallbacks;
     std::shared_ptr<MediaSampleWriterMuxerInterface> mMuxer;
     std::vector<TrackRecord> mTracks;
     std::thread mThread;
diff --git a/media/libmediatranscoding/transcoder/include/media/MediaTranscoder.h b/media/libmediatranscoding/transcoder/include/media/MediaTranscoder.h
index 7a36c8c..33bd9d4 100644
--- a/media/libmediatranscoding/transcoder/include/media/MediaTranscoder.h
+++ b/media/libmediatranscoding/transcoder/include/media/MediaTranscoder.h
@@ -17,6 +17,9 @@
 #ifndef ANDROID_MEDIA_TRANSCODER_H
 #define ANDROID_MEDIA_TRANSCODER_H
 
+#include <binder/Parcel.h>
+#include <binder/Parcelable.h>
+#include <media/MediaSampleWriter.h>
 #include <media/MediaTrackTranscoderCallback.h>
 #include <media/NdkMediaError.h>
 #include <media/NdkMediaFormat.h>
@@ -30,11 +33,11 @@
 namespace android {
 
 class MediaSampleReader;
-class MediaSampleWriter;
 class Parcel;
 
 class MediaTranscoder : public std::enable_shared_from_this<MediaTranscoder>,
-                        public MediaTrackTranscoderCallback {
+                        public MediaTrackTranscoderCallback,
+                        public MediaSampleWriter::CallbackInterface {
 public:
     /** Callbacks from transcoder to client. */
     class CallbackInterface {
@@ -126,6 +129,12 @@
     virtual void onTrackError(const MediaTrackTranscoder* transcoder,
                               media_status_t status) override;
     // ~MediaTrackTranscoderCallback
+
+    // MediaSampleWriter::CallbackInterface
+    virtual void onFinished(const MediaSampleWriter* writer, media_status_t status) override;
+    virtual void onProgressUpdate(const MediaSampleWriter* writer, int32_t progress) override;
+    // ~MediaSampleWriter::CallbackInterface
+
     void onSampleWriterFinished(media_status_t status);
     void sendCallback(media_status_t status);
 
diff --git a/media/libmediatranscoding/transcoder/tests/MediaSampleWriterTests.cpp b/media/libmediatranscoding/transcoder/tests/MediaSampleWriterTests.cpp
index e3cb192..c82ec28 100644
--- a/media/libmediatranscoding/transcoder/tests/MediaSampleWriterTests.cpp
+++ b/media/libmediatranscoding/transcoder/tests/MediaSampleWriterTests.cpp
@@ -32,28 +32,6 @@
 
 namespace android {
 
-/** Minimal one-shot semaphore */
-class SimpleSemaphore {
-public:
-    void signal() {
-        std::unique_lock<std::mutex> lock(mMutex);
-        mSignaled = true;
-        mCondition.notify_all();
-    }
-
-    void wait() {
-        std::unique_lock<std::mutex> lock(mMutex);
-        while (!mSignaled) {
-            mCondition.wait(lock);
-        }
-    }
-
-private:
-    std::mutex mMutex;
-    std::condition_variable mCondition;
-    bool mSignaled = false;
-};
-
 /** Muxer interface to enable MediaSampleWriter testing. */
 class TestMuxer : public MediaSampleWriterMuxerInterface {
 public:
@@ -151,11 +129,22 @@
         for (size_t trackIndex = 0; trackIndex < mTrackCount; trackIndex++) {
             AMediaFormat* trackFormat = AMediaExtractor_getTrackFormat(mExtractor, trackIndex);
             ASSERT_NE(trackFormat, nullptr);
+
+            const char* mime = nullptr;
+            AMediaFormat_getString(trackFormat, AMEDIAFORMAT_KEY_MIME, &mime);
+            if (strncmp(mime, "video/", 6) == 0) {
+                mVideoTrackIndex = trackIndex;
+            } else if (strncmp(mime, "audio/", 6) == 0) {
+                mAudioTrackIndex = trackIndex;
+            }
+
             mTrackFormats.push_back(
                     std::shared_ptr<AMediaFormat>(trackFormat, &AMediaFormat_delete));
 
             AMediaExtractor_selectTrack(mExtractor, trackIndex);
         }
+        EXPECT_GE(mVideoTrackIndex, 0);
+        EXPECT_GE(mAudioTrackIndex, 0);
     }
 
     void reset() const {
@@ -167,6 +156,60 @@
     AMediaExtractor* mExtractor = nullptr;
     size_t mTrackCount = 0;
     std::vector<std::shared_ptr<AMediaFormat>> mTrackFormats;
+    int mVideoTrackIndex = -1;
+    int mAudioTrackIndex = -1;
+};
+
+class TestCallbacks : public MediaSampleWriter::CallbackInterface {
+public:
+    TestCallbacks(bool expectSuccess = true) : mExpectSuccess(expectSuccess) {}
+
+    bool hasFinished() {
+        std::unique_lock<std::mutex> lock(mMutex);
+        return mFinished;
+    }
+
+    // MediaSampleWriter::CallbackInterface
+    virtual void onFinished(const MediaSampleWriter* writer __unused,
+                            media_status_t status) override {
+        std::unique_lock<std::mutex> lock(mMutex);
+        EXPECT_FALSE(mFinished);
+        if (mExpectSuccess) {
+            EXPECT_EQ(status, AMEDIA_OK);
+        } else {
+            EXPECT_NE(status, AMEDIA_OK);
+        }
+        mFinished = true;
+        mCondition.notify_all();
+    }
+
+    virtual void onProgressUpdate(const MediaSampleWriter* writer __unused,
+                                  int32_t progress) override {
+        EXPECT_GT(progress, mLastProgress);
+        EXPECT_GE(progress, 0);
+        EXPECT_LE(progress, 100);
+
+        mLastProgress = progress;
+        mProgressUpdateCount++;
+    }
+    // ~MediaSampleWriter::CallbackInterface
+
+    void waitForWritingFinished() {
+        std::unique_lock<std::mutex> lock(mMutex);
+        while (!mFinished) {
+            mCondition.wait(lock);
+        }
+    }
+
+    uint32_t getProgressUpdateCount() const { return mProgressUpdateCount; }
+
+private:
+    std::mutex mMutex;
+    std::condition_variable mCondition;
+    bool mFinished = false;
+    bool mExpectSuccess;
+    int32_t mLastProgress = -1;
+    uint32_t mProgressUpdateCount = 0;
 };
 
 class MediaSampleWriterTests : public ::testing::Test {
@@ -222,7 +265,7 @@
 protected:
     std::shared_ptr<TestMuxer> mTestMuxer;
     std::shared_ptr<MediaSampleQueue> mSampleQueue;
-    const MediaSampleWriter::OnWritingFinishedCallback mEmptyCallback = [](media_status_t) {};
+    std::shared_ptr<TestCallbacks> mTestCallbacks = std::make_shared<TestCallbacks>();
 };
 
 TEST_F(MediaSampleWriterTests, TestAddTrackWithoutInit) {
@@ -239,14 +282,14 @@
 
 TEST_F(MediaSampleWriterTests, TestStartWithoutTracks) {
     MediaSampleWriter writer{};
-    EXPECT_TRUE(writer.init(mTestMuxer, mEmptyCallback));
+    EXPECT_TRUE(writer.init(mTestMuxer, mTestCallbacks));
     EXPECT_FALSE(writer.start());
     EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::NoEvent);
 }
 
 TEST_F(MediaSampleWriterTests, TestAddInvalidTrack) {
     MediaSampleWriter writer{};
-    EXPECT_TRUE(writer.init(mTestMuxer, mEmptyCallback));
+    EXPECT_TRUE(writer.init(mTestMuxer, mTestCallbacks));
 
     EXPECT_FALSE(writer.addTrack(mSampleQueue, nullptr));
     EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::NoEvent);
@@ -259,31 +302,25 @@
 TEST_F(MediaSampleWriterTests, TestDoubleStartStop) {
     MediaSampleWriter writer{};
 
-    bool callbackFired = false;
-    MediaSampleWriter::OnWritingFinishedCallback stoppedCallback =
-            [&callbackFired](media_status_t status) {
-                EXPECT_NE(status, AMEDIA_OK);
-                EXPECT_FALSE(callbackFired);
-                callbackFired = true;
-            };
-
-    EXPECT_TRUE(writer.init(mTestMuxer, stoppedCallback));
+    std::shared_ptr<TestCallbacks> callbacks =
+            std::make_shared<TestCallbacks>(false /* expectSuccess */);
+    EXPECT_TRUE(writer.init(mTestMuxer, callbacks));
 
     const TestMediaSource& mediaSource = getMediaSource();
     EXPECT_TRUE(writer.addTrack(mSampleQueue, mediaSource.mTrackFormats[0]));
     EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::AddTrack(mediaSource.mTrackFormats[0].get()));
 
-    EXPECT_TRUE(writer.start());
+    ASSERT_TRUE(writer.start());
     EXPECT_FALSE(writer.start());
 
     EXPECT_TRUE(writer.stop());
-    EXPECT_TRUE(callbackFired);
+    EXPECT_TRUE(callbacks->hasFinished());
     EXPECT_FALSE(writer.stop());
 }
 
 TEST_F(MediaSampleWriterTests, TestStopWithoutStart) {
     MediaSampleWriter writer{};
-    EXPECT_TRUE(writer.init(mTestMuxer, mEmptyCallback));
+    EXPECT_TRUE(writer.init(mTestMuxer, mTestCallbacks));
 
     const TestMediaSource& mediaSource = getMediaSource();
     EXPECT_TRUE(writer.addTrack(mSampleQueue, mediaSource.mTrackFormats[0]));
@@ -295,22 +332,48 @@
 
 TEST_F(MediaSampleWriterTests, TestStartWithoutCallback) {
     MediaSampleWriter writer{};
-    EXPECT_FALSE(writer.init(mTestMuxer, nullptr));
+
+    std::weak_ptr<MediaSampleWriter::CallbackInterface> unassignedWp;
+    EXPECT_FALSE(writer.init(mTestMuxer, unassignedWp));
+
+    std::shared_ptr<MediaSampleWriter::CallbackInterface> unassignedSp;
+    EXPECT_FALSE(writer.init(mTestMuxer, unassignedSp));
 
     const TestMediaSource& mediaSource = getMediaSource();
     EXPECT_FALSE(writer.addTrack(mSampleQueue, mediaSource.mTrackFormats[0]));
     ASSERT_FALSE(writer.start());
 }
 
+TEST_F(MediaSampleWriterTests, TestProgressUpdate) {
+    static constexpr uint32_t kSegmentLengthUs = 1;
+    const TestMediaSource& mediaSource = getMediaSource();
+
+    MediaSampleWriter writer{kSegmentLengthUs};
+    EXPECT_TRUE(writer.init(mTestMuxer, mTestCallbacks));
+
+    std::shared_ptr<AMediaFormat> videoFormat =
+            std::shared_ptr<AMediaFormat>(AMediaFormat_new(), &AMediaFormat_delete);
+    AMediaFormat_copy(videoFormat.get(),
+                      mediaSource.mTrackFormats[mediaSource.mVideoTrackIndex].get());
+
+    AMediaFormat_setInt64(videoFormat.get(), AMEDIAFORMAT_KEY_DURATION, 100);
+    EXPECT_TRUE(writer.addTrack(mSampleQueue, videoFormat));
+    ASSERT_TRUE(writer.start());
+
+    for (int64_t pts = 0; pts < 100; ++pts) {
+        mSampleQueue->enqueue(newSampleWithPts(pts));
+    }
+    mSampleQueue->enqueue(newSampleEos());
+    mTestCallbacks->waitForWritingFinished();
+
+    EXPECT_EQ(mTestCallbacks->getProgressUpdateCount(), 100);
+}
+
 TEST_F(MediaSampleWriterTests, TestInterleaving) {
     static constexpr uint32_t kSegmentLength = MediaSampleWriter::kDefaultTrackSegmentLengthUs;
-    SimpleSemaphore semaphore;
 
     MediaSampleWriter writer{kSegmentLength};
-    EXPECT_TRUE(writer.init(mTestMuxer, [&semaphore](media_status_t status) {
-        EXPECT_EQ(status, AMEDIA_OK);
-        semaphore.signal();
-    }));
+    EXPECT_TRUE(writer.init(mTestMuxer, mTestCallbacks));
 
     // Use two tracks for this test.
     static constexpr int kNumTracks = 2;
@@ -356,7 +419,7 @@
     ASSERT_TRUE(writer.start());
 
     // Wait for writer to complete.
-    semaphore.wait();
+    mTestCallbacks->waitForWritingFinished();
     EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::Start());
 
     // Verify sample order.
@@ -386,16 +449,14 @@
 
     EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::Stop());
     EXPECT_TRUE(writer.stop());
+    EXPECT_TRUE(mTestCallbacks->hasFinished());
 }
 
 TEST_F(MediaSampleWriterTests, TestAbortInputQueue) {
-    SimpleSemaphore semaphore;
-
     MediaSampleWriter writer{};
-    EXPECT_TRUE(writer.init(mTestMuxer, [&semaphore](media_status_t status) {
-        EXPECT_NE(status, AMEDIA_OK);
-        semaphore.signal();
-    }));
+    std::shared_ptr<TestCallbacks> callbacks =
+            std::make_shared<TestCallbacks>(false /* expectSuccess */);
+    EXPECT_TRUE(writer.init(mTestMuxer, callbacks));
 
     // Use two tracks for this test.
     static constexpr int kNumTracks = 2;
@@ -417,7 +478,8 @@
     for (int trackIdx = 0; trackIdx < kNumTracks; ++trackIdx) {
         sampleQueues[trackIdx]->abort();
     }
-    semaphore.wait();
+
+    callbacks->waitForWritingFinished();
 
     EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::Start());
     EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::Stop());
@@ -465,12 +527,8 @@
     ASSERT_GT(destinationFd, 0);
 
     // Initialize writer.
-    SimpleSemaphore semaphore;
     MediaSampleWriter writer{};
-    EXPECT_TRUE(writer.init(destinationFd, [&semaphore](media_status_t status) {
-        EXPECT_EQ(status, AMEDIA_OK);
-        semaphore.signal();
-    }));
+    EXPECT_TRUE(writer.init(destinationFd, mTestCallbacks));
     close(destinationFd);
 
     // Add tracks.
@@ -497,7 +555,7 @@
     }
 
     // Wait for writer.
-    semaphore.wait();
+    mTestCallbacks->waitForWritingFinished();
     EXPECT_TRUE(writer.stop());
 
     // Compare output file with source.