Skip to content

Commit

Permalink
Merge pull request #78 from chainbound/fix/ipc/breaking-changes
Browse files Browse the repository at this point in the history
fix(ipc): remove api breaking changes by using `Transport<A>`
  • Loading branch information
merklefruit authored Aug 27, 2024
2 parents a79265d + cbab9f3 commit 6c9a9a8
Show file tree
Hide file tree
Showing 26 changed files with 257 additions and 247 deletions.
16 changes: 9 additions & 7 deletions msg-socket/src/pub/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ use super::{
session::SubscriberSession, trie::PrefixTrie, PubError, PubMessage, PubOptions, SocketState,
};
use crate::{AuthResult, Authenticator};
use msg_transport::{PeerAddress, Transport};
use msg_transport::{Address, PeerAddress, Transport};
use msg_wire::{auth, pubsub};

#[allow(clippy::type_complexity)]
pub(crate) struct PubDriver<T: Transport> {
pub(crate) struct PubDriver<T: Transport<A>, A: Address> {
/// Session ID counter.
pub(super) id_counter: u32,
/// The server transport used to accept incoming connections.
Expand All @@ -32,14 +32,15 @@ pub(crate) struct PubDriver<T: Transport> {
/// A set of pending incoming connections, represented by [`Transport::Accept`].
pub(super) conn_tasks: FuturesUnordered<T::Accept>,
/// A joinset of authentication tasks.
pub(super) auth_tasks: JoinSet<Result<AuthResult<T::Io, T::Addr>, PubError>>,
pub(super) auth_tasks: JoinSet<Result<AuthResult<T::Io, A>, PubError>>,
/// The receiver end of the message broadcast channel. The sender half is stored by [`PubSocket`](super::PubSocket).
pub(super) from_socket_bcast: broadcast::Receiver<PubMessage>,
}

impl<T> Future for PubDriver<T>
impl<T, A> Future for PubDriver<T, A>
where
T: Transport + Unpin + 'static,
T: Transport<A> + Unpin + 'static,
A: Address,
{
type Output = Result<(), PubError>;

Expand Down Expand Up @@ -130,9 +131,10 @@ where
}
}

impl<T> PubDriver<T>
impl<T, A> PubDriver<T, A>
where
T: Transport + Unpin + 'static,
T: Transport<A> + Unpin + 'static,
A: Address,
{
/// Handles an incoming connection. If this returns an error, the active connections counter
/// should be decremented.
Expand Down
42 changes: 21 additions & 21 deletions msg-socket/src/pub/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,10 @@ mod tests {

let mut sub_socket = SubSocket::with_options(Tcp::default(), SubOptions::default());

pub_socket.bind_socket("0.0.0.0:0").await.unwrap();
pub_socket.bind("0.0.0.0:0").await.unwrap();
let addr = pub_socket.local_addr().unwrap();

sub_socket.connect_socket(addr).await.unwrap();
sub_socket.connect(addr).await.unwrap();
sub_socket.subscribe("HELLO".to_string()).await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;

Expand All @@ -221,10 +221,10 @@ mod tests {
SubOptions::default().auth_token(Bytes::from("client1")),
);

pub_socket.bind_socket("0.0.0.0:0").await.unwrap();
pub_socket.bind("0.0.0.0:0").await.unwrap();
let addr = pub_socket.local_addr().unwrap();

sub_socket.connect_socket(addr).await.unwrap();
sub_socket.connect(addr).await.unwrap();
sub_socket.subscribe("HELLO".to_string()).await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;

Expand All @@ -250,10 +250,10 @@ mod tests {
SubOptions::default().auth_token(Bytes::from("client1")),
);

pub_socket.bind_socket("0.0.0.0:0").await.unwrap();
pub_socket.bind("0.0.0.0:0").await.unwrap();
let addr = pub_socket.local_addr().unwrap();

sub_socket.connect_socket(addr).await.unwrap();
sub_socket.connect(addr).await.unwrap();
sub_socket.subscribe("HELLO".to_string()).await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;

Expand All @@ -278,11 +278,11 @@ mod tests {

let mut sub2 = SubSocket::new(Tcp::default());

pub_socket.bind_socket("0.0.0.0:0").await.unwrap();
pub_socket.bind("0.0.0.0:0").await.unwrap();
let addr = pub_socket.local_addr().unwrap();

sub1.connect_socket(addr).await.unwrap();
sub2.connect_socket(addr).await.unwrap();
sub1.connect(addr).await.unwrap();
sub2.connect(addr).await.unwrap();
sub1.subscribe("HELLO".to_string()).await.unwrap();
sub2.subscribe("HELLO".to_string()).await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
Expand Down Expand Up @@ -313,11 +313,11 @@ mod tests {

let mut sub2 = SubSocket::new(Tcp::default());

pub_socket.bind_socket("0.0.0.0:0").await.unwrap();
pub_socket.bind("0.0.0.0:0").await.unwrap();
let addr = pub_socket.local_addr().unwrap();

sub1.connect_socket(addr).await.unwrap();
sub2.connect_socket(addr).await.unwrap();
sub1.connect(addr).await.unwrap();
sub2.connect(addr).await.unwrap();
sub1.subscribe("HELLO".to_string()).await.unwrap();
sub2.subscribe("HELLO".to_string()).await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
Expand Down Expand Up @@ -349,11 +349,11 @@ mod tests {
let mut sub_socket = SubSocket::new(Tcp::default());

// Try to connect and subscribe before the publisher is up
sub_socket.connect_socket("0.0.0.0:6662").await.unwrap();
sub_socket.connect("0.0.0.0:6662").await.unwrap();
sub_socket.subscribe("HELLO".to_string()).await.unwrap();
tokio::time::sleep(Duration::from_millis(500)).await;

pub_socket.bind_socket("0.0.0.0:6662").await.unwrap();
pub_socket.bind("0.0.0.0:6662").await.unwrap();
tokio::time::sleep(Duration::from_millis(2000)).await;

pub_socket
Expand All @@ -376,11 +376,11 @@ mod tests {
let mut sub_socket = SubSocket::new(Quic::default());

// Try to connect and subscribe before the publisher is up
sub_socket.connect_socket("0.0.0.0:6662").await.unwrap();
sub_socket.connect("0.0.0.0:6662").await.unwrap();
sub_socket.subscribe("HELLO".to_string()).await.unwrap();
tokio::time::sleep(Duration::from_millis(1000)).await;

pub_socket.bind_socket("0.0.0.0:6662").await.unwrap();
pub_socket.bind("0.0.0.0:6662").await.unwrap();
tokio::time::sleep(Duration::from_millis(2000)).await;

pub_socket
Expand All @@ -401,18 +401,18 @@ mod tests {
let mut pub_socket =
PubSocket::with_options(Tcp::default(), PubOptions::default().max_clients(1));

pub_socket.bind_socket("0.0.0.0:0").await.unwrap();
pub_socket.bind("0.0.0.0:0").await.unwrap();

let mut sub1 = SubSocket::<Tcp>::with_options(Tcp::default(), SubOptions::default());
let mut sub1 = SubSocket::with_options(Tcp::default(), SubOptions::default());

let mut sub2 = SubSocket::<Tcp>::with_options(Tcp::default(), SubOptions::default());
let mut sub2 = SubSocket::with_options(Tcp::default(), SubOptions::default());

let addr = pub_socket.local_addr().unwrap();

sub1.connect_socket(addr).await.unwrap();
sub1.connect(addr).await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
assert_eq!(pub_socket.stats().active_clients(), 1);
sub2.connect_socket(addr).await.unwrap();
sub2.connect(addr).await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
assert_eq!(pub_socket.stats().active_clients(), 1);
}
Expand Down
29 changes: 15 additions & 14 deletions msg-socket/src/pub/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ use tracing::{debug, trace, warn};

use super::{driver::PubDriver, stats::SocketStats, PubError, PubMessage, PubOptions, SocketState};
use crate::Authenticator;
use msg_transport::Transport;
use msg_transport::{Address, Transport};
use msg_wire::compression::Compressor;

/// A publisher socket. This is thread-safe and can be cloned.
#[derive(Clone, Default)]
pub struct PubSocket<T: Transport> {
pub struct PubSocket<T: Transport<A>, A: Address> {
/// The reply socket options, shared with the driver.
options: Arc<PubOptions>,
/// The reply socket state, shared with the driver.
Expand All @@ -32,39 +32,40 @@ pub struct PubSocket<T: Transport> {
// complicates the API a lot. We can always change this later for perf reasons.
compressor: Option<Arc<dyn Compressor>>,
/// The local address this socket is bound to.
local_addr: Option<T::Addr>,
local_addr: Option<A>,
}

impl<T> PubSocket<T>
impl<T> PubSocket<T, SocketAddr>
where
T: Transport<Addr = SocketAddr> + Send + Unpin + 'static,
T: Transport<SocketAddr> + Send + Unpin + 'static,
{
/// Binds the socket to the given socket addres
///
/// This method is only available for transports that support [`SocketAddr`] as address type,
/// like [`Tcp`](msg_transport::tcp::Tcp) and [`Quic`](msg_transport::quic::Quic).
pub async fn bind_socket(&mut self, addr: impl ToSocketAddrs) -> Result<(), PubError> {
pub async fn bind(&mut self, addr: impl ToSocketAddrs) -> Result<(), PubError> {
let addrs = lookup_host(addr).await?;
self.try_bind(addrs.collect()).await
}
}

impl<T> PubSocket<T>
impl<T> PubSocket<T, PathBuf>
where
T: Transport<Addr = PathBuf> + Send + Unpin + 'static,
T: Transport<PathBuf> + Send + Unpin + 'static,
{
/// Binds the socket to the given path.
///
/// This method is only available for transports that support [`PathBuf`] as address type,
/// like [`Ipc`](msg_transport::ipc::Ipc).
pub async fn bind_path(&mut self, path: impl Into<PathBuf>) -> Result<(), PubError> {
pub async fn bind(&mut self, path: impl Into<PathBuf>) -> Result<(), PubError> {
self.try_bind(vec![path.into()]).await
}
}

impl<T> PubSocket<T>
impl<T, A> PubSocket<T, A>
where
T: Transport + Send + Unpin + 'static,
T: Transport<A> + Send + Unpin + 'static,
A: Address,
{
/// Creates a new reply socket with the default [`PubOptions`].
pub fn new(transport: T) -> Self {
Expand All @@ -85,7 +86,7 @@ where
}

/// Sets the connection authenticator for this socket.
pub fn with_auth<A: Authenticator>(mut self, authenticator: A) -> Self {
pub fn with_auth<O: Authenticator>(mut self, authenticator: O) -> Self {
self.auth = Some(Arc::new(authenticator));
self
}
Expand All @@ -99,7 +100,7 @@ where
/// Binds the socket to the given addresses in order until one succeeds.
///
/// This also spawns the socket driver task.
pub async fn try_bind(&mut self, addresses: Vec<T::Addr>) -> Result<(), PubError> {
pub async fn try_bind(&mut self, addresses: Vec<A>) -> Result<(), PubError> {
let (to_sessions_bcast, from_socket_bcast) =
broadcast::channel(self.options.session_buffer_size);

Expand Down Expand Up @@ -219,7 +220,7 @@ where
}

/// Returns the local address this socket is bound to. `None` if the socket is not bound.
pub fn local_addr(&self) -> Option<&T::Addr> {
pub fn local_addr(&self) -> Option<&A> {
self.local_addr.as_ref()
}
}
18 changes: 10 additions & 8 deletions msg-socket/src/rep/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub(crate) struct PeerState<T: AsyncRead + AsyncWrite, A: Address> {
}

#[allow(clippy::type_complexity)]
pub(crate) struct RepDriver<T: Transport> {
pub(crate) struct RepDriver<T: Transport<A>, A: Address> {
/// The server transport used to accept incoming connections.
pub(crate) transport: T,
/// The reply socket state, shared with the socket front-end.
Expand All @@ -46,9 +46,9 @@ pub(crate) struct RepDriver<T: Transport> {
/// Options shared with socket.
pub(crate) options: Arc<RepOptions>,
/// [`StreamMap`] of connected peers. The key is the peer's address.
pub(crate) peer_states: StreamMap<T::Addr, StreamNotifyClose<PeerState<T::Io, T::Addr>>>,
pub(crate) peer_states: StreamMap<A, StreamNotifyClose<PeerState<T::Io, A>>>,
/// Sender to the socket front-end. Used to notify the socket of incoming requests.
pub(crate) to_socket: mpsc::Sender<Request<T::Addr>>,
pub(crate) to_socket: mpsc::Sender<Request<A>>,
/// Optional connection authenticator.
pub(crate) auth: Option<Arc<dyn Authenticator>>,
/// Optional message compressor. This is shared with the socket to keep
Expand All @@ -57,12 +57,13 @@ pub(crate) struct RepDriver<T: Transport> {
/// A set of pending incoming connections, represented by [`Transport::Accept`].
pub(super) conn_tasks: FuturesUnordered<T::Accept>,
/// A joinset of authentication tasks.
pub(crate) auth_tasks: JoinSet<Result<AuthResult<T::Io, T::Addr>, PubError>>,
pub(crate) auth_tasks: JoinSet<Result<AuthResult<T::Io, A>, PubError>>,
}

impl<T> Future for RepDriver<T>
impl<T, A> Future for RepDriver<T, A>
where
T: Transport + Unpin + 'static,
T: Transport<A> + Unpin + 'static,
A: Address,
{
type Output = Result<(), PubError>;

Expand Down Expand Up @@ -176,9 +177,10 @@ where
}
}

impl<T> RepDriver<T>
impl<T, A> RepDriver<T, A>
where
T: Transport + Unpin + 'static,
T: Transport<A> + Unpin + 'static,
A: Address,
{
/// Handles an incoming connection. If this returns an error, the active connections counter
/// should be decremented.
Expand Down
22 changes: 11 additions & 11 deletions msg-socket/src/rep/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,10 @@ mod tests {
async fn reqrep_simple() {
let _ = tracing_subscriber::fmt::try_init();
let mut rep = RepSocket::new(Tcp::default());
rep.bind_socket(localhost()).await.unwrap();
rep.bind(localhost()).await.unwrap();

let mut req = ReqSocket::new(Tcp::default());
req.connect_socket(rep.local_addr().unwrap()).await.unwrap();
req.connect(rep.local_addr().unwrap()).await.unwrap();

tokio::spawn(async move {
loop {
Expand Down Expand Up @@ -156,15 +156,15 @@ mod tests {
// Try to connect even through the server isn't up yet
let endpoint = addr.clone();
let connection_attempt = tokio::spawn(async move {
req.connect_socket(endpoint).await.unwrap();
req.connect(endpoint).await.unwrap();

req
});

// Wait a moment to start the server
tokio::time::sleep(Duration::from_millis(500)).await;
let mut rep = RepSocket::new(Tcp::default());
rep.bind_socket(addr).await.unwrap();
rep.bind(addr).await.unwrap();

let req = connection_attempt.await.unwrap();

Expand Down Expand Up @@ -193,15 +193,15 @@ mod tests {

let _ = tracing_subscriber::fmt::try_init();
let mut rep = RepSocket::new(Tcp::default()).with_auth(Auth);
rep.bind_socket(localhost()).await.unwrap();
rep.bind(localhost()).await.unwrap();

// Initialize socket with a client ID. This will implicitly enable authentication.
let mut req = ReqSocket::with_options(
Tcp::default(),
ReqOptions::default().auth_token(Bytes::from("REQ")),
);

req.connect_socket(rep.local_addr().unwrap()).await.unwrap();
req.connect(rep.local_addr().unwrap()).await.unwrap();

tracing::info!("Connected to rep");

Expand Down Expand Up @@ -236,16 +236,16 @@ mod tests {
async fn rep_max_connections() {
let _ = tracing_subscriber::fmt::try_init();
let mut rep = RepSocket::with_options(Tcp::default(), RepOptions::default().max_clients(1));
rep.bind_socket("127.0.0.1:0").await.unwrap();
rep.bind("127.0.0.1:0").await.unwrap();
let addr = rep.local_addr().unwrap();

let mut req1 = ReqSocket::new(Tcp::default());
req1.connect_socket(addr).await.unwrap();
req1.connect(addr).await.unwrap();
tokio::time::sleep(Duration::from_secs(1)).await;
assert_eq!(rep.stats().active_clients(), 1);

let mut req2 = ReqSocket::new(Tcp::default());
req2.connect_socket(addr).await.unwrap();
req2.connect(addr).await.unwrap();
tokio::time::sleep(Duration::from_secs(1)).await;
assert_eq!(rep.stats().active_clients(), 1);
}
Expand All @@ -256,13 +256,13 @@ mod tests {
RepSocket::with_options(Tcp::default(), RepOptions::default().min_compress_size(0))
.with_compressor(SnappyCompressor);

rep.bind_socket("0.0.0.0:4445").await.unwrap();
rep.bind("0.0.0.0:4445").await.unwrap();

let mut req =
ReqSocket::with_options(Tcp::default(), ReqOptions::default().min_compress_size(0))
.with_compressor(GzipCompressor::new(6));

req.connect_socket("0.0.0.0:4445").await.unwrap();
req.connect("0.0.0.0:4445").await.unwrap();

tokio::spawn(async move {
let req = rep.next().await.unwrap();
Expand Down
Loading

0 comments on commit 6c9a9a8

Please sign in to comment.