From 8f6127cf9d84931ff06cdd3d3d54893bb44fb655 Mon Sep 17 00:00:00 2001 From: Jonas Bostoen Date: Wed, 11 Oct 2023 11:24:36 +0200 Subject: [PATCH 1/8] feat(socket): clean up, add rep socket state & stats --- msg-socket/src/rep/driver.rs | 267 +++++++++++++++++++++++++++ msg-socket/src/rep/mod.rs | 340 ++--------------------------------- msg-socket/src/rep/socket.rs | 113 ++++++++++++ msg-socket/src/rep/stats.rs | 52 ++++++ msg-socket/src/req/driver.rs | 2 +- msg-socket/src/req/socket.rs | 2 +- 6 files changed, 452 insertions(+), 324 deletions(-) create mode 100644 msg-socket/src/rep/driver.rs create mode 100644 msg-socket/src/rep/socket.rs create mode 100644 msg-socket/src/rep/stats.rs diff --git a/msg-socket/src/rep/driver.rs b/msg-socket/src/rep/driver.rs new file mode 100644 index 0000000..00cb32e --- /dev/null +++ b/msg-socket/src/rep/driver.rs @@ -0,0 +1,267 @@ +use bytes::Bytes; +use futures::{Future, SinkExt, Stream, StreamExt}; +use std::{ + collections::VecDeque, + net::SocketAddr, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{mpsc, oneshot}, + task::JoinSet, +}; +use tokio_stream::{StreamMap, StreamNotifyClose}; +use tokio_util::codec::Framed; + +use crate::{rep::SocketState, Authenticator, RepError, Request}; +use msg_transport::ServerTransport; +use msg_wire::{auth, reqrep}; + +pub(crate) struct PeerState { + pending_requests: JoinSet>, + conn: Framed, + addr: SocketAddr, + egress_queue: VecDeque, + state: Arc, +} + +pub(crate) struct RepDriver { + /// The server transport used to accept incoming connections. + pub(crate) transport: T, + /// The reply socket state, shared with the socket front-end. + pub(crate) state: Arc, + /// [`StreamMap`] of connected peers. The key is the peer's address. + /// Note that when the [`PeerState`] stream ends, it will be silently removed + /// from this map. + pub(crate) peer_states: StreamMap>>, + /// Sender to the socket front-end. Used to notify the socket of incoming requests. + pub(crate) to_socket: mpsc::Sender, + /// Optional connection authenticator. + pub(crate) auth: Option>, + /// A joinset of authentication tasks. + pub(crate) auth_tasks: JoinSet, RepError>>, +} + +pub(crate) struct AuthResult { + id: Bytes, + addr: SocketAddr, + stream: S, +} + +impl Future for RepDriver { + type Output = Result<(), RepError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + loop { + if let Poll::Ready(Some((peer, msg))) = this.peer_states.poll_next_unpin(cx) { + match msg { + Some(Ok(request)) => { + tracing::debug!("Received request from peer {}", peer); + this.state.stats.increment_rx(request.msg().len()); + let _ = this.to_socket.try_send(request); + } + Some(Err(e)) => { + tracing::error!("Error receiving message from peer {}: {:?}", peer, e); + } + None => { + tracing::debug!("Peer {} disconnected", peer); + this.state.stats.decrement_active_clients(); + } + } + + continue; + } + + if let Poll::Ready(Some(Ok(auth))) = this.auth_tasks.poll_join_next(cx) { + match auth { + Ok(auth) => { + // Run custom authenticator + tracing::debug!("Authentication passed for {:?} ({})", auth.id, auth.addr); + this.state.stats.increment_active_clients(); + + this.peer_states.insert( + auth.addr, + StreamNotifyClose::new(PeerState { + pending_requests: JoinSet::new(), + conn: Framed::new(auth.stream, reqrep::Codec::new()), + addr: auth.addr, + egress_queue: VecDeque::new(), + state: Arc::clone(&this.state), + }), + ); + } + Err(e) => { + tracing::error!("Error authenticating client: {:?}", e); + } + } + + continue; + } + + // Poll the transport for new incoming connections + match this.transport.poll_accept(cx) { + Poll::Ready(Ok((stream, addr))) => { + // If authentication is enabled, start the authentication process + if let Some(ref auth) = this.auth { + let authenticator = Arc::clone(auth); + tracing::debug!("New connection from {}, authenticating", addr); + this.auth_tasks.spawn(async move { + let mut conn = Framed::new(stream, auth::Codec::new_server()); + + tracing::debug!("Waiting for auth"); + // Wait for the response + let auth = conn + .next() + .await + .ok_or(RepError::SocketClosed)? + .map_err(|e| RepError::Auth(e.to_string()))?; + + tracing::debug!("Auth received: {:?}", auth); + + let auth::Message::Auth(id) = auth else { + conn.send(auth::Message::Reject).await?; + conn.flush().await?; + conn.close().await?; + return Err(RepError::Auth("Invalid auth message".to_string())); + }; + + // If authentication fails, send a reject message and close the connection + if !authenticator.authenticate(&id) { + conn.send(auth::Message::Reject).await?; + conn.flush().await?; + conn.close().await?; + return Err(RepError::Auth("Authentication failed".to_string())); + } + + // Send ack + conn.send(auth::Message::Ack).await?; + conn.flush().await?; + + Ok(AuthResult { + id, + addr, + stream: conn.into_inner(), + }) + }); + } else { + this.state.stats.increment_active_clients(); + this.peer_states.insert( + addr, + StreamNotifyClose::new(PeerState { + pending_requests: JoinSet::new(), + conn: Framed::new(stream, reqrep::Codec::new()), + addr, + // TODO: pre-allocate according to some options + egress_queue: VecDeque::with_capacity(64), + state: Arc::clone(&this.state), + }), + ); + + tracing::debug!("New connection from {}", addr); + } + + continue; + } + Poll::Ready(Err(e)) => { + // Errors here are usually about `WouldBlock` + tracing::error!("Error accepting connection: {:?}", e); + + continue; + } + Poll::Pending => {} + } + + return Poll::Pending; + } + } +} + +impl Stream for PeerState { + type Item = Result; + + /// Advances the state of the peer. + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + loop { + // Flush any messages on the outgoing buffer + let _ = this.conn.poll_flush_unpin(cx); + + // Then, try to drain the egress queue. + if this.conn.poll_ready_unpin(cx).is_ready() { + if let Some(msg) = this.egress_queue.pop_front() { + let msg_len = msg.size(); + match this.conn.start_send_unpin(msg) { + Ok(_) => { + this.state.stats.increment_tx(msg_len); + // We might be able to send more queued messages + continue; + } + Err(e) => { + tracing::error!("Failed to send message to socket: {:?}", e); + // End this stream as we can't send any more messages + return Poll::Ready(None); + } + } + } + } + + // Then we check for completed requests, and push them onto the egress queue. + match this.pending_requests.poll_join_next(cx) { + Poll::Ready(Some(Ok(Some((id, payload))))) => { + let msg = reqrep::Message::new(id, payload); + this.egress_queue.push_back(msg); + + continue; + } + Poll::Ready(Some(Ok(None))) => { + tracing::error!("Failed to respond to request"); + this.state.stats.increment_failed_requests(); + + continue; + } + Poll::Ready(Some(Err(e))) => { + tracing::error!("Error receiving response: {:?}", e); + this.state.stats.increment_failed_requests(); + + continue; + } + _ => {} + } + + // Finally we accept incoming requests from the peer. + match this.conn.poll_next_unpin(cx) { + Poll::Ready(Some(result)) => { + tracing::trace!("Received message from peer {}: {:?}", this.addr, result); + let msg = result?; + let msg_id = msg.id(); + + let (tx, rx) = oneshot::channel(); + + // Spawn a task to listen for the response. On success, return message ID and response. + this.pending_requests + .spawn(async move { rx.await.ok().map(|res| (msg_id, res)) }); + + let request = Request { + source: this.addr, + response: tx, + msg: msg.into_payload(), + }; + + return Poll::Ready(Some(Ok(request))); + } + Poll::Ready(None) => { + tracing::debug!("Connection closed"); + return Poll::Ready(None); + } + Poll::Pending => {} + } + + return Poll::Pending; + } + } +} diff --git a/msg-socket/src/rep/mod.rs b/msg-socket/src/rep/mod.rs index d997432..f4bc282 100644 --- a/msg-socket/src/rep/mod.rs +++ b/msg-socket/src/rep/mod.rs @@ -1,55 +1,16 @@ -use std::{ - collections::VecDeque, - net::SocketAddr, - pin::Pin, - sync::Arc, - task::{Context, Poll}, -}; - use bytes::Bytes; -use futures::{Future, SinkExt, Stream, StreamExt}; -use msg_transport::ServerTransport; -use msg_wire::{auth, reqrep}; +use std::net::SocketAddr; use thiserror::Error; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - sync::{mpsc, oneshot}, - task::JoinSet, -}; -use tokio_stream::StreamMap; -use tokio_util::codec::Framed; +use tokio::sync::oneshot; -use crate::Authenticator; +mod driver; +mod socket; +mod stats; +pub use socket::*; +use stats::SocketStats; const DEFAULT_BUFFER_SIZE: usize = 1024; -/// A reply socket. This socket can bind multiple times. -pub struct RepSocket { - #[allow(unused)] - options: Arc, - from_backend: Option>, - transport: Option, - auth: Option>, - local_addr: Option, -} - -impl RepSocket { - pub fn local_addr(&self) -> Option { - self.local_addr - } -} - -impl Stream for RepSocket { - type Item = Request; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.from_backend - .as_mut() - .expect("Inactive socket") - .poll_recv(cx) - } -} - #[derive(Debug, Error)] pub enum RepError { #[error("IO error: {0:?}")] @@ -78,62 +39,14 @@ impl Default for RepOptions { } } -impl RepSocket { - pub fn new(transport: T) -> Self { - Self::new_with_options(transport, RepOptions::default()) - } - - pub fn new_with_options(transport: T, options: RepOptions) -> Self { - Self { - from_backend: None, - transport: Some(transport), - local_addr: None, - options: Arc::new(options), - auth: None, - } - } - - pub fn with_auth(mut self, authenticator: A) -> Self { - self.auth = Some(Arc::new(authenticator)); - self - } -} - -impl RepSocket { - pub async fn bind(&mut self, addr: &str) -> Result<(), RepError> { - let (to_socket, from_backend) = mpsc::channel(DEFAULT_BUFFER_SIZE); - - // Take the transport here, so we can move it into the backend task - let mut transport = self.transport.take().unwrap(); - - transport - .bind(addr) - .await - .map_err(|e| RepError::Transport(Box::new(e)))?; - - let local_addr = transport - .local_addr() - .map_err(|e| RepError::Transport(Box::new(e)))?; - - tracing::debug!("Listening on {}", local_addr); - - let backend = RepBackend { - transport, - peer_states: StreamMap::with_capacity(128), - to_socket, - auth: self.auth.take(), - auth_tasks: JoinSet::new(), - }; - - tokio::spawn(backend); - - self.local_addr = Some(local_addr); - self.from_backend = Some(from_backend); - - Ok(()) - } +/// The request socket state, shared between the backend task and the socket. +#[derive(Debug, Default)] +pub(crate) struct SocketState { + pub(crate) stats: SocketStats, } +/// A request received by the socket. It contains the source address, the message, +/// and a oneshot channel to respond to the request. pub struct Request { source: SocketAddr, response: oneshot::Sender, @@ -141,14 +54,17 @@ pub struct Request { } impl Request { + /// Returns the source address of the request. pub fn source(&self) -> SocketAddr { self.source } + /// Returns a reference to the message. pub fn msg(&self) -> &Bytes { &self.msg } + /// Responds to the request. pub fn respond(self, response: Bytes) -> Result<(), RepError> { self.response .send(response) @@ -156,235 +72,15 @@ impl Request { } } -struct PeerState { - pending_requests: JoinSet>, - conn: Framed, - addr: SocketAddr, - egress_queue: VecDeque, -} - -struct RepBackend { - transport: T, - /// [`StreamMap`] of connected peers. The key is the peer's address. - /// Note that when the [`PeerState`] stream ends, it will be silently removed - /// from this map. - peer_states: StreamMap>, - to_socket: mpsc::Sender, - /// Optional connection authenticator - auth: Option>, - /// Authentication tasks - auth_tasks: JoinSet, RepError>>, -} - -struct AuthResult { - id: Bytes, - addr: SocketAddr, - stream: S, -} - -impl Future for RepBackend { - type Output = Result<(), RepError>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.get_mut(); - - loop { - if let Poll::Ready(Some((peer, msg))) = this.peer_states.poll_next_unpin(cx) { - match msg { - Ok(request) => { - tracing::debug!("Received message from peer {}", peer); - let _ = this.to_socket.try_send(request); - } - - Err(e) => { - tracing::error!("Error receiving message from peer {}: {:?}", peer, e); - } - } - - continue; - } - - if let Poll::Ready(Some(Ok(auth))) = this.auth_tasks.poll_join_next(cx) { - match auth { - Ok(auth) => { - // Run custom authenticator - tracing::debug!("Authentication passed for {:?} ({})", auth.id, auth.addr); - - this.peer_states.insert( - auth.addr, - PeerState { - addr: auth.addr, - pending_requests: JoinSet::new(), - conn: Framed::new(auth.stream, reqrep::Codec::new()), - egress_queue: VecDeque::new(), - }, - ); - } - Err(e) => { - tracing::error!("Error authenticating client: {:?}", e); - } - } - - continue; - } - - // Poll the transport for new incoming connections - match this.transport.poll_accept(cx) { - Poll::Ready(Ok((stream, addr))) => { - // If authentication is enabled, start the authentication process - if let Some(ref auth) = this.auth { - let authenticator = Arc::clone(auth); - tracing::debug!("New connection from {}, authenticating", addr); - this.auth_tasks.spawn(async move { - let mut conn = Framed::new(stream, auth::Codec::new_server()); - - tracing::debug!("Waiting for auth"); - // Wait for the response - let auth = conn - .next() - .await - .ok_or(RepError::SocketClosed)? - .map_err(|e| RepError::Auth(e.to_string()))?; - - tracing::debug!("Auth received: {:?}", auth); - - let auth::Message::Auth(id) = auth else { - conn.send(auth::Message::Reject).await?; - conn.flush().await?; - conn.close().await?; - return Err(RepError::Auth("Invalid auth message".to_string())); - }; - - // If authentication fails, send a reject message and close the connection - if !authenticator.authenticate(&id) { - conn.send(auth::Message::Reject).await?; - conn.flush().await?; - conn.close().await?; - return Err(RepError::Auth("Authentication failed".to_string())); - } - - // Send ack - conn.send(auth::Message::Ack).await?; - conn.flush().await?; - - Ok(AuthResult { - id, - addr, - stream: conn.into_inner(), - }) - }); - } else { - this.peer_states.insert( - addr, - PeerState { - addr, - pending_requests: JoinSet::new(), - conn: Framed::new(stream, reqrep::Codec::new()), - egress_queue: VecDeque::new(), - }, - ); - - tracing::debug!("New connection from {}", addr); - } - - continue; - } - Poll::Ready(Err(e)) => { - // Errors here are usually about `WouldBlock` - tracing::error!("Error accepting connection: {:?}", e); - - continue; - } - Poll::Pending => {} - } - - return Poll::Pending; - } - } -} - -impl Stream for PeerState { - type Item = Result; - - /// Advances the state of the peer. - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - - loop { - let _ = this.conn.poll_flush_unpin(cx); - - if this.conn.poll_ready_unpin(cx).is_ready() { - if let Some(msg) = this.egress_queue.pop_front() { - match this.conn.start_send_unpin(msg) { - Ok(_) => { - // We might be able to send more queued messages - continue; - } - Err(e) => { - tracing::error!("Failed to send message to socket: {:?}", e); - // End this stream as we can't send any more messages - return Poll::Ready(None); - } - } - } - } - - // First, try to drain the egress queue. - // First check for completed requests - match this.pending_requests.poll_join_next(cx) { - Poll::Ready(Some(Ok(Some((id, payload))))) => { - let msg = reqrep::Message::new(id, payload); - this.egress_queue.push_back(msg); - - continue; - } - Poll::Ready(Some(Err(e))) => { - tracing::error!("Error receiving response: {:?}", e); - continue; - } - _ => {} - } - - match this.conn.poll_next_unpin(cx) { - Poll::Ready(Some(result)) => { - tracing::trace!("Received message from peer {}: {:?}", this.addr, result); - let msg = result?; - let msg_id = msg.id(); - - let (tx, rx) = oneshot::channel(); - - // Spawn a task to listen for the response. On success, return message ID and response. - this.pending_requests - .spawn(async move { rx.await.ok().map(|res| (msg_id, res)) }); - - let request = Request { - source: this.addr, - response: tx, - msg: msg.into_payload(), - }; - - return Poll::Ready(Some(Ok(request))); - } - Poll::Ready(None) => { - tracing::debug!("Connection closed"); - return Poll::Ready(None); - } - Poll::Pending => {} - } - - return Poll::Pending; - } - } -} - #[cfg(test)] mod tests { use std::time::Duration; + use futures::StreamExt; use msg_transport::Tcp; use rand::Rng; - use crate::{req::ReqSocket, ReqOptions}; + use crate::{req::ReqSocket, Authenticator, ReqOptions}; use super::*; diff --git a/msg-socket/src/rep/socket.rs b/msg-socket/src/rep/socket.rs new file mode 100644 index 0000000..01c3a01 --- /dev/null +++ b/msg-socket/src/rep/socket.rs @@ -0,0 +1,113 @@ +use std::{ + net::SocketAddr, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use futures::Stream; +use msg_transport::ServerTransport; +use tokio::{sync::mpsc, task::JoinSet}; +use tokio_stream::StreamMap; + +use crate::{ + rep::{driver::RepDriver, DEFAULT_BUFFER_SIZE}, + rep::{SocketState, SocketStats}, + Authenticator, RepError, RepOptions, Request, +}; + +/// A reply socket. This socket implements [`Stream`] and yields incoming [`Request`]s. +pub struct RepSocket { + /// The reply socket options, shared with the driver. + options: Arc, + /// The reply socket state, shared with the driver. + state: Arc, + /// Receiver from the socket driver. + from_driver: Option>, + /// The optional transport. This is taken when the socket is bound. + transport: Option, + /// Optional connection authenticator. + auth: Option>, + /// The local address this socket is bound to. + local_addr: Option, +} + +impl RepSocket { + /// Creates a new reply socket with the default [`RepOptions`]. + pub fn new(transport: T) -> Self { + Self::new_with_options(transport, RepOptions::default()) + } + + /// Creates a new reply socket with the given [`RepOptions`]. + pub fn new_with_options(transport: T, options: RepOptions) -> Self { + Self { + from_driver: None, + transport: Some(transport), + local_addr: None, + options: Arc::new(options), + state: Arc::new(SocketState::default()), + auth: None, + } + } + + /// Sets the connection authenticator for this socket. + pub fn with_auth(mut self, authenticator: A) -> Self { + self.auth = Some(Arc::new(authenticator)); + self + } + + /// Binds the socket to the given address. This spawns the socket driver task. + pub async fn bind(&mut self, addr: &str) -> Result<(), RepError> { + let (to_socket, from_backend) = mpsc::channel(DEFAULT_BUFFER_SIZE); + + // Take the transport here, so we can move it into the backend task + let mut transport = self.transport.take().unwrap(); + + transport + .bind(addr) + .await + .map_err(|e| RepError::Transport(Box::new(e)))?; + + let local_addr = transport + .local_addr() + .map_err(|e| RepError::Transport(Box::new(e)))?; + + tracing::debug!("Listening on {}", local_addr); + + let backend = RepDriver { + transport, + state: Arc::clone(&self.state), + peer_states: StreamMap::with_capacity(self.options.max_connections.unwrap_or(64)), + to_socket, + auth: self.auth.take(), + auth_tasks: JoinSet::new(), + }; + + tokio::spawn(backend); + + self.local_addr = Some(local_addr); + self.from_driver = Some(from_backend); + + Ok(()) + } + + pub fn stats(&self) -> &SocketStats { + &self.state.stats + } + + /// Returns the local address this socket is bound to. `None` if the socket is not bound. + pub fn local_addr(&self) -> Option { + self.local_addr + } +} + +impl Stream for RepSocket { + type Item = Request; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.from_driver + .as_mut() + .expect("Inactive socket") + .poll_recv(cx) + } +} diff --git a/msg-socket/src/rep/stats.rs b/msg-socket/src/rep/stats.rs new file mode 100644 index 0000000..c414d32 --- /dev/null +++ b/msg-socket/src/rep/stats.rs @@ -0,0 +1,52 @@ +use std::sync::atomic::{AtomicUsize, Ordering}; + +/// Statistics for a reply socket. These are shared between the driver task +/// and the socket. +#[derive(Debug, Default)] +pub struct SocketStats { + /// Total bytes sent + bytes_tx: AtomicUsize, + /// Total bytes received + bytes_rx: AtomicUsize, + /// Total number of active request clients + active_clients: AtomicUsize, + /// Total number of failed requests + failed_requests: AtomicUsize, +} + +impl SocketStats { + #[inline] + pub(crate) fn increment_tx(&self, bytes: usize) { + self.bytes_tx.fetch_add(bytes, Ordering::Relaxed); + } + + #[inline] + pub(crate) fn increment_rx(&self, bytes: usize) { + self.bytes_rx.fetch_add(bytes, Ordering::Relaxed); + } + + #[inline] + pub(crate) fn increment_active_clients(&self) { + self.active_clients.fetch_add(1, Ordering::Relaxed); + } + + #[inline] + pub(crate) fn decrement_active_clients(&self) { + self.active_clients.fetch_sub(1, Ordering::Relaxed); + } + + #[inline] + pub(crate) fn increment_failed_requests(&self) { + self.failed_requests.fetch_add(1, Ordering::Relaxed); + } + + #[inline] + pub fn bytes_tx(&self) -> usize { + self.bytes_tx.load(Ordering::Relaxed) + } + + #[inline] + pub fn bytes_rx(&self) -> usize { + self.bytes_rx.load(Ordering::Relaxed) + } +} diff --git a/msg-socket/src/req/driver.rs b/msg-socket/src/req/driver.rs index 355a094..e5ebe58 100644 --- a/msg-socket/src/req/driver.rs +++ b/msg-socket/src/req/driver.rs @@ -13,7 +13,7 @@ use tokio::{ }; use tokio_util::codec::Framed; -use crate::SocketState; +use crate::req::SocketState; use super::{Command, ReqError, ReqOptions}; use msg_wire::reqrep; diff --git a/msg-socket/src/req/socket.rs b/msg-socket/src/req/socket.rs index c0ee4e6..359e9fe 100644 --- a/msg-socket/src/req/socket.rs +++ b/msg-socket/src/req/socket.rs @@ -7,7 +7,7 @@ use rustc_hash::FxHashMap; use tokio::sync::{mpsc, oneshot}; use tokio_util::codec::Framed; -use crate::{req::stats::SocketStats, SocketState}; +use crate::{req::stats::SocketStats, req::SocketState}; use super::{Command, ReqDriver, ReqError, ReqOptions, DEFAULT_BUFFER_SIZE}; From 92dc40d35772fe6894fe414316ba0c541c6075ca Mon Sep 17 00:00:00 2001 From: Jonas Bostoen Date: Wed, 11 Oct 2023 11:26:50 +0200 Subject: [PATCH 2/8] refactor(socket): reorder imports --- msg-socket/src/rep/socket.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/msg-socket/src/rep/socket.rs b/msg-socket/src/rep/socket.rs index 01c3a01..9bfaf30 100644 --- a/msg-socket/src/rep/socket.rs +++ b/msg-socket/src/rep/socket.rs @@ -1,12 +1,10 @@ +use futures::Stream; use std::{ net::SocketAddr, pin::Pin, sync::Arc, task::{Context, Poll}, }; - -use futures::Stream; -use msg_transport::ServerTransport; use tokio::{sync::mpsc, task::JoinSet}; use tokio_stream::StreamMap; @@ -15,6 +13,7 @@ use crate::{ rep::{SocketState, SocketStats}, Authenticator, RepError, RepOptions, Request, }; +use msg_transport::ServerTransport; /// A reply socket. This socket implements [`Stream`] and yields incoming [`Request`]s. pub struct RepSocket { From b5bcc554a62fe22706bfff28ce5e475a6f5454a6 Mon Sep 17 00:00:00 2001 From: Jonas Bostoen Date: Wed, 11 Oct 2023 11:35:48 +0200 Subject: [PATCH 3/8] feat(socket): rep socket stats getters --- msg-socket/src/rep/stats.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/msg-socket/src/rep/stats.rs b/msg-socket/src/rep/stats.rs index c414d32..e3249d3 100644 --- a/msg-socket/src/rep/stats.rs +++ b/msg-socket/src/rep/stats.rs @@ -49,4 +49,14 @@ impl SocketStats { pub fn bytes_rx(&self) -> usize { self.bytes_rx.load(Ordering::Relaxed) } + + #[inline] + pub fn active_clients(&self) -> usize { + self.active_clients.load(Ordering::Relaxed) + } + + #[inline] + pub fn failed_requests(&self) -> usize { + self.failed_requests.load(Ordering::Relaxed) + } } From f65b6ba4fd00914bb42e392885b89a88915d921c Mon Sep 17 00:00:00 2001 From: Jonas Bostoen Date: Wed, 11 Oct 2023 12:15:50 +0200 Subject: [PATCH 4/8] doc: update README --- README.md | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index ac65e4f..3d6a29c 100644 --- a/README.md +++ b/README.md @@ -20,14 +20,20 @@ It was built because we needed a Rust-native messaging library like those above. - [ ] Multiple socket types - [x] Request/Reply - - [ ] Channel - [ ] Publish/Subscribe + - [ ] Channel - [ ] Push/Pull - [ ] Survey/Respond - [ ] Stats (RTT, throughput, packet drops etc.) -- [ ] Durable transports (built-in retries and reconnections) + - [x] Request/Reply basic stats - [ ] Queuing -- [ ] Pluggable transport layer (TCP, UDP, QUIC etc.) +- [ ] Pluggable transport layer + - [x] TCP + - [ ] TLS + - [ ] IPC + - [ ] UDP + - [ ] Inproc +- [x] Durable IO abstraction (built-in retries and reconnections) - [ ] Simulation modes with [Turmoil](https://github.com/tokio-rs/turmoil) ## Socket Types @@ -65,6 +71,8 @@ async fn main() { println!("Response: {:?}", res); } ``` +## MSRV +The minimum supported Rust version is 1.70. ## Contributions & Bug Reports From 4f7168966c43902c919430de05df37cfc6cf8d6d Mon Sep 17 00:00:00 2001 From: Jonas Bostoen Date: Wed, 11 Oct 2023 12:57:51 +0200 Subject: [PATCH 5/8] feat(socket): change new with options API --- msg-socket/src/rep/mod.rs | 22 ++++------------------ msg-socket/src/rep/socket.rs | 13 +++++++------ msg-socket/src/req/socket.rs | 11 ++++++----- msg/examples/durable.rs | 6 ++---- msg/examples/reqrep_auth.rs | 6 ++---- 5 files changed, 21 insertions(+), 37 deletions(-) diff --git a/msg-socket/src/rep/mod.rs b/msg-socket/src/rep/mod.rs index f4bc282..5f048b8 100644 --- a/msg-socket/src/rep/mod.rs +++ b/msg-socket/src/rep/mod.rs @@ -164,10 +164,8 @@ mod tests { rep.bind("127.0.0.1:0").await.unwrap(); // Initialize socket with a client ID. This will implicitly enable authentication. - let mut req = ReqSocket::new_with_options( - Tcp::new(), - ReqOptions::default().with_id(Bytes::from("REQ")), - ); + let mut req = ReqSocket::new(Tcp::new()) + .with_options(ReqOptions::default().with_id(Bytes::from("REQ"))); req.connect(&rep.local_addr().unwrap().to_string()) .await @@ -205,22 +203,10 @@ mod tests { #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn test_batch_req_rep() { let _ = tracing_subscriber::fmt::try_init(); - let mut rep = RepSocket::new_with_options( - Tcp::new(), - RepOptions { - set_nodelay: true, - ..Default::default() - }, - ); + let mut rep = RepSocket::new(Tcp::new()); rep.bind("127.0.0.1:0").await.unwrap(); - let mut req = ReqSocket::new_with_options( - Tcp::new(), - ReqOptions { - set_nodelay: true, - ..Default::default() - }, - ); + let mut req = ReqSocket::new(Tcp::new()); req.connect(&rep.local_addr().unwrap().to_string()) .await .unwrap(); diff --git a/msg-socket/src/rep/socket.rs b/msg-socket/src/rep/socket.rs index 9bfaf30..d742a1a 100644 --- a/msg-socket/src/rep/socket.rs +++ b/msg-socket/src/rep/socket.rs @@ -34,21 +34,22 @@ pub struct RepSocket { impl RepSocket { /// Creates a new reply socket with the default [`RepOptions`]. pub fn new(transport: T) -> Self { - Self::new_with_options(transport, RepOptions::default()) - } - - /// Creates a new reply socket with the given [`RepOptions`]. - pub fn new_with_options(transport: T, options: RepOptions) -> Self { Self { from_driver: None, transport: Some(transport), local_addr: None, - options: Arc::new(options), + options: Arc::new(RepOptions::default()), state: Arc::new(SocketState::default()), auth: None, } } + /// Sets the options for this socket. + pub fn with_options(mut self, options: RepOptions) -> Self { + self.options = Arc::new(options); + self + } + /// Sets the connection authenticator for this socket. pub fn with_auth(mut self, authenticator: A) -> Self { self.auth = Some(Arc::new(authenticator)); diff --git a/msg-socket/src/req/socket.rs b/msg-socket/src/req/socket.rs index 359e9fe..dbb8024 100644 --- a/msg-socket/src/req/socket.rs +++ b/msg-socket/src/req/socket.rs @@ -25,18 +25,19 @@ pub struct ReqSocket { impl ReqSocket { pub fn new(transport: T) -> Self { - Self::new_with_options(transport, ReqOptions::default()) - } - - pub fn new_with_options(transport: T, options: ReqOptions) -> Self { Self { to_driver: None, transport, - options: Arc::new(options), + options: Arc::new(ReqOptions::default()), state: Arc::new(SocketState::default()), } } + pub fn with_options(mut self, options: ReqOptions) -> Self { + self.options = Arc::new(options); + self + } + pub fn stats(&self) -> &SocketStats { &self.state.stats } diff --git a/msg/examples/durable.rs b/msg/examples/durable.rs index 912cb6f..fd18768 100644 --- a/msg/examples/durable.rs +++ b/msg/examples/durable.rs @@ -37,10 +37,8 @@ async fn main() { // Initialize the request socket (client side) with a transport // and an identifier. This will implicitly turn on client authentication. - let mut req = ReqSocket::new_with_options( - Tcp::new(), - ReqOptions::default().with_id(Bytes::from("client1")), - ); + let mut req = ReqSocket::new(Tcp::new()) + .with_options(ReqOptions::default().with_id(Bytes::from("client1"))); tracing::info!("Trying to connect to rep socket..."); req.connect("0.0.0.0:4444").await.unwrap(); diff --git a/msg/examples/reqrep_auth.rs b/msg/examples/reqrep_auth.rs index 9e0cdf1..a280f7d 100644 --- a/msg/examples/reqrep_auth.rs +++ b/msg/examples/reqrep_auth.rs @@ -23,10 +23,8 @@ async fn main() { // Initialize the request socket (client side) with a transport // and an identifier. This will implicitly turn on client authentication. - let mut req = ReqSocket::new_with_options( - Tcp::new(), - ReqOptions::default().with_id(Bytes::from("client1")), - ); + let mut req = ReqSocket::new(Tcp::new()) + .with_options(ReqOptions::default().with_id(Bytes::from("client1"))); req.connect("0.0.0.0:4444").await.unwrap(); From 38d920ec05d2de548fa9e4fcc615fb5cbaed65ad Mon Sep 17 00:00:00 2001 From: Jonas Bostoen Date: Wed, 11 Oct 2023 13:44:21 +0200 Subject: [PATCH 6/8] feat(socket): pre-allocate rep socket egress queue --- msg-socket/src/rep/driver.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/msg-socket/src/rep/driver.rs b/msg-socket/src/rep/driver.rs index 00cb32e..985abf5 100644 --- a/msg-socket/src/rep/driver.rs +++ b/msg-socket/src/rep/driver.rs @@ -89,7 +89,8 @@ impl Future for RepDriver { pending_requests: JoinSet::new(), conn: Framed::new(auth.stream, reqrep::Codec::new()), addr: auth.addr, - egress_queue: VecDeque::new(), + // TODO: pre-allocate according to some options + egress_queue: VecDeque::with_capacity(64), state: Arc::clone(&this.state), }), ); From 9f99f77f72812185f73bc045b4a9999a638b088b Mon Sep 17 00:00:00 2001 From: Jonas Bostoen Date: Wed, 11 Oct 2023 13:50:38 +0200 Subject: [PATCH 7/8] refactor(transport): make Layer generic over IO --- msg-transport/src/durable/session.rs | 16 +++++++--------- msg-transport/src/lib.rs | 6 ++---- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/msg-transport/src/durable/session.rs b/msg-transport/src/durable/session.rs index a6e1534..ec373b1 100644 --- a/msg-transport/src/durable/session.rs +++ b/msg-transport/src/durable/session.rs @@ -17,14 +17,14 @@ pub type PendingIo = Pin> + Send>>; /// A layer can be applied to pre-process a newly established IO object. If you need /// multiple layers, use a single top-level layer that contains and calls the other layers. -pub trait Layer: 'static { +pub trait Layer: 'static { /// The type of the IO object that is processed. - type Io: AsyncRead + AsyncWrite; + // type Io: AsyncRead + AsyncWrite; /// The processing method. This method is called with the IO object that /// should be processed, and returns a future that resolves to a processing error /// or the processed IO object. - fn process(&mut self, io: Self::Io) -> PendingIo; + fn process(&mut self, io: Io) -> PendingIo; } struct ReconnectStatus { @@ -126,7 +126,7 @@ where /// Adds a layer to the session. The layer will be applied to all established or re-established /// sessions. - pub fn with_layer(mut self, layer: impl Layer + Send) -> Self { + pub fn with_layer(mut self, layer: impl Layer + Send) -> Self { self.layer_stack = Some(Box::new(layer)); self } @@ -218,7 +218,7 @@ pub struct DurableSession { endpoint: SocketAddr, /// Optional layer stack. If this is `None`, newly connected (or reconnected) sessions will /// be passed through without processing. - layer_stack: Option + Send>>, + layer_stack: Option + Send>>, } impl AsyncRead for DurableSession @@ -638,10 +638,8 @@ mod tests { async fn session_with_layer() { struct TestLayer; - impl Layer for TestLayer { - type Io = TcpStream; - - fn process(&mut self, io: Self::Io) -> PendingIo { + impl Layer for TestLayer { + fn process(&mut self, io: TcpStream) -> PendingIo { Box::pin(async move { let mut io = io; io.write_i32(10).await.unwrap(); diff --git a/msg-transport/src/lib.rs b/msg-transport/src/lib.rs index 00540eb..a1deca6 100644 --- a/msg-transport/src/lib.rs +++ b/msg-transport/src/lib.rs @@ -98,10 +98,8 @@ pub struct AuthLayer { id: Bytes, } -impl Layer for AuthLayer { - type Io = TcpStream; - - fn process(&mut self, io: Self::Io) -> PendingIo { +impl Layer for AuthLayer { + fn process(&mut self, io: TcpStream) -> PendingIo { let id = self.id.clone(); Box::pin(async move { let mut conn = Framed::new(io, auth::Codec::new_client()); From e74b11ab82779f58703e9be4d4a99118dfb063dc Mon Sep 17 00:00:00 2001 From: Jonas Bostoen Date: Wed, 11 Oct 2023 14:02:30 +0200 Subject: [PATCH 8/8] chore(transport): rm unused comments --- msg-transport/src/durable/session.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/msg-transport/src/durable/session.rs b/msg-transport/src/durable/session.rs index ec373b1..c02d678 100644 --- a/msg-transport/src/durable/session.rs +++ b/msg-transport/src/durable/session.rs @@ -18,9 +18,6 @@ pub type PendingIo = Pin> + Send>>; /// A layer can be applied to pre-process a newly established IO object. If you need /// multiple layers, use a single top-level layer that contains and calls the other layers. pub trait Layer: 'static { - /// The type of the IO object that is processed. - // type Io: AsyncRead + AsyncWrite; - /// The processing method. This method is called with the IO object that /// should be processed, and returns a future that resolves to a processing error /// or the processed IO object.