From 5c08a99b17d5c6ebd30fa0ac57e396db9efa5b7c Mon Sep 17 00:00:00 2001 From: nicolas <48695862+merklefruit@users.noreply.github.com> Date: Mon, 29 Jan 2024 15:42:36 +0100 Subject: [PATCH] test: max-clients reached --- msg-socket/src/rep/mod.rs | 8 +++-- msg-socket/src/req/driver.rs | 61 ++++++++++++++---------------------- msg-socket/src/req/socket.rs | 5 +-- 3 files changed, 31 insertions(+), 43 deletions(-) diff --git a/msg-socket/src/rep/mod.rs b/msg-socket/src/rep/mod.rs index 81f6ec5..282c1d8 100644 --- a/msg-socket/src/rep/mod.rs +++ b/msg-socket/src/rep/mod.rs @@ -237,13 +237,15 @@ mod tests { let _ = tracing_subscriber::fmt::try_init(); let mut rep = RepSocket::with_options(Tcp::default(), RepOptions::default().max_clients(1)); rep.bind("127.0.0.1:0").await.unwrap(); + let addr = rep.local_addr().unwrap(); let mut req1 = ReqSocket::new(Tcp::default()); - req1.connect(rep.local_addr().unwrap()).await.unwrap(); + req1.connect(addr).await.unwrap(); + tokio::time::sleep(Duration::from_secs(1)).await; + assert_eq!(rep.stats().active_clients(), 1); let mut req2 = ReqSocket::new(Tcp::default()); - req2.connect(rep.local_addr().unwrap()).await.unwrap(); - + req2.connect(addr).await.unwrap(); tokio::time::sleep(Duration::from_secs(1)).await; assert_eq!(rep.stats().active_clients(), 1); } diff --git a/msg-socket/src/req/driver.rs b/msg-socket/src/req/driver.rs index 0edebb8..94c78e0 100644 --- a/msg-socket/src/req/driver.rs +++ b/msg-socket/src/req/driver.rs @@ -1,6 +1,5 @@ use bytes::Bytes; use futures::{Future, FutureExt, SinkExt, StreamExt}; -use msg_transport::{PeerAddress, Transport}; use rustc_hash::FxHashMap; use std::{ collections::VecDeque, @@ -9,14 +8,14 @@ use std::{ pin::Pin, sync::Arc, task::{ready, Context, Poll}, - time::Duration, + time::{Duration, Instant}, }; use tokio::{ sync::{mpsc, oneshot}, - task::JoinHandle, + time::Interval, }; use tokio_util::codec::Framed; -use tracing::{debug, error}; +use tracing::{debug, error, trace}; use crate::{ connection::{ConnectionState, ExponentialBackoff}, @@ -24,13 +23,14 @@ use crate::{ }; use super::{Command, ReqError, ReqOptions}; +use msg_transport::Transport; use msg_wire::{ auth, compression::{try_decompress_payload, Compressor}, reqrep, }; -use std::time::Instant; -use tokio::time::Interval; + +type ConnectionTask = Pin> + Send>>; /// The request socket driver. Endless future that drives /// the the socket forward. @@ -46,8 +46,10 @@ pub(crate) struct ReqDriver { pub(crate) from_socket: mpsc::Receiver, /// The transport for this socket. pub(crate) transport: T, + /// The address of the server. + pub(crate) addr: SocketAddr, /// The connection task which handles the connection to the server. - pub(crate) conn_task: Option>>, + pub(crate) conn_task: Option>, /// The transport controller, wrapped in a [`ConnectionState`] for backoff. /// The [`Framed`] object can send and receive messages from the socket. pub(crate) conn_state: ConnectionState, ExponentialBackoff>, @@ -81,12 +83,12 @@ where /// Start the connection task to the server, handling authentication if necessary. /// The result will be polled by the driver and re-tried according to the backoff policy. fn try_connect(&mut self, addr: SocketAddr) { - tracing::trace!("try_connect"); + trace!("Trying to connect to {}", addr); + let connect = self.transport.connect(addr); let token = self.options.auth_token.clone(); - self.conn_task = Some(tokio::spawn(async move { - tracing::trace!("conn_task start"); + self.conn_task = Some(Box::pin(async move { let mut io = match connect.await { Ok(io) => io, Err(e) => { @@ -95,8 +97,6 @@ where } }; - tracing::trace!("io got"); - // Perform the authentication handshake if let Some(token) = token { let mut conn = Framed::new(&mut io, auth::Codec::new_client()); @@ -244,13 +244,8 @@ where #[inline] fn reset_connection(&mut self) { - let addr = match self.conn_state { - ConnectionState::Active { ref channel } => channel.get_ref().peer_addr().unwrap(), - ConnectionState::Inactive { addr, .. } => addr, - }; - self.conn_state = ConnectionState::Inactive { - addr, + addr: self.addr, backoff: ExponentialBackoff::new(Duration::from_millis(20), 16), }; } @@ -277,27 +272,17 @@ where // Poll the active connection task, if any if let Some(ref mut conn_task) = this.conn_task { - match conn_task.poll_unpin(cx) { - Poll::Ready(Ok(result)) => { - tracing::trace!("conn_task ready"); - - // As soon as the connection task finishes, set it to `None`. - // If it succeeds, the connection will be active, otherwise it will be - // re-tried until the backoff limit is reached. - this.conn_task = None; - - if let Ok(io) = result { - let mut framed = Framed::new(io, reqrep::Codec::new()); - framed.set_backpressure_boundary(this.options.backpressure_boundary); - this.conn_state = ConnectionState::Active { channel: framed }; - - continue; - } - } - Poll::Ready(Err(e)) => { - error!("Connection task failed: {:?}", e); + if let Poll::Ready(result) = conn_task.poll_unpin(cx) { + // As soon as the connection task finishes, set it to `None`. + // - If it was successful, set the connection to active + // - If it failed, it will be re-tried until the backoff limit is reached. + this.conn_task = None; + + if let Ok(io) = result { + let mut framed = Framed::new(io, reqrep::Codec::new()); + framed.set_backpressure_boundary(this.options.backpressure_boundary); + this.conn_state = ConnectionState::Active { channel: framed }; } - Poll::Pending => {} } } diff --git a/msg-socket/src/req/socket.rs b/msg-socket/src/req/socket.rs index f5a10b9..dcbd518 100644 --- a/msg-socket/src/req/socket.rs +++ b/msg-socket/src/req/socket.rs @@ -93,8 +93,8 @@ where .take() .expect("Transport has been moved already"); - // We initialize the connection as inactive, and let it be activated by the backend task - // as soon as the driver is spawned. + // We initialize the connection as inactive, and let it be activated + // by the backend task as soon as the driver is spawned. let conn_state = ConnectionState::Inactive { addr: endpoint, backoff: ExponentialBackoff::new(Duration::from_millis(20), 16), @@ -112,6 +112,7 @@ where // Create the socket backend let driver: ReqDriver = ReqDriver { + addr: endpoint, options: Arc::clone(&self.options), socket_state: Arc::clone(&self.state), id_counter: 0,