diff --git a/audio/aidl/default/Module.cpp b/audio/aidl/default/Module.cpp
index a8f3b9b..9dbd61c 100644
--- a/audio/aidl/default/Module.cpp
+++ b/audio/aidl/default/Module.cpp
@@ -97,6 +97,7 @@
 }
 
 ndk::ScopedAStatus Module::createStreamContext(int32_t in_portConfigId, int64_t in_bufferSizeFrames,
+                                               std::shared_ptr<IStreamCallback> asyncCallback,
                                                StreamContext* out_context) {
     if (in_bufferSizeFrames <= 0) {
         LOG(ERROR) << __func__ << ": non-positive buffer size " << in_bufferSizeFrames;
@@ -136,7 +137,7 @@
                 std::make_unique<StreamContext::CommandMQ>(1, true /*configureEventFlagWord*/),
                 std::make_unique<StreamContext::ReplyMQ>(1, true /*configureEventFlagWord*/),
                 frameSize, std::make_unique<StreamContext::DataMQ>(frameSize * in_bufferSizeFrames),
-                mDebug.streamTransientStateDelayMs);
+                asyncCallback, mDebug.streamTransientStateDelayMs);
         if (temp.isValid()) {
             *out_context = std::move(temp);
         } else {
@@ -461,7 +462,8 @@
         return ndk::ScopedAStatus::fromExceptionCode(EX_ILLEGAL_ARGUMENT);
     }
     StreamContext context;
-    if (auto status = createStreamContext(in_args.portConfigId, in_args.bufferSizeFrames, &context);
+    if (auto status = createStreamContext(in_args.portConfigId, in_args.bufferSizeFrames, nullptr,
+                                          &context);
         !status.isOk()) {
         return status;
     }
@@ -501,8 +503,16 @@
                    << " has COMPRESS_OFFLOAD flag set, requires offload info";
         return ndk::ScopedAStatus::fromExceptionCode(EX_ILLEGAL_ARGUMENT);
     }
+    const bool isNonBlocking = isBitPositionFlagSet(port->flags.get<AudioIoFlags::Tag::output>(),
+                                                    AudioOutputFlags::NON_BLOCKING);
+    if (isNonBlocking && in_args.callback == nullptr) {
+        LOG(ERROR) << __func__ << ": port id " << port->id
+                   << " has NON_BLOCKING flag set, requires async callback";
+        return ndk::ScopedAStatus::fromExceptionCode(EX_ILLEGAL_ARGUMENT);
+    }
     StreamContext context;
-    if (auto status = createStreamContext(in_args.portConfigId, in_args.bufferSizeFrames, &context);
+    if (auto status = createStreamContext(in_args.portConfigId, in_args.bufferSizeFrames,
+                                          isNonBlocking ? in_args.callback : nullptr, &context);
         !status.isOk()) {
         return status;
     }
diff --git a/audio/aidl/default/Stream.cpp b/audio/aidl/default/Stream.cpp
index d1efb02..d7c352f 100644
--- a/audio/aidl/default/Stream.cpp
+++ b/audio/aidl/default/Stream.cpp
@@ -178,13 +178,18 @@
             }
             break;
         case Tag::drain:
