Skip to content

Commit

Permalink
Merge pull request #69 from chainbound/test/integration-tests
Browse files Browse the repository at this point in the history
Add `pubsub` integration tests
  • Loading branch information
mempirate authored Jan 24, 2024
2 parents 3799474 + 1cec564 commit 7a4d594
Show file tree
Hide file tree
Showing 11 changed files with 509 additions and 57 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions msg-common/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
use futures::future::BoxFuture;
use std::{
task::{Context, Poll},
time::SystemTime,
};

use tokio::sync::mpsc::{
self,
error::{SendError, TryRecvError, TrySendError},
Receiver, Sender,
};

use futures::future::BoxFuture;
pub mod task;

/// Returns the current UNIX timestamp in microseconds.
#[inline]
Expand Down
108 changes: 108 additions & 0 deletions msg-common/src/task.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
use futures::{future::poll_fn, Future};
use std::{
collections::HashSet,
task::{ready, Context, Poll},
};
use tokio::task::{JoinError, JoinSet};

/// A collection of keyed tasks spawned on a Tokio runtime.
/// Hacky implementation of a join set that allows for a key to be associated with each task by having
/// the task return a tuple of (key, value).
#[derive(Debug, Default)]
pub struct JoinMap<K, V> {
keys: HashSet<K>,
joinset: JoinSet<(K, V)>,
}

impl<K, V> JoinMap<K, V> {
/// Create a new `JoinSet`.
pub fn new() -> Self {
Self {
keys: HashSet::new(),
joinset: JoinSet::new(),
}
}

/// Returns the number of tasks currently in the `JoinSet`.
pub fn len(&self) -> usize {
self.joinset.len()
}

/// Returns whether the `JoinSet` is empty.
pub fn is_empty(&self) -> bool {
self.joinset.is_empty()
}
}

