Refactor callback and make payload stream duplex

Guest VMs now directly call onPayloadStarted to tell the host that their
payload started. And the stream passed by onPayloadStarted is now duplex
so it can also be used as an input stream, which will be fed to the
payload's stdin.

Bug: 191845268
Bug: 195381416
Test: run MicrodroidDemoApp and see output
Test: atest MicrodroidHostTestCases ComposHostTestCases AuthFsHostTest
Change-Id: Ic72045b4e4d11ab1efb14cb2e95de319ca8f9f97
diff --git a/compos/compos_key_cmd/compos_key_cmd.cpp b/compos/compos_key_cmd/compos_key_cmd.cpp
index 04ba1d0..8b1e2b1 100644
--- a/compos/compos_key_cmd/compos_key_cmd.cpp
+++ b/compos/compos_key_cmd/compos_key_cmd.cpp
@@ -100,7 +100,7 @@
 class Callback : public BnVirtualMachineCallback {
 public:
     ::ndk::ScopedAStatus onPayloadStarted(
-            int32_t in_cid, const ::ndk::ScopedFileDescriptor& /*in_stdout*/) override {
+            int32_t in_cid, const ::ndk::ScopedFileDescriptor& /*in_stream*/) override {
         // TODO: Consider copying stdout somewhere useful?
         LOG(INFO) << "Payload started! cid = " << in_cid;
         {
diff --git a/demo/java/com/android/microdroid/demo/MainActivity.java b/demo/java/com/android/microdroid/demo/MainActivity.java
index 9374f5d..044a55d 100644
--- a/demo/java/com/android/microdroid/demo/MainActivity.java
+++ b/demo/java/com/android/microdroid/demo/MainActivity.java
@@ -24,6 +24,7 @@
 import android.system.virtualmachine.VirtualMachineConfig;
 import android.system.virtualmachine.VirtualMachineException;
 import android.system.virtualmachine.VirtualMachineManager;
+import android.util.Log;
 import android.view.View;
 import android.widget.Button;
 import android.widget.CheckBox;
@@ -50,6 +51,8 @@
  * the virtual machine to the UI.
  */
 public class MainActivity extends AppCompatActivity {
+    private static final String TAG = "MicrodroidDemo";
+
     @Override
     protected void onCreate(Bundle savedInstanceState) {
         super.onCreate(savedInstanceState);
@@ -139,19 +142,24 @@
                         new VirtualMachineCallback() {
                             @Override
                             public void onPayloadStarted(
-                                    VirtualMachine vm, ParcelFileDescriptor out) {
+                                    VirtualMachine vm, ParcelFileDescriptor stream) {
+                                if (stream == null) {
+                                    mPayloadOutput.postValue("(no output available)");
+                                    return;
+                                }
                                 try {
                                     BufferedReader reader =
                                             new BufferedReader(
                                                     new InputStreamReader(
                                                             new FileInputStream(
-                                                                    out.getFileDescriptor())));
+                                                                    stream.getFileDescriptor())));
                                     String line;
                                     while ((line = reader.readLine()) != null) {
                                         mPayloadOutput.postValue(line);
                                     }
                                 } catch (IOException e) {
-                                    // Consume
+                                    Log.e(TAG, "IOException while reading payload: "
+                                            + e.getMessage());
                                 }
                             }
 
diff --git a/javalib/src/android/system/virtualmachine/VirtualMachineCallback.java b/javalib/src/android/system/virtualmachine/VirtualMachineCallback.java
index 07af4a1..89bb260 100644
--- a/javalib/src/android/system/virtualmachine/VirtualMachineCallback.java
+++ b/javalib/src/android/system/virtualmachine/VirtualMachineCallback.java
@@ -17,6 +17,7 @@
 package android.system.virtualmachine;
 
 import android.annotation.NonNull;
+import android.annotation.Nullable;
 import android.os.ParcelFileDescriptor;
 
 /**
@@ -28,7 +29,7 @@
 public interface VirtualMachineCallback {
 
     /** Called when the payload starts in the VM. */
-    void onPayloadStarted(@NonNull VirtualMachine vm, @NonNull ParcelFileDescriptor stdout);
+    void onPayloadStarted(@NonNull VirtualMachine vm, @Nullable ParcelFileDescriptor stream);
 
     /** Called when the VM died. */
     void onDied(@NonNull VirtualMachine vm);
diff --git a/microdroid_manager/Android.bp b/microdroid_manager/Android.bp
index 95a7014..5fae7b1 100644
--- a/microdroid_manager/Android.bp
+++ b/microdroid_manager/Android.bp
@@ -20,6 +20,7 @@
         "liblog_rust",
         "libmicrodroid_metadata",
         "libmicrodroid_payload_config",
+        "libnix",
         "libprotobuf",
         "librustutils",
         "libserde",
diff --git a/microdroid_manager/src/main.rs b/microdroid_manager/src/main.rs
index 2fb7fdd..ee0e797 100644
--- a/microdroid_manager/src/main.rs
+++ b/microdroid_manager/src/main.rs
@@ -23,9 +23,10 @@
 use binder::{FromIBinder, Strong};
 use log::{error, info, warn};
 use microdroid_payload_config::{Task, TaskType, VmPayloadConfig};
+use nix::ioctl_read_bad;
 use rustutils::system_properties::PropertyWatcher;
-use std::fs::{self, File};
-use std::os::unix::io::{FromRawFd, IntoRawFd};
+use std::fs::{self, File, OpenOptions};
+use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd};
 use std::path::Path;
 use std::process::{Command, Stdio};
 use std::str;
@@ -61,6 +62,27 @@
     }
 }
 
