diff --git a/msg-socket/src/sub/driver.rs b/msg-socket/src/sub/driver.rs index 1f20bd1..17e7d25 100644 --- a/msg-socket/src/sub/driver.rs +++ b/msg-socket/src/sub/driver.rs @@ -27,7 +27,7 @@ use msg_common::{channel, Channel}; use msg_transport::Transport; use msg_wire::{auth, pubsub}; -type ConnectionResult = Result<(SocketAddr, Io), E>; +type ConnectionResult = Result<(SocketAddr, Io), (SocketAddr, E)>; pub(crate) struct SubDriver { /// Options shared with the socket. @@ -92,8 +92,13 @@ where Ok((addr, io)) => { this.on_connection(addr, io); } - Err(e) => { - error!("Error connecting to publisher: {:?}", e); + // If the initial connection failed, reset the publisher to try again later. + Err((addr, e)) => { + this.reset_publisher(addr); + error!( + "Error connecting to publisher, scheduling reconnect: {:?}", + e + ); } } @@ -117,7 +122,7 @@ where addr, PublisherState::Inactive { addr, - backoff: ExponentialBackoff::new(Duration::from_millis(50), 16), + backoff: ExponentialBackoff::new(self.options.initial_backoff, 16), }, ); } @@ -218,10 +223,6 @@ where } self.connect(endpoint); - - // Also set the publisher to the disconnected state. This will make sure that if the - // initial connection attempt fails, it will be retried in `poll_publishers`. - self.reset_publisher(endpoint); } Command::Disconnect { endpoint } => { if self.publishers.remove(&endpoint).is_some() { @@ -243,7 +244,7 @@ where let token = self.options.auth_token.clone(); self.connection_tasks.spawn(async move { - let io = connect.await?; + let io = connect.await.map_err(|e| (addr, e))?; if let Some(token) = token { let mut conn = Framed::new(io, auth::Codec::new_client()); @@ -252,8 +253,8 @@ where // Send the authentication message conn.send(auth::Message::Auth(token)) .await - .map_err(T::Error::from)?; - conn.flush().await.map_err(T::Error::from)?; + .map_err(|e| (addr, T::Error::from(e)))?; + conn.flush().await.map_err(|e| (addr, T::Error::from(e)))?; tracing::debug!("Waiting for ACK from server..."); @@ -261,20 +262,28 @@ where let ack = conn .next() .await - .ok_or(io::Error::new( - io::ErrorKind::UnexpectedEof, - "Connection closed", + .ok_or(( + addr, + io::Error::new(io::ErrorKind::UnexpectedEof, "Connection closed").into(), ))? - .map_err(|e| io::Error::new(io::ErrorKind::PermissionDenied, e))?; + .map_err(|e| { + ( + addr, + io::Error::new(io::ErrorKind::PermissionDenied, e).into(), + ) + })?; if matches!(ack, auth::Message::Ack) { Ok((addr, conn.into_inner())) } else { - Err(io::Error::new( - io::ErrorKind::PermissionDenied, - "Publisher denied connection", - ) - .into()) + Err(( + addr, + io::Error::new( + io::ErrorKind::PermissionDenied, + "Publisher denied connection", + ) + .into(), + )) } } else { Ok((addr, io)) diff --git a/msg-socket/src/sub/mod.rs b/msg-socket/src/sub/mod.rs index ef7be03..01b0774 100644 --- a/msg-socket/src/sub/mod.rs +++ b/msg-socket/src/sub/mod.rs @@ -1,7 +1,7 @@ use bytes::Bytes; use core::fmt; use msg_wire::pubsub; -use std::net::SocketAddr; +use std::{net::SocketAddr, time::Duration}; use thiserror::Error; mod driver; @@ -55,6 +55,8 @@ pub struct SubOptions { ingress_buffer_size: usize, /// The read buffer size for each session. read_buffer_size: usize, + /// The initial backoff for reconnecting to a publisher. + initial_backoff: Duration, } impl SubOptions { @@ -77,6 +79,12 @@ impl SubOptions { self.read_buffer_size = read_buffer_size; self } + + /// Set the initial backoff for reconnecting to a publisher. + pub fn initial_backoff(mut self, initial_backoff: Duration) -> Self { + self.initial_backoff = initial_backoff; + self + } } impl Default for SubOptions { @@ -85,6 +93,7 @@ impl Default for SubOptions { auth_token: None, ingress_buffer_size: DEFAULT_BUFFER_SIZE, read_buffer_size: 8192, + initial_backoff: Duration::from_millis(100), } } }