blob: 52ae180c6754e168c2a452315f3620b6e6aef153 [file] [log] [blame]
Jiyong Park17438782024-08-20 18:17:16 +09001// Copyright 2024, The Android Open Source Project
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Library for a safer conversion from `RawFd` to `OwnedFd`
16
17use nix::fcntl::{fcntl, FdFlag, F_DUPFD, F_GETFD, F_SETFD};
18use nix::libc;
19use nix::unistd::close;
20use std::os::fd::FromRawFd;
21use std::os::fd::OwnedFd;
22use std::os::fd::RawFd;
23use std::sync::Mutex;
24use thiserror::Error;
25
26/// Errors that can occur while taking an ownership of `RawFd`
27#[derive(Debug, PartialEq, Error)]
28pub enum Error {
29 /// RawFd is not a valid file descriptor
30 #[error("{0} is not a file descriptor")]
31 Invalid(RawFd),
32
33 /// RawFd is either stdio, stdout, or stderr
34 #[error("standard IO descriptors cannot be owned")]
35 StdioNotAllowed,
36
37 /// Generic UNIX error
38 #[error("UNIX error")]
39 Errno(#[from] nix::errno::Errno),
40}
41
42static LOCK: Mutex<()> = Mutex::new(());
43
44/// Takes the ownership of `RawFd` and converts it to `OwnedFd`. It is important to know that
45/// `RawFd` is closed when this function successfully returns. The raw file descriptor of the
46/// returned `OwnedFd` is different from `RawFd`. The returned file descriptor is CLOEXEC set.
47pub fn take_fd_ownership(raw_fd: RawFd) -> Result<OwnedFd, Error> {
48 fcntl(raw_fd, F_GETFD).map_err(|_| Error::Invalid(raw_fd))?;
49
50 if [libc::STDIN_FILENO, libc::STDOUT_FILENO, libc::STDERR_FILENO].contains(&raw_fd) {
51 return Err(Error::StdioNotAllowed);
52 }
53
54 // sync is needed otherwise we can create multiple OwnedFds out of the same RawFd
55 let lock = LOCK.lock().unwrap();
56 let new_fd = fcntl(raw_fd, F_DUPFD(raw_fd))?;
57 close(raw_fd)?;
58 drop(lock);
59
60 // This is not essential, but let's follow the common practice in the Rust ecosystem
61 fcntl(new_fd, F_SETFD(FdFlag::FD_CLOEXEC)).map_err(Error::Errno)?;
62
63 // SAFETY: In this function, we have checked that RawFd is actually an open file descriptor and
64 // this is the first time to claim its ownership because we just created it by duping.
65 Ok(unsafe { OwnedFd::from_raw_fd(new_fd) })
66}
67
68#[cfg(test)]
69mod tests {
70 use super::*;
71 use anyhow::Result;
72 use nix::fcntl::{fcntl, FdFlag, F_GETFD, F_SETFD};
73 use std::os::fd::AsRawFd;
74 use std::os::fd::IntoRawFd;
75 use tempfile::tempfile;
76
77 #[test]
78 fn good_fd() -> Result<()> {
79 let raw_fd = tempfile()?.into_raw_fd();
80 assert!(take_fd_ownership(raw_fd).is_ok());
81 Ok(())
82 }
83
84 #[test]
85 fn invalid_fd() -> Result<()> {
86 let raw_fd = 12345; // randomly chosen
87 assert_eq!(take_fd_ownership(raw_fd).unwrap_err(), Error::Invalid(raw_fd));
88 Ok(())
89 }
90
91 #[test]
92 fn original_fd_closed() -> Result<()> {
93 let raw_fd = tempfile()?.into_raw_fd();
94 let owned_fd = take_fd_ownership(raw_fd)?;
95 assert_ne!(raw_fd, owned_fd.as_raw_fd());
96 assert!(fcntl(raw_fd, F_GETFD).is_err());
97 Ok(())
98 }
99
100 #[test]
101 fn cannot_use_same_rawfd_multiple_times() -> Result<()> {
102 let raw_fd = tempfile()?.into_raw_fd();
103
104 let owned_fd = take_fd_ownership(raw_fd); // once
105 let owned_fd2 = take_fd_ownership(raw_fd); // twice
106
107 assert!(owned_fd.is_ok());
108 assert!(owned_fd2.is_err());
109 Ok(())
110 }
111
112 #[test]
113 fn cloexec() -> Result<()> {
114 let raw_fd = tempfile()?.into_raw_fd();
115
116 // intentionally clear cloexec to see if it is set by take_fd_ownership
117 fcntl(raw_fd, F_SETFD(FdFlag::empty()))?;
118 let flags = fcntl(raw_fd, F_GETFD)?;
119 assert_eq!(flags, FdFlag::empty().bits());
120
121 let owned_fd = take_fd_ownership(raw_fd)?;
122 let flags = fcntl(owned_fd.as_raw_fd(), F_GETFD)?;
123 assert_eq!(flags, FdFlag::FD_CLOEXEC.bits());
124 drop(owned_fd);
125 Ok(())
126 }
127}