diff --git a/tokio/src/net/mod.rs b/tokio/src/net/mod.rs index 2d317a8a219..abc270bd0d8 100644 --- a/tokio/src/net/mod.rs +++ b/tokio/src/net/mod.rs @@ -49,6 +49,7 @@ cfg_net_unix! { pub use unix::datagram::socket::UnixDatagram; pub use unix::listener::UnixListener; pub use unix::stream::UnixStream; + pub use unix::socket::UnixSocket; } cfg_net_windows! { diff --git a/tokio/src/net/unix/datagram/socket.rs b/tokio/src/net/unix/datagram/socket.rs index d92ad5940e0..bec4bf983d5 100644 --- a/tokio/src/net/unix/datagram/socket.rs +++ b/tokio/src/net/unix/datagram/socket.rs @@ -96,6 +96,16 @@ cfg_net_unix! { } impl UnixDatagram { + pub(crate) fn from_mio(sys: mio::net::UnixDatagram) -> io::Result { + let datagram = UnixDatagram::new(sys)?; + + if let Some(e) = datagram.io.take_error()? { + return Err(e); + } + + Ok(datagram) + } + /// Waits for any of the requested ready states. /// /// This function is usually paired with `try_recv()` or `try_send()`. It diff --git a/tokio/src/net/unix/listener.rs b/tokio/src/net/unix/listener.rs index a7e9115eadd..bc7b53b3b53 100644 --- a/tokio/src/net/unix/listener.rs +++ b/tokio/src/net/unix/listener.rs @@ -50,6 +50,11 @@ cfg_net_unix! { } impl UnixListener { + pub(crate) fn new(listener: mio::net::UnixListener) -> io::Result { + let io = PollEvented::new(listener)?; + Ok(UnixListener { io }) + } + /// Creates a new `UnixListener` bound to the specified path. /// /// # Panics diff --git a/tokio/src/net/unix/mod.rs b/tokio/src/net/unix/mod.rs index a49b70af34a..a94fc7b2711 100644 --- a/tokio/src/net/unix/mod.rs +++ b/tokio/src/net/unix/mod.rs @@ -7,6 +7,8 @@ pub mod datagram; pub(crate) mod listener; +pub(crate) mod socket; + mod split; pub use split::{ReadHalf, WriteHalf}; diff --git a/tokio/src/net/unix/socket.rs b/tokio/src/net/unix/socket.rs new file mode 100644 index 00000000000..cb383b09a59 --- /dev/null +++ b/tokio/src/net/unix/socket.rs @@ -0,0 +1,271 @@ +use std::io; +use std::path::Path; + +use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, RawFd}; + +use crate::net::{UnixDatagram, UnixListener, UnixStream}; + +cfg_net_unix! { + /// A Unix socket that has not yet been converted to a [`UnixStream`], [`UnixDatagram`], or + /// [`UnixListener`]. + /// + /// `UnixSocket` wraps an operating system socket and enables the caller to + /// configure the socket before establishing a connection or accepting + /// inbound connections. The caller is able to set socket option and explicitly + /// bind the socket with a socket address. + /// + /// The underlying socket is closed when the `UnixSocket` value is dropped. + /// + /// `UnixSocket` should only be used directly if the default configuration used + /// by [`UnixStream::connect`], [`UnixDatagram::bind`], and [`UnixListener::bind`] + /// does not meet the required use case. + /// + /// Calling `UnixStream::connect(path)` effectively performs the same function as: + /// + /// ```no_run + /// use tokio::net::UnixSocket; + /// use std::error::Error; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box> { + /// let dir = tempfile::tempdir().unwrap(); + /// let path = dir.path().join("bind_path"); + /// let socket = UnixSocket::new_stream()?; + /// + /// let stream = socket.connect(path).await?; + /// + /// Ok(()) + /// } + /// ``` + /// + /// Calling `UnixDatagram::bind(path)` effectively performs the same function as: + /// + /// ```no_run + /// use tokio::net::UnixSocket; + /// use std::error::Error; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box> { + /// let dir = tempfile::tempdir().unwrap(); + /// let path = dir.path().join("bind_path"); + /// let socket = UnixSocket::new_datagram()?; + /// socket.bind(path)?; + /// + /// let datagram = socket.datagram()?; + /// + /// Ok(()) + /// } + /// ``` + /// + /// Calling `UnixListener::bind(path)` effectively performs the same function as: + /// + /// ```no_run + /// use tokio::net::UnixSocket; + /// use std::error::Error; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box> { + /// let dir = tempfile::tempdir().unwrap(); + /// let path = dir.path().join("bind_path"); + /// let socket = UnixSocket::new_stream()?; + /// socket.bind(path)?; + /// + /// let listener = socket.listen(1024)?; + /// + /// Ok(()) + /// } + /// ``` + /// + /// Setting socket options not explicitly provided by `UnixSocket` may be done by + /// accessing the [`RawFd`]/[`RawSocket`] using [`AsRawFd`]/[`AsRawSocket`] and + /// setting the option with a crate like [`socket2`]. + /// + /// [`RawFd`]: std::os::fd::RawFd + /// [`RawSocket`]: https://doc.rust-lang.org/std/os/windows/io/type.RawSocket.html + /// [`AsRawFd`]: std::os::fd::AsRawFd + /// [`AsRawSocket`]: https://doc.rust-lang.org/std/os/windows/io/trait.AsRawSocket.html + /// [`socket2`]: https://docs.rs/socket2/ + #[derive(Debug)] + pub struct UnixSocket { + inner: socket2::Socket, + } +} + +impl UnixSocket { + fn ty(&self) -> socket2::Type { + self.inner.r#type().unwrap() + } + + /// Creates a new Unix datagram socket. + /// + /// Calls `socket(2)` with `AF_UNIX` and `SOCK_DGRAM`. + /// + /// # Returns + /// + /// On success, the newly created [`UnixSocket`] is returned. If an error is + /// encountered, it is returned instead. + pub fn new_datagram() -> io::Result { + UnixSocket::new(socket2::Type::DGRAM) + } + + /// Creates a new Unix stream socket. + /// + /// Calls `socket(2)` with `AF_UNIX` and `SOCK_STREAM`. + /// + /// # Returns + /// + /// On success, the newly created [`UnixSocket`] is returned. If an error is + /// encountered, it is returned instead. + pub fn new_stream() -> io::Result { + UnixSocket::new(socket2::Type::STREAM) + } + + fn new(ty: socket2::Type) -> io::Result { + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd" + ))] + let ty = ty.nonblocking(); + let inner = socket2::Socket::new(socket2::Domain::UNIX, ty, None)?; + #[cfg(not(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd" + )))] + inner.set_nonblocking(true)?; + Ok(UnixSocket { inner }) + } + + /// Binds the socket to the given address. + /// + /// This calls the `bind(2)` operating-system function. + pub fn bind(&self, path: impl AsRef) -> io::Result<()> { + let addr = socket2::SockAddr::unix(path)?; + self.inner.bind(&addr) + } + + /// Converts the socket into a `UnixListener`. + /// + /// `backlog` defines the maximum number of pending connections are queued + /// by the operating system at any given time. Connection are removed from + /// the queue with [`UnixListener::accept`]. When the queue is full, the + /// operating-system will start rejecting connections. + /// + /// Calling this function on a socket created by [`new_datagram`] will return an error. + /// + /// This calls the `listen(2)` operating-system function, marking the socket + /// as a passive socket. + /// + /// [`new_datagram`]: `UnixSocket::new_datagram` + pub fn listen(self, backlog: u32) -> io::Result { + if self.ty() == socket2::Type::DGRAM { + return Err(io::Error::new( + io::ErrorKind::Other, + "listen cannot be called on a datagram socket", + )); + } + + self.inner.listen(backlog as i32)?; + let mio = { + use std::os::unix::io::{FromRawFd, IntoRawFd}; + + let raw_fd = self.inner.into_raw_fd(); + unsafe { mio::net::UnixListener::from_raw_fd(raw_fd) } + }; + + UnixListener::new(mio) + } + + /// Establishes a Unix connection with a peer at the specified socket address. + /// + /// The `UnixSocket` is consumed. Once the connection is established, a + /// connected [`UnixStream`] is returned. If the connection fails, the + /// encountered error is returned. + /// + /// Calling this function on a socket created by [`new_datagram`] will return an error. + /// + /// This calls the `connect(2)` operating-system function. + /// + /// [`new_datagram`]: `UnixSocket::new_datagram` + pub async fn connect(self, path: impl AsRef) -> io::Result { + if self.ty() == socket2::Type::DGRAM { + return Err(io::Error::new( + io::ErrorKind::Other, + "connect cannot be called on a datagram socket", + )); + } + + let addr = socket2::SockAddr::unix(path)?; + if let Err(err) = self.inner.connect(&addr) { + if err.raw_os_error() != Some(libc::EINPROGRESS) { + return Err(err); + } + } + let mio = { + use std::os::unix::io::{FromRawFd, IntoRawFd}; + + let raw_fd = self.inner.into_raw_fd(); + unsafe { mio::net::UnixStream::from_raw_fd(raw_fd) } + }; + + UnixStream::connect_mio(mio).await + } + + /// Converts the socket into a [`UnixDatagram`]. + /// + /// Calling this function on a socket created by [`new_stream`] will return an error. + /// + /// [`new_stream`]: `UnixSocket::new_stream` + pub fn datagram(self) -> io::Result { + if self.ty() == socket2::Type::STREAM { + return Err(io::Error::new( + io::ErrorKind::Other, + "datagram cannot be called on a stream socket", + )); + } + let mio = { + use std::os::unix::io::{FromRawFd, IntoRawFd}; + + let raw_fd = self.inner.into_raw_fd(); + unsafe { mio::net::UnixDatagram::from_raw_fd(raw_fd) } + }; + + UnixDatagram::from_mio(mio) + } +} + +impl AsRawFd for UnixSocket { + fn as_raw_fd(&self) -> RawFd { + self.inner.as_raw_fd() + } +} + +impl AsFd for UnixSocket { + fn as_fd(&self) -> BorrowedFd<'_> { + unsafe { BorrowedFd::borrow_raw(self.as_raw_fd()) } + } +} + +impl FromRawFd for UnixSocket { + unsafe fn from_raw_fd(fd: RawFd) -> UnixSocket { + let inner = socket2::Socket::from_raw_fd(fd); + UnixSocket { inner } + } +} + +impl IntoRawFd for UnixSocket { + fn into_raw_fd(self) -> RawFd { + self.inner.into_raw_fd() + } +} diff --git a/tokio/src/net/unix/stream.rs b/tokio/src/net/unix/stream.rs index 4821260ff6a..e1a4ff437f7 100644 --- a/tokio/src/net/unix/stream.rs +++ b/tokio/src/net/unix/stream.rs @@ -39,6 +39,24 @@ cfg_net_unix! { } impl UnixStream { + pub(crate) async fn connect_mio(sys: mio::net::UnixStream) -> io::Result { + let stream = UnixStream::new(sys)?; + + // Once we've connected, wait for the stream to be writable as + // that's when the actual connection has been initiated. Once we're + // writable we check for `take_socket_error` to see if the connect + // actually hit an error or not. + // + // If all that succeeded then we ship everything on up. + poll_fn(|cx| stream.io.registration().poll_write_ready(cx)).await?; + + if let Some(e) = stream.io.take_error()? { + return Err(e); + } + + Ok(stream) + } + /// Connects to the socket named by `path`. /// /// This function will create a new Unix socket and connect to the path diff --git a/tokio/tests/uds_socket.rs b/tokio/tests/uds_socket.rs new file mode 100644 index 00000000000..5261ffe5da3 --- /dev/null +++ b/tokio/tests/uds_socket.rs @@ -0,0 +1,118 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] +#![cfg(unix)] + +use futures::future::try_join; +use std::io; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::UnixSocket, +}; + +#[tokio::test] +async fn datagram_echo_server() -> io::Result<()> { + let dir = tempfile::tempdir().unwrap(); + let server_path = dir.path().join("server.sock"); + let client_path = dir.path().join("client.sock"); + + let server_socket = { + let socket = UnixSocket::new_datagram()?; + socket.bind(&server_path)?; + socket.datagram()? + }; + + tokio::spawn(async move { + let mut recv_buf = vec![0u8; 1024]; + loop { + let (len, peer_addr) = server_socket.recv_from(&mut recv_buf[..]).await?; + if let Some(path) = peer_addr.as_pathname() { + server_socket.send_to(&recv_buf[..len], path).await?; + } + } + + #[allow(unreachable_code)] + Ok::<(), io::Error>(()) + }); + + { + let socket = UnixSocket::new_datagram()?; + socket.bind(&client_path).unwrap(); + let socket = socket.datagram()?; + + socket.connect(server_path)?; + socket.send(b"ECHO").await?; + + let mut recv_buf = [0u8; 16]; + let len = socket.recv(&mut recv_buf[..]).await?; + assert_eq!(&recv_buf[..len], b"ECHO"); + } + + Ok(()) +} + +#[tokio::test] +async fn listen_and_stream() -> std::io::Result<()> { + let dir = tempfile::Builder::new().tempdir().unwrap(); + let sock_path = dir.path().join("connect.sock"); + let peer_path = dir.path().join("peer.sock"); + + let listener = { + let sock = UnixSocket::new_stream()?; + sock.bind(&sock_path)?; + sock.listen(1024)? + }; + + let accept = listener.accept(); + let connect = { + let sock = UnixSocket::new_stream()?; + sock.bind(&peer_path)?; + sock.connect(&sock_path) + }; + + let ((mut server, _), mut client) = try_join(accept, connect).await?; + + assert_eq!( + server.peer_addr().unwrap().as_pathname().unwrap(), + &peer_path + ); + + // Write to the client. + client.write_all(b"hello").await?; + drop(client); + + // Read from the server. + let mut buf = vec![]; + server.read_to_end(&mut buf).await?; + assert_eq!(&buf, b"hello"); + let len = server.read(&mut buf).await?; + assert_eq!(len, 0); + Ok(()) +} + +#[tokio::test] +async fn assert_usage() -> std::io::Result<()> { + let datagram_socket = UnixSocket::new_datagram()?; + let result = datagram_socket + .connect(std::path::PathBuf::new().join("invalid.sock")) + .await; + assert_eq!( + result.unwrap_err().to_string(), + "connect cannot be called on a datagram socket" + ); + + let datagram_socket = UnixSocket::new_datagram()?; + let result = datagram_socket.listen(1024); + assert_eq!( + result.unwrap_err().to_string(), + "listen cannot be called on a datagram socket" + ); + + let stream_socket = UnixSocket::new_stream()?; + let result = stream_socket.datagram(); + assert_eq!( + result.unwrap_err().to_string(), + "datagram cannot be called on a stream socket" + ); + + Ok(()) +}