impl<K, V> JoinMap<K, V>
where
K: Eq + std::hash::Hash + Clone + Send + Sync + 'static,
V: 'static,
{
/// Spawns a task onto the Tokio runtime that will execute the given future ONLY IF
/// there is not already a task in the set with the same key.
pub fn spawn<F>(&mut self, key: K, future: F)
where
F: Future<Output = (K, V)> + Send + 'static,
V: Send,
{
if self.keys.insert(key) {
self.joinset.spawn(future);
}
}

/// Returns `true` if the `JoinSet` contains a task for the given key.
pub fn contains_key(&self, key: &K) -> bool {
self.keys.contains(key)
}

/// Waits until one of the tasks in the set completes and returns its output.
///
/// Returns `None` if the set is empty.
///
/// # Cancel Safety
///
/// This method is cancel safe. If `join_next` is used as the event in a `tokio::select!`
/// statement and some other branch completes first, it is guaranteed that no tasks were
/// removed from this `JoinSet`.
pub async fn join_next(&mut self) -> Option<Result<(K, V), JoinError>> {
poll_fn(|cx| self.poll_join_next(cx)).await
}

/// Polls for one of the tasks in the set to complete.
///
/// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the set.
///
/// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled
/// to receive a wakeup when a task in the `JoinSet` completes. Note that on multiple calls to
/// `poll_join_next`, only the `Waker` from the `Context` passed to the most recent call is
/// scheduled to receive a wakeup.
///
/// # Returns
///
/// This function returns:
///
/// * `Poll::Pending` if the `JoinSet` is not empty but there is no task whose output is
/// available right now.
/// * `Poll::Ready(Some(Ok(value)))` if one of the tasks in this `JoinSet` has completed.
/// The `value` is the return value of one of the tasks that completed.
/// * `Poll::Ready(Some(Err(err)))` if one of the tasks in this `JoinSet` has panicked or been
/// aborted. The `err` is the `JoinError` from the panicked/aborted task.
/// * `Poll::Ready(None)` if the `JoinSet` is empty.
///
/// Note that this method may return `Poll::Pending` even if one of the tasks has completed.
/// This can happen if the [coop budget] is reached.
pub fn poll_join_next(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<(K, V), JoinError>>> {
match ready!(self.joinset.poll_join_next(cx)) {
Some(Ok((key, value))) => {
self.keys.remove(&key);
Poll::Ready(Some(Ok((key, value))))
}
Some(Err(err)) => Poll::Ready(Some(Err(err))),
None => Poll::Ready(None),
}
}
}
12 changes: 6 additions & 6 deletions msg-sim/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{collections::HashMap, io, net::IpAddr, time::Duration};

use protocol::Protocol;
pub use protocol::Protocol;

mod protocol;

Expand All @@ -13,13 +13,13 @@ use dummynet::{PacketFilter, Pipe};
#[allow(unused)]
pub struct SimulationConfig {
/// The latency of the connection.
latency: Option<Duration>,
pub latency: Option<Duration>,
/// The bandwidth in Kbps.
bw: Option<u64>,
pub bw: Option<u64>,
/// The packet loss rate in percent.
plr: Option<f64>,
pub plr: Option<f64>,
/// The supported protocols.
protocols: Vec<Protocol>,
pub protocols: Vec<Protocol>,
}

#[derive(Default)]
Expand All @@ -34,7 +34,7 @@ impl Simulator {
pub fn new() -> Self {
Self {
active_sims: HashMap::new(),
sim_id: 0,
sim_id: 1,
}
}

Expand Down
2 changes: 2 additions & 0 deletions msg-socket/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,6 @@ rand.workspace = true
parking_lot.workspace = true

[dev-dependencies]
msg-sim.workspace = true

tracing-subscriber = "0.3"
3 changes: 2 additions & 1 deletion msg-socket/src/pub/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ where
}

/// Publishes a message to the given topic. If the topic doesn't exist, this is a no-op.
pub async fn publish(&self, topic: String, message: Bytes) -> Result<(), PubError> {
pub async fn publish(&self, topic: impl Into<String>, message: Bytes) -> Result<(), PubError> {
let topic = topic.into();
let mut msg = PubMessage::new(topic, message);

// We compress here since that way we only have to do it once.
Expand Down
96 changes: 65 additions & 31 deletions msg-socket/src/sub/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ use std::{
task::{Context, Poll},
time::Duration,
};
use tokio::sync::mpsc::error::TrySendError;
use tokio::{sync::mpsc, task::JoinSet};
use tokio::sync::mpsc::{self, error::TrySendError};
use tokio_util::codec::Framed;
use tracing::{debug, error, info, warn};

Expand All @@ -23,12 +22,10 @@ use super::{
Command, PubMessage, SocketState, SubOptions,
};

use msg_common::{channel, Channel};
use msg_common::{channel, task::JoinMap, Channel};
use msg_transport::Transport;
use msg_wire::{auth, pubsub};

type ConnectionResult<Io, E> = Result<(SocketAddr, Io), E>;

pub(crate) struct SubDriver<T: Transport> {
/// Options shared with the socket.
pub(super) options: Arc<SubOptions>,
Expand All @@ -39,7 +36,7 @@ pub(crate) struct SubDriver<T: Transport> {
/// Messages to the socket.
pub(super) to_socket: mpsc::Sender<PubMessage>,
/// A joinset of authentication tasks.
pub(super) connection_tasks: JoinSet<ConnectionResult<T::Io, T::Error>>,
pub(super) connection_tasks: JoinMap<SocketAddr, Result<T::Io, T::Error>>,
/// The set of subscribed topics.
pub(super) subscribed_topics: HashSet<String>,
/// All active publisher sessions for this subscriber socket.
Expand Down Expand Up @@ -87,13 +84,14 @@ where
continue;
}

if let Poll::Ready(Some(Ok(result))) = this.connection_tasks.poll_join_next(cx) {
if let Poll::Ready(Some(Ok((addr, result)))) = this.connection_tasks.poll_join_next(cx)
{
match result {
Ok((addr, io)) => {
Ok(io) => {
this.on_connection(addr, io);
}
Err(e) => {
error!("Error connecting to publisher: {:?}", e);
error!(%addr, "Error connecting to publisher: {:?}", e);
}
}

Expand Down Expand Up @@ -131,6 +129,10 @@ where
false
}

fn is_known(&self, addr: &SocketAddr) -> bool {
self.publishers.contains_key(addr)
}

/// Subscribes to a topic on all publishers.
fn subscribe(&mut self, topic: String) {
let mut inactive = Vec::new();
Expand Down Expand Up @@ -217,6 +219,11 @@ where
endpoint.set_ip(IpAddr::V4(Ipv4Addr::LOCALHOST));
}

if self.is_known(&endpoint) {
debug!(%endpoint, "Publisher already known, ignoring connect command");
return;
}

self.connect(endpoint);

// Also set the publisher to the disconnected state. This will make sure that if the
Expand All @@ -242,42 +249,63 @@ where
let connect = self.transport.connect(addr);
let token = self.options.auth_token.clone();

self.connection_tasks.spawn(async move {
let io = connect.await?;
self.connection_tasks.spawn(addr, async move {
let io = match connect.await {
Ok(io) => io,
Err(e) => {
return (addr, Err(e));
}
};

if let Some(token) = token {
let mut conn = Framed::new(io, auth::Codec::new_client());

tracing::debug!("Sending auth message: {:?}", token);
// Send the authentication message
conn.send(auth::Message::Auth(token))
.await
.map_err(T::Error::from)?;
conn.flush().await.map_err(T::Error::from)?;
if let Err(e) = conn.send(auth::Message::Auth(token)).await {
return (addr, Err(e.into()));
}

if let Err(e) = conn.flush().await {
return (addr, Err(e.into()));
}

tracing::debug!("Waiting for ACK from server...");

// Wait for the response
let ack = conn
.next()
.await
.ok_or(io::Error::new(
io::ErrorKind::UnexpectedEof,
"Connection closed",
))?
.map_err(|e| io::Error::new(io::ErrorKind::PermissionDenied, e))?;
let ack = match conn.next().await {
Some(Ok(ack)) => ack,
Some(Err(e)) => {
return (
addr,
Err(io::Error::new(io::ErrorKind::PermissionDenied, e).into()),
)
}
None => {
return (
addr,
Err(
io::Error::new(io::ErrorKind::UnexpectedEof, "Connection closed")
.into(),
),
)
}
};

if matches!(ack, auth::Message::Ack) {
Ok((addr, conn.into_inner()))
(addr, Ok(conn.into_inner()))
} else {
Err(io::Error::new(
io::ErrorKind::PermissionDenied,
"Publisher denied connection",
(
addr,
Err(io::Error::new(
io::ErrorKind::PermissionDenied,
"Publisher denied connection",
)
.into()),
)
.into())
}
} else {
Ok((addr, io))
(addr, Ok(io))
}
});
}
Expand Down Expand Up @@ -382,8 +410,14 @@ where
if let Poll::Ready(item) = backoff.poll_next_unpin(cx) {
if let Some(duration) = item {
progress = true;
tracing::debug!(backoff = ?duration, "Retrying connection to {:?}", addr);
to_retry.push(*addr);

// Only retry if there are no active connection tasks
if !self.connection_tasks.contains_key(addr) {
tracing::debug!(backoff = ?duration, "Retrying connection to {:?}", addr);
to_retry.push(*addr);
} else {
tracing::debug!(backoff = ?duration, "Not retrying connection to {:?} as there is already a connection task", addr);
}
} else {
error!("Exceeded maximum number of retries for {:?}, terminating connection", addr);
to_terminate.push(*addr);
Expand Down
Loading

0 comments on commit 7a4d594

Please sign in to comment.