+const IOCTL_VM_SOCKETS_GET_LOCAL_CID: usize = 0x7b9;
+ioctl_read_bad!(
+    /// Gets local cid from /dev/vsock
+    vm_sockets_get_local_cid,
+    IOCTL_VM_SOCKETS_GET_LOCAL_CID,
+    u32
+);
+
+// TODO: remove this after VS can check the peer addresses of binder clients
+fn get_local_cid() -> Result<u32> {
+    let f = OpenOptions::new()
+        .read(true)
+        .write(false)
+        .open("/dev/vsock")
+        .context("failed to open /dev/vsock")?;
+    let mut ret = 0;
+    // SAFETY: the kernel only modifies the given u32 integer.
+    unsafe { vm_sockets_get_local_cid(f.as_raw_fd(), &mut ret) }?;
+    Ok(ret)
+}
+
 fn main() -> Result<()> {
     kernlog::init()?;
     info!("started.");
@@ -72,10 +94,7 @@
         return Err(err);
     }
 
-    // TODO(b/191845268): microdroid_manager should use this binder to communicate with the host
-    if let Err(err) = get_vms_rpc_binder() {
-        error!("cannot connect to VirtualMachineService: {}", err);
-    }
+    let service = get_vms_rpc_binder().expect("cannot connect to VirtualMachineService");
 
     if !metadata.payload_config_path.is_empty() {
         let config = load_config(Path::new(&metadata.payload_config_path))?;
@@ -87,7 +106,7 @@
 
         // TODO(jooyung): wait until sys.boot_completed?
         if let Some(main_task) = &config.task {
-            exec_task(main_task).map_err(|e| {
+            exec_task(main_task, &service).map_err(|e| {
                 error!("failed to execute task: {}", e);
                 e
             })?;
@@ -118,29 +137,13 @@
 
 /// Executes the given task. Stdout of the task is piped into the vsock stream to the
 /// virtualizationservice in the host side.
-fn exec_task(task: &Task) -> Result<()> {
-    const VMADDR_CID_HOST: u32 = 2;
-    const PORT_VIRT_SVC: u32 = 3000;
-    let stdout = match VsockStream::connect_with_cid_port(VMADDR_CID_HOST, PORT_VIRT_SVC) {
-        Ok(stream) => {
-            // SAFETY: the ownership of the underlying file descriptor is transferred from stream
-            // to the file object, and then into the Command object. When the command is finished,
-            // the file descriptor is closed.
-            let f = unsafe { File::from_raw_fd(stream.into_raw_fd()) };
-            Stdio::from(f)
-        }
-        Err(e) => {
-            error!("failed to connect to virtualization service: {}", e);
-            // Don't fail hard here. Even if we failed to connect to the virtualizationservice,
-            // we keep executing the task. This can happen if the owner of the VM doesn't register
-            // callback to accept the stream. Use /dev/null as the stdout so that the task can
-            // make progress without waiting for someone to consume the output.
-            Stdio::null()
-        }
-    };
+fn exec_task(task: &Task, service: &Strong<dyn IVirtualMachineService>) -> Result<()> {
     info!("executing main task {:?}...", task);
-    // TODO(jiyong): consider piping the stream into stdio (and probably stderr) as well.
-    let mut child = build_command(task)?.stdout(stdout).spawn()?;
+    let mut child = build_command(task)?.spawn()?;
+
+    info!("notifying payload started");
+    service.notifyPayloadStarted(get_local_cid()? as i32)?;
+
     match child.wait()?.code() {
         Some(0) => {
             info!("task successfully finished");
@@ -152,7 +155,10 @@
 }
 
 fn build_command(task: &Task) -> Result<Command> {
-    Ok(match task.type_ {
+    const VMADDR_CID_HOST: u32 = 2;
+    const PORT_VIRT_SVC: u32 = 3000;
+
+    let mut command = match task.type_ {
         TaskType::Executable => {
             let mut command = Command::new(&task.command);
             command.args(&task.args);
@@ -163,7 +169,30 @@
             command.arg(find_library_path(&task.command)?).args(&task.args);
             command
         }
-    })
+    };
+
+    match VsockStream::connect_with_cid_port(VMADDR_CID_HOST, PORT_VIRT_SVC) {
+        Ok(stream) => {
+            // SAFETY: the ownership of the underlying file descriptor is transferred from stream
+            // to the file object, and then into the Command object. When the command is finished,
+            // the file descriptor is closed.
+            let file = unsafe { File::from_raw_fd(stream.into_raw_fd()) };
+            command
+                .stdin(Stdio::from(file.try_clone()?))
+                .stdout(Stdio::from(file.try_clone()?))
+                .stderr(Stdio::from(file));
+        }
+        Err(e) => {
+            error!("failed to connect to virtualization service: {}", e);
+            // Don't fail hard here. Even if we failed to connect to the virtualizationservice,
+            // we keep executing the task. This can happen if the owner of the VM doesn't register
+            // callback to accept the stream. Use /dev/null as the stream so that the task can
+            // make progress without waiting for someone to consume the output.
+            command.stdin(Stdio::null()).stdout(Stdio::null()).stderr(Stdio::null());
+        }
+    }
+
+    Ok(command)
 }
 
 fn find_library_path(name: &str) -> Result<String> {
diff --git a/virtualizationservice/aidl/android/system/virtualizationservice/IVirtualMachineCallback.aidl b/virtualizationservice/aidl/android/system/virtualizationservice/IVirtualMachineCallback.aidl
index 7bb18a4..c7a1471 100644
--- a/virtualizationservice/aidl/android/system/virtualizationservice/IVirtualMachineCallback.aidl
+++ b/virtualizationservice/aidl/android/system/virtualizationservice/IVirtualMachineCallback.aidl
@@ -23,14 +23,13 @@
  */
 oneway interface IVirtualMachineCallback {
     /**
-     * Called when the payload starts in the VM. `stdout` is the stdout of the payload.
+     * Called when the payload starts in the VM. `stream` is the input/output port of the payload.
      *
      * <p>Note: when the virtual machine object is shared to multiple processes and they register
-     * this callback to the same virtual machine object, the processes will compete to read from the
-     * same payload stdout. As a result, each process might get only a part of the entire output
-     * stream. To avoid such a case, keep only one process to read from the stdout.
+     * this callback to the same virtual machine object, the processes will compete to access the
+     * same payload stream. Keep only one process to access the stream.
      */
-    void onPayloadStarted(int cid, in ParcelFileDescriptor stdout);
+    void onPayloadStarted(int cid, in @nullable ParcelFileDescriptor stream);
 
     /**
      * Called when the VM dies.
diff --git a/virtualizationservice/src/aidl.rs b/virtualizationservice/src/aidl.rs
index e85ac2c..ab410d3 100644
--- a/virtualizationservice/src/aidl.rs
+++ b/virtualizationservice/src/aidl.rs
@@ -52,7 +52,7 @@
 use std::path::{Path, PathBuf};
 use std::sync::{Arc, Mutex, Weak};
 use vmconfig::VmConfig;
-use vsock::{VsockListener, SockAddr, VsockStream};
+use vsock::{SockAddr, VsockListener, VsockStream};
 use zip::ZipArchive;
 
 pub const BINDER_SERVICE_IDENTIFIER: &str = "android.system.virtualizationservice";
@@ -64,8 +64,8 @@
 const VMADDR_CID_HOST: u32 = 2;
 
 /// Port number that virtualizationservice listens on connections from the guest VMs for the
-/// payload output
-const PORT_VIRT_SERVICE: u32 = 3000;
+/// payload input and output
+const PORT_VIRT_STREAM_SERVICE: u32 = 3000;
 
 /// Port number that virtualizationservice listens on connections from the guest VMs for the
 /// VirtualMachineService binder service
@@ -286,7 +286,7 @@
         // server for payload output
         let state = service.state.clone(); // reference to state (not the state itself) is copied
         std::thread::spawn(move || {
-            handle_connection_from_vm(state).unwrap();
+            handle_stream_connection_from_vm(state).unwrap();
         });
 
         // binder server for vm
@@ -316,8 +316,8 @@
 
 /// Waits for incoming connections from VM. If a new connection is made, notify the event to the
 /// client via the callback (if registered).
-fn handle_connection_from_vm(state: Arc<Mutex<State>>) -> Result<()> {
-    let listener = VsockListener::bind_with_cid_port(VMADDR_CID_HOST, PORT_VIRT_SERVICE)?;
+fn handle_stream_connection_from_vm(state: Arc<Mutex<State>>) -> Result<()> {
+    let listener = VsockListener::bind_with_cid_port(VMADDR_CID_HOST, PORT_VIRT_STREAM_SERVICE)?;
     for stream in listener.incoming() {
         let stream = match stream {
             Err(e) => {
@@ -329,14 +329,11 @@
         if let Ok(SockAddr::Vsock(addr)) = stream.peer_addr() {
             let cid = addr.cid();
             let port = addr.port();
-            info!("connected from cid={}, port={}", cid, port);
-            if cid < FIRST_GUEST_CID {
-                warn!("connection is not from a guest VM");
-                continue;
-            }
-            // TODO(b/191845268): handle this with VirtualMachineService
+            info!("payload stream connected from cid={}, port={}", cid, port);
             if let Some(vm) = state.lock().unwrap().get_vm(cid) {
-                vm.callbacks.notify_payload_started(cid, stream);
+                vm.stream.lock().unwrap().insert(stream);
+            } else {
+                error!("connection from cid={} is not from a guest VM", cid);
             }
         }
     }
@@ -603,11 +600,11 @@
 
 impl VirtualMachineCallbacks {
     /// Call all registered callbacks to notify that the payload has started.
-    pub fn notify_payload_started(&self, cid: Cid, stream: VsockStream) {
+    pub fn notify_payload_started(&self, cid: Cid, stream: Option<VsockStream>) {
         let callbacks = &*self.0.lock().unwrap();
-        let pfd = vsock_stream_to_pfd(stream);
+        let pfd = stream.map(vsock_stream_to_pfd);
         for callback in callbacks {
-            if let Err(e) = callback.onPayloadStarted(cid as i32, &pfd) {
+            if let Err(e) = callback.onPayloadStarted(cid as i32, pfd.as_ref()) {
                 error!("Error notifying payload start event from VM CID {}: {}", cid, e);
             }
         }
@@ -748,15 +745,18 @@
 impl IVirtualMachineService for VirtualMachineService {
     fn notifyPayloadStarted(&self, cid: i32) -> binder::Result<()> {
         let cid = cid as Cid;
-        if self.state.lock().unwrap().get_vm(cid).is_none() {
+        if let Some(vm) = self.state.lock().unwrap().get_vm(cid) {
+            info!("VM having CID {} started payload", cid);
+            let stream = vm.stream.lock().unwrap().take();
+            vm.callbacks.notify_payload_started(cid, stream);
+            Ok(())
+        } else {
             error!("notifyPayloadStarted is called from an unknown cid {}", cid);
-            return Err(new_binder_exception(
+            Err(new_binder_exception(
                 ExceptionCode::SERVICE_SPECIFIC,
                 format!("cannot find a VM with cid {}", cid),
-            ));
+            ))
         }
-        info!("VM having CID {} started payload", cid);
-        Ok(())
     }
 }
 
diff --git a/virtualizationservice/src/crosvm.rs b/virtualizationservice/src/crosvm.rs
index 5873cd9..5984ff0 100644
--- a/virtualizationservice/src/crosvm.rs
+++ b/virtualizationservice/src/crosvm.rs
@@ -26,8 +26,9 @@
 use std::path::PathBuf;
 use std::process::Command;
 use std::sync::atomic::{AtomicBool, Ordering};
-use std::sync::Arc;
+use std::sync::{Arc, Mutex};
 use std::thread;
+use vsock::VsockStream;
 
 const CROSVM_PATH: &str = "/apex/com.android.virt/bin/crosvm";
 
@@ -73,6 +74,8 @@
     running: AtomicBool,
     /// Callbacks to clients of the VM.
     pub callbacks: VirtualMachineCallbacks,
+    /// Input/output stream of the payload run in the VM.
+    pub stream: Mutex<Option<VsockStream>>,
 }
 
 impl VmInstance {
@@ -96,6 +99,7 @@
             requester_debug_pid,
             running: AtomicBool::new(true),
             callbacks: Default::default(),
+            stream: Mutex::new(None),
         }
     }
 
diff --git a/vm/src/run.rs b/vm/src/run.rs
index 8db43fb..6eb88e9 100644
--- a/vm/src/run.rs
+++ b/vm/src/run.rs
@@ -163,16 +163,22 @@
 impl Interface for VirtualMachineCallback {}
 
 impl IVirtualMachineCallback for VirtualMachineCallback {
-    fn onPayloadStarted(&self, _cid: i32, stdout: &ParcelFileDescriptor) -> BinderResult<()> {
-        // Show the stdout of the payload
-        let mut reader = BufReader::new(stdout.as_ref());
-        loop {
-            let mut s = String::new();
-            match reader.read_line(&mut s) {
-                Ok(0) => break,
-                Ok(_) => print!("{}", s),
-                Err(e) => eprintln!("error reading from virtual machine: {}", e),
-            };
+    fn onPayloadStarted(
+        &self,
+        _cid: i32,
+        stream: Option<&ParcelFileDescriptor>,
+    ) -> BinderResult<()> {
+        // Show the output of the payload
+        if let Some(stream) = stream {
+            let mut reader = BufReader::new(stream.as_ref());
+            loop {
+                let mut s = String::new();
+                match reader.read_line(&mut s) {
+                    Ok(0) => break,
+                    Ok(_) => print!("{}", s),
+                    Err(e) => eprintln!("error reading from virtual machine: {}", e),
+                };
+            }
         }
         Ok(())
     }