-            if (mState == StreamDescriptor::State::ACTIVE) {
-                usleep(1000);  // Simulate a blocking call into the driver.
-                populateReply(&reply, mIsConnected);
-                // Can switch the state to ERROR if a driver error occurs.
-                mState = StreamDescriptor::State::DRAINING;
+            if (command.get<Tag::drain>() == StreamDescriptor::DrainMode::DRAIN_UNSPECIFIED) {
+                if (mState == StreamDescriptor::State::ACTIVE) {
+                    usleep(1000);  // Simulate a blocking call into the driver.
+                    populateReply(&reply, mIsConnected);
+                    // Can switch the state to ERROR if a driver error occurs.
+                    mState = StreamDescriptor::State::DRAINING;
+                } else {
+                    populateReplyWrongState(&reply, command);
+                }
             } else {
-                populateReplyWrongState(&reply, command);
+                LOG(WARNING) << __func__
+                             << ": invalid drain mode: " << toString(command.get<Tag::drain>());
             }
             break;
         case Tag::standby:
@@ -262,11 +267,31 @@
 const std::string StreamOutWorkerLogic::kThreadName = "writer";
 
 StreamOutWorkerLogic::Status StreamOutWorkerLogic::cycle() {
-    if (mState == StreamDescriptor::State::DRAINING) {
+    if (mState == StreamDescriptor::State::DRAINING ||
+        mState == StreamDescriptor::State::TRANSFERRING) {
         if (auto stateDurationMs = std::chrono::duration_cast<std::chrono::milliseconds>(
                     std::chrono::steady_clock::now() - mTransientStateStart);
             stateDurationMs >= mTransientStateDelayMs) {
-            mState = StreamDescriptor::State::IDLE;
+            if (mAsyncCallback == nullptr) {
+                // In blocking mode, mState can only be DRAINING.
+                mState = StreamDescriptor::State::IDLE;
+            } else {
+                // In a real implementation, the driver should notify the HAL about
+                // drain or transfer completion. In the stub, we switch unconditionally.
+                if (mState == StreamDescriptor::State::DRAINING) {
+                    mState = StreamDescriptor::State::IDLE;
+                    ndk::ScopedAStatus status = mAsyncCallback->onDrainReady();
+                    if (!status.isOk()) {
+                        LOG(ERROR) << __func__ << ": error from onDrainReady: " << status;
+                    }
+                } else {
+                    mState = StreamDescriptor::State::ACTIVE;
+                    ndk::ScopedAStatus status = mAsyncCallback->onTransferReady();
+                    if (!status.isOk()) {
+                        LOG(ERROR) << __func__ << ": error from onTransferReady: " << status;
+                    }
+                }
+            }
             if (mTransientStateDelayMs.count() != 0) {
                 LOG(DEBUG) << __func__ << ": switched to state " << toString(mState)
                            << " after a timeout";
@@ -298,40 +323,57 @@
         case Tag::getStatus:
             populateReply(&reply, mIsConnected);
             break;
-        case Tag::start:
+        case Tag::start: {
+            bool commandAccepted = true;
             switch (mState) {
                 case StreamDescriptor::State::STANDBY:
                     mState = StreamDescriptor::State::IDLE;
-                    populateReply(&reply, mIsConnected);
                     break;
                 case StreamDescriptor::State::PAUSED:
                     mState = StreamDescriptor::State::ACTIVE;
-                    populateReply(&reply, mIsConnected);
                     break;
                 case StreamDescriptor::State::DRAIN_PAUSED:
                     switchToTransientState(StreamDescriptor::State::DRAINING);
-                    populateReply(&reply, mIsConnected);
+                    break;
+                case StreamDescriptor::State::TRANSFER_PAUSED:
+                    switchToTransientState(StreamDescriptor::State::TRANSFERRING);
                     break;
                 default:
                     populateReplyWrongState(&reply, command);
+                    commandAccepted = false;
             }
-            break;
+            if (commandAccepted) {
+                populateReply(&reply, mIsConnected);
+            }
+        } break;
         case Tag::burst:
             if (const int32_t fmqByteCount = command.get<Tag::burst>(); fmqByteCount >= 0) {
                 LOG(DEBUG) << __func__ << ": '" << toString(command.getTag()) << "' command for "
                            << fmqByteCount << " bytes";
-                if (mState !=
-                    StreamDescriptor::State::ERROR) {  // BURST can be handled in all valid states
+                if (mState != StreamDescriptor::State::ERROR &&
+                    mState != StreamDescriptor::State::TRANSFERRING &&
+                    mState != StreamDescriptor::State::TRANSFER_PAUSED) {
                     if (!write(fmqByteCount, &reply)) {
                         mState = StreamDescriptor::State::ERROR;
                     }
                     if (mState == StreamDescriptor::State::STANDBY ||
-                        mState == StreamDescriptor::State::DRAIN_PAUSED) {
-                        mState = StreamDescriptor::State::PAUSED;
+                        mState == StreamDescriptor::State::DRAIN_PAUSED ||
+                        mState == StreamDescriptor::State::PAUSED) {
+                        if (mAsyncCallback == nullptr ||
+                            mState != StreamDescriptor::State::DRAIN_PAUSED) {
+                            mState = StreamDescriptor::State::PAUSED;
+                        } else {
+                            mState = StreamDescriptor::State::TRANSFER_PAUSED;
+                        }
                     } else if (mState == StreamDescriptor::State::IDLE ||
-                               mState == StreamDescriptor::State::DRAINING) {
-                        mState = StreamDescriptor::State::ACTIVE;
-                    }  // When in 'ACTIVE' and 'PAUSED' do not need to change the state.
+                               mState == StreamDescriptor::State::DRAINING ||
+                               mState == StreamDescriptor::State::ACTIVE) {
+                        if (mAsyncCallback == nullptr || reply.fmqByteCount == fmqByteCount) {
+                            mState = StreamDescriptor::State::ACTIVE;
+                        } else {
+                            switchToTransientState(StreamDescriptor::State::TRANSFERRING);
+                        }
+                    }
                 } else {
                     populateReplyWrongState(&reply, command);
                 }
@@ -340,13 +382,23 @@
             }
             break;
         case Tag::drain:
-            if (mState == StreamDescriptor::State::ACTIVE) {
-                usleep(1000);  // Simulate a blocking call into the driver.
-                populateReply(&reply, mIsConnected);
-                // Can switch the state to ERROR if a driver error occurs.
-                switchToTransientState(StreamDescriptor::State::DRAINING);
+            if (command.get<Tag::drain>() == StreamDescriptor::DrainMode::DRAIN_ALL ||
+                command.get<Tag::drain>() == StreamDescriptor::DrainMode::DRAIN_EARLY_NOTIFY) {
+                if (mState == StreamDescriptor::State::ACTIVE ||
+                    mState == StreamDescriptor::State::TRANSFERRING) {
+                    usleep(1000);  // Simulate a blocking call into the driver.
+                    populateReply(&reply, mIsConnected);
+                    // Can switch the state to ERROR if a driver error occurs.
+                    switchToTransientState(StreamDescriptor::State::DRAINING);
+                } else if (mState == StreamDescriptor::State::TRANSFER_PAUSED) {
+                    mState = StreamDescriptor::State::DRAIN_PAUSED;
+                    populateReply(&reply, mIsConnected);
+                } else {
+                    populateReplyWrongState(&reply, command);
+                }
             } else {
-                populateReplyWrongState(&reply, command);
+                LOG(WARNING) << __func__
+                             << ": invalid drain mode: " << toString(command.get<Tag::drain>());
             }
             break;
         case Tag::standby:
@@ -359,20 +411,30 @@
                 populateReplyWrongState(&reply, command);
             }
             break;
-        case Tag::pause:
-            if (mState == StreamDescriptor::State::ACTIVE ||
-                mState == StreamDescriptor::State::DRAINING) {
-                populateReply(&reply, mIsConnected);
-                mState = mState == StreamDescriptor::State::ACTIVE
-                                 ? StreamDescriptor::State::PAUSED
-                                 : StreamDescriptor::State::DRAIN_PAUSED;
-            } else {
-                populateReplyWrongState(&reply, command);
+        case Tag::pause: {
+            bool commandAccepted = true;
+            switch (mState) {
+                case StreamDescriptor::State::ACTIVE:
+                    mState = StreamDescriptor::State::PAUSED;
+                    break;
+                case StreamDescriptor::State::DRAINING:
+                    mState = StreamDescriptor::State::DRAIN_PAUSED;
+                    break;
+                case StreamDescriptor::State::TRANSFERRING:
+                    mState = StreamDescriptor::State::TRANSFER_PAUSED;
+                    break;
+                default:
+                    populateReplyWrongState(&reply, command);
+                    commandAccepted = false;
             }
-            break;
+            if (commandAccepted) {
+                populateReply(&reply, mIsConnected);
+            }
+        } break;
         case Tag::flush:
             if (mState == StreamDescriptor::State::PAUSED ||
-                mState == StreamDescriptor::State::DRAIN_PAUSED) {
+                mState == StreamDescriptor::State::DRAIN_PAUSED ||
+                mState == StreamDescriptor::State::TRANSFER_PAUSED) {
                 populateReply(&reply, mIsConnected);
                 mState = StreamDescriptor::State::IDLE;
             } else {
diff --git a/audio/aidl/default/include/core-impl/Module.h b/audio/aidl/default/include/core-impl/Module.h
index 0086743..f7b85ed 100644
--- a/audio/aidl/default/include/core-impl/Module.h
+++ b/audio/aidl/default/include/core-impl/Module.h
@@ -86,6 +86,7 @@
     void cleanUpPatch(int32_t patchId);
     ndk::ScopedAStatus createStreamContext(
             int32_t in_portConfigId, int64_t in_bufferSizeFrames,
+            std::shared_ptr<IStreamCallback> asyncCallback,
             ::aidl::android::hardware::audio::core::StreamContext* out_context);
     ndk::ScopedAStatus findPortIdForNewStream(
             int32_t in_portConfigId, ::aidl::android::media::audio::common::AudioPort** port);
diff --git a/audio/aidl/default/include/core-impl/Stream.h b/audio/aidl/default/include/core-impl/Stream.h
index bcbabad..3c96973 100644
--- a/audio/aidl/default/include/core-impl/Stream.h
+++ b/audio/aidl/default/include/core-impl/Stream.h
@@ -29,6 +29,7 @@
 #include <aidl/android/hardware/audio/common/SourceMetadata.h>
 #include <aidl/android/hardware/audio/core/BnStreamIn.h>
 #include <aidl/android/hardware/audio/core/BnStreamOut.h>
+#include <aidl/android/hardware/audio/core/IStreamCallback.h>
 #include <aidl/android/hardware/audio/core/StreamDescriptor.h>
 #include <aidl/android/media/audio/common/AudioOffloadInfo.h>
 #include <fmq/AidlMessageQueue.h>
@@ -60,12 +61,14 @@
 
     StreamContext() = default;
     StreamContext(std::unique_ptr<CommandMQ> commandMQ, std::unique_ptr<ReplyMQ> replyMQ,
-                  size_t frameSize, std::unique_ptr<DataMQ> dataMQ, int transientStateDelayMs)
+                  size_t frameSize, std::unique_ptr<DataMQ> dataMQ,
+                  std::shared_ptr<IStreamCallback> asyncCallback, int transientStateDelayMs)
         : mCommandMQ(std::move(commandMQ)),
           mInternalCommandCookie(std::rand()),
           mReplyMQ(std::move(replyMQ)),
           mFrameSize(frameSize),
           mDataMQ(std::move(dataMQ)),
+          mAsyncCallback(asyncCallback),
           mTransientStateDelayMs(transientStateDelayMs) {}
     StreamContext(StreamContext&& other)
         : mCommandMQ(std::move(other.mCommandMQ)),
@@ -73,6 +76,7 @@
           mReplyMQ(std::move(other.mReplyMQ)),
           mFrameSize(other.mFrameSize),
           mDataMQ(std::move(other.mDataMQ)),
+          mAsyncCallback(other.mAsyncCallback),
           mTransientStateDelayMs(other.mTransientStateDelayMs) {}
     StreamContext& operator=(StreamContext&& other) {
         mCommandMQ = std::move(other.mCommandMQ);
@@ -80,11 +84,13 @@
         mReplyMQ = std::move(other.mReplyMQ);
         mFrameSize = other.mFrameSize;
         mDataMQ = std::move(other.mDataMQ);
+        mAsyncCallback = other.mAsyncCallback;
         mTransientStateDelayMs = other.mTransientStateDelayMs;
         return *this;
     }
 
     void fillDescriptor(StreamDescriptor* desc);
+    std::shared_ptr<IStreamCallback> getAsyncCallback() const { return mAsyncCallback; }
     CommandMQ* getCommandMQ() const { return mCommandMQ.get(); }
     DataMQ* getDataMQ() const { return mDataMQ.get(); }
     size_t getFrameSize() const { return mFrameSize; }
@@ -100,6 +106,7 @@
     std::unique_ptr<ReplyMQ> mReplyMQ;
     size_t mFrameSize;
     std::unique_ptr<DataMQ> mDataMQ;
+    std::shared_ptr<IStreamCallback> mAsyncCallback;
     int mTransientStateDelayMs;
 };
 
@@ -118,6 +125,7 @@
           mCommandMQ(context.getCommandMQ()),
           mReplyMQ(context.getReplyMQ()),
           mDataMQ(context.getDataMQ()),
+          mAsyncCallback(context.getAsyncCallback()),
           mTransientStateDelayMs(context.getTransientStateDelayMs()) {}
     std::string init() override;
     void populateReply(StreamDescriptor::Reply* reply, bool isConnected) const;
@@ -138,6 +146,7 @@
     StreamContext::CommandMQ* mCommandMQ;
     StreamContext::ReplyMQ* mReplyMQ;
     StreamContext::DataMQ* mDataMQ;
+    std::shared_ptr<IStreamCallback> mAsyncCallback;
     const std::chrono::duration<int, std::milli> mTransientStateDelayMs;
     std::chrono::time_point<std::chrono::steady_clock> mTransientStateStart;
     // We use an array and the "size" field instead of a vector to be able to detect
