Refactor callback and make payload stream duplex

Guest VMs now directly call onPayloadStarted to tell the host that their
payload started. And the stream passed by onPayloadStarted is now duplex
so it can also be used as an input stream, which will be fed to the
payload's stdin.

Bug: 191845268
Bug: 195381416
Test: run MicrodroidDemoApp and see output
Test: atest MicrodroidHostTestCases ComposHostTestCases AuthFsHostTest
Change-Id: Ic72045b4e4d11ab1efb14cb2e95de319ca8f9f97
diff --git a/virtualizationservice/aidl/android/system/virtualizationservice/IVirtualMachineCallback.aidl b/virtualizationservice/aidl/android/system/virtualizationservice/IVirtualMachineCallback.aidl
index 7bb18a4..c7a1471 100644
--- a/virtualizationservice/aidl/android/system/virtualizationservice/IVirtualMachineCallback.aidl
+++ b/virtualizationservice/aidl/android/system/virtualizationservice/IVirtualMachineCallback.aidl
@@ -23,14 +23,13 @@
  */
 oneway interface IVirtualMachineCallback {
     /**
-     * Called when the payload starts in the VM. `stdout` is the stdout of the payload.
+     * Called when the payload starts in the VM. `stream` is the input/output port of the payload.
      *
      * <p>Note: when the virtual machine object is shared to multiple processes and they register
-     * this callback to the same virtual machine object, the processes will compete to read from the
-     * same payload stdout. As a result, each process might get only a part of the entire output
-     * stream. To avoid such a case, keep only one process to read from the stdout.
+     * this callback to the same virtual machine object, the processes will compete to access the
+     * same payload stream. Keep only one process to access the stream.
      */
-    void onPayloadStarted(int cid, in ParcelFileDescriptor stdout);
+    void onPayloadStarted(int cid, in @nullable ParcelFileDescriptor stream);
 
     /**
      * Called when the VM dies.
diff --git a/virtualizationservice/src/aidl.rs b/virtualizationservice/src/aidl.rs
index e85ac2c..ab410d3 100644
--- a/virtualizationservice/src/aidl.rs
+++ b/virtualizationservice/src/aidl.rs
@@ -52,7 +52,7 @@
 use std::path::{Path, PathBuf};
 use std::sync::{Arc, Mutex, Weak};
 use vmconfig::VmConfig;
-use vsock::{VsockListener, SockAddr, VsockStream};
+use vsock::{SockAddr, VsockListener, VsockStream};
 use zip::ZipArchive;
 
 pub const BINDER_SERVICE_IDENTIFIER: &str = "android.system.virtualizationservice";
@@ -64,8 +64,8 @@
 const VMADDR_CID_HOST: u32 = 2;
 
 /// Port number that virtualizationservice listens on connections from the guest VMs for the
-/// payload output
-const PORT_VIRT_SERVICE: u32 = 3000;
+/// payload input and output
+const PORT_VIRT_STREAM_SERVICE: u32 = 3000;
 
 /// Port number that virtualizationservice listens on connections from the guest VMs for the
 /// VirtualMachineService binder service
@@ -286,7 +286,7 @@
         // server for payload output
         let state = service.state.clone(); // reference to state (not the state itself) is copied
         std::thread::spawn(move || {
-            handle_connection_from_vm(state).unwrap();
+            handle_stream_connection_from_vm(state).unwrap();
         });
 
         // binder server for vm
@@ -316,8 +316,8 @@
 
 /// Waits for incoming connections from VM. If a new connection is made, notify the event to the
 /// client via the callback (if registered).
-fn handle_connection_from_vm(state: Arc<Mutex<State>>) -> Result<()> {
-    let listener = VsockListener::bind_with_cid_port(VMADDR_CID_HOST, PORT_VIRT_SERVICE)?;
+fn handle_stream_connection_from_vm(state: Arc<Mutex<State>>) -> Result<()> {
+    let listener = VsockListener::bind_with_cid_port(VMADDR_CID_HOST, PORT_VIRT_STREAM_SERVICE)?;
     for stream in listener.incoming() {
         let stream = match stream {
             Err(e) => {
@@ -329,14 +329,11 @@
         if let Ok(SockAddr::Vsock(addr)) = stream.peer_addr() {
             let cid = addr.cid();
             let port = addr.port();
-            info!("connected from cid={}, port={}", cid, port);
-            if cid < FIRST_GUEST_CID {
-                warn!("connection is not from a guest VM");
-                continue;
-            }
-            // TODO(b/191845268): handle this with VirtualMachineService
+            info!("payload stream connected from cid={}, port={}", cid, port);
             if let Some(vm) = state.lock().unwrap().get_vm(cid) {
-                vm.callbacks.notify_payload_started(cid, stream);
+                vm.stream.lock().unwrap().insert(stream);
+            } else {
+                error!("connection from cid={} is not from a guest VM", cid);
             }
         }
     }
@@ -603,11 +600,11 @@
 
 impl VirtualMachineCallbacks {
     /// Call all registered callbacks to notify that the payload has started.
-    pub fn notify_payload_started(&self, cid: Cid, stream: VsockStream) {
+    pub fn notify_payload_started(&self, cid: Cid, stream: Option<VsockStream>) {
         let callbacks = &*self.0.lock().unwrap();
-        let pfd = vsock_stream_to_pfd(stream);
+        let pfd = stream.map(vsock_stream_to_pfd);
         for callback in callbacks {
-            if let Err(e) = callback.onPayloadStarted(cid as i32, &pfd) {
+            if let Err(e) = callback.onPayloadStarted(cid as i32, pfd.as_ref()) {
                 error!("Error notifying payload start event from VM CID {}: {}", cid, e);
             }
         }
@@ -748,15 +745,18 @@
 impl IVirtualMachineService for VirtualMachineService {
     fn notifyPayloadStarted(&self, cid: i32) -> binder::Result<()> {
         let cid = cid as Cid;
-        if self.state.lock().unwrap().get_vm(cid).is_none() {
+        if let Some(vm) = self.state.lock().unwrap().get_vm(cid) {
+            info!("VM having CID {} started payload", cid);
+            let stream = vm.stream.lock().unwrap().take();
+            vm.callbacks.notify_payload_started(cid, stream);
+            Ok(())
+        } else {
             error!("notifyPayloadStarted is called from an unknown cid {}", cid);
-            return Err(new_binder_exception(
+            Err(new_binder_exception(
                 ExceptionCode::SERVICE_SPECIFIC,
                 format!("cannot find a VM with cid {}", cid),
-            ));
+            ))
         }
-        info!("VM having CID {} started payload", cid);
-        Ok(())
     }
 }
 
diff --git a/virtualizationservice/src/crosvm.rs b/virtualizationservice/src/crosvm.rs
index 5873cd9..5984ff0 100644
--- a/virtualizationservice/src/crosvm.rs
+++ b/virtualizationservice/src/crosvm.rs
@@ -26,8 +26,9 @@
 use std::path::PathBuf;
 use std::process::Command;
 use std::sync::atomic::{AtomicBool, Ordering};
-use std::sync::Arc;
+use std::sync::{Arc, Mutex};
 use std::thread;
+use vsock::VsockStream;
 
 const CROSVM_PATH: &str = "/apex/com.android.virt/bin/crosvm";
 
@@ -73,6 +74,8 @@
     running: AtomicBool,
     /// Callbacks to clients of the VM.
     pub callbacks: VirtualMachineCallbacks,
+    /// Input/output stream of the payload run in the VM.
+    pub stream: Mutex<Option<VsockStream>>,
 }
 
 impl VmInstance {
@@ -96,6 +99,7 @@
             requester_debug_pid,
             running: AtomicBool::new(true),
             callbacks: Default::default(),
+            stream: Mutex::new(None),
         }
     }