Skip to content

Commit

Permalink
feat(socket): initial sub socket implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
mempirate committed Oct 16, 2023
1 parent 2f8f99e commit dfad88c
Show file tree
Hide file tree
Showing 6 changed files with 611 additions and 7 deletions.
11 changes: 11 additions & 0 deletions msg-socket/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
use std::net::SocketAddr;
use tokio::io::{AsyncRead, AsyncWrite};

mod rep;
mod req;
mod sub;

use bytes::Bytes;
pub use rep::*;
pub use req::*;
pub use sub::*;

pub struct RequestId(u32);

Expand All @@ -24,3 +29,9 @@ impl RequestId {
pub trait Authenticator: Send + Sync + Unpin + 'static {
fn authenticate(&self, id: &Bytes) -> bool;
}

pub(crate) struct AuthResult<S: AsyncRead + AsyncWrite> {
id: Bytes,
addr: SocketAddr,
stream: S,
}
225 changes: 225 additions & 0 deletions msg-socket/src/sub/driver.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
use bytes::Bytes;
use futures::{Future, Stream, StreamExt};
use std::collections::{HashSet, VecDeque};
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::mpsc::error::TrySendError;
use tokio::{sync::mpsc, task::JoinSet};
use tokio_stream::StreamMap;
use tokio_util::codec::Framed;
use tracing::{debug, error};

use super::stream::TopicMessage;
use super::{stream::PublisherStream, Command, SubMessage, SubOptions};
use msg_transport::ClientTransport;
use msg_wire::pubsub;

type ConnectionResult<Io, E> = Result<(SocketAddr, Io), E>;

pub(crate) struct SubDriver<T: ClientTransport> {
/// Options shared with the socket.
pub(super) options: Arc<SubOptions>,
/// The transport for this socket driver.
pub(super) transport: Arc<T>,
/// Commands from the socket.
pub(super) from_socket: mpsc::Receiver<Command>,
/// Messages to the socket.
pub(super) to_socket: mpsc::Sender<SubMessage>,
/// A joinset of authentication tasks.
pub(super) connection_tasks: JoinSet<ConnectionResult<T::Io, T::Error>>,
/// The set of subscribed topics.
pub(super) subscribed_topics: HashSet<String>,
/// All active publisher sessions for this subscriber socket.
pub(super) publishers: StreamMap<SocketAddr, PublisherSession<T::Io>>,
}

impl<T> Future for SubDriver<T>
where
T: ClientTransport + Send + Sync + 'static,
{
type Output = ();

/// This poll implementation prioritizes incoming messages over commands.
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();

loop {
if let Poll::Ready(Some((addr, result))) = this.publishers.poll_next_unpin(cx) {
match result {
Ok(msg) => {
this.on_message(SubMessage::new(addr, msg.topic, msg.payload));
}
Err(e) => {
error!(source = %addr, "Error receiving message from publisher: {:?}", e);
}
}

continue;
}

if let Poll::Ready(Some(cmd)) = this.from_socket.poll_recv(cx) {
this.on_command(cmd);

continue;
}

if let Poll::Ready(Some(Ok(result))) = this.connection_tasks.poll_join_next(cx) {
match result {
Ok((addr, io)) => {
this.on_connection(addr, io);
}
Err(e) => {
error!("Error connecting to publisher: {:?}", e);
}
}

continue;
}

return Poll::Pending;
}
}
}

impl<T> SubDriver<T>
where
T: ClientTransport + Send + Sync + 'static,
{
fn on_command(&mut self, cmd: Command) {
debug!("Received command: {:?}", cmd);
match cmd {
Command::Subscribe { topic } => {
if !self.subscribed_topics.contains(&topic) {
self.subscribed_topics.insert(topic.clone());
// Subscribe to the topic on all publishers
for session in self.publishers.values_mut() {
session.subscribe(topic.clone());
}
} else {
debug!(topic = topic.as_str(), "Already subscribed to topic");
}
}
Command::Unsubscribe { topic } => {
if self.subscribed_topics.remove(&topic) {
for session in self.publishers.values_mut() {
session.unsubscribe(topic.clone());
}
} else {
debug!(topic = topic.as_str(), "Not subscribed to topic");
}
}
Command::Connect { endpoint } => {
let id = self.options.client_id.clone();
let transport = Arc::clone(&self.transport);

// NOTE: don't know if this is gonna work
self.connection_tasks.spawn(async move {
let io = transport.connect_with_auth(endpoint, id).await?;

Ok((endpoint, io))
});
}
Command::Shutdown => todo!(),
}
}

fn on_message(&self, msg: SubMessage) {
debug!(source = %msg.source, "New message: {:?}", msg);
// TODO: queuing
if let Err(TrySendError::Full(msg)) = self.to_socket.try_send(msg) {
error!(topic = msg.topic, "Slow subsriber socket, dropping message");
}
}

fn on_connection(&mut self, addr: SocketAddr, io: T::Io) {
// This should spawn a new task tied to this connection, and
debug!("Connection to {} established, spawning session", addr);
let framed = Framed::new(io, pubsub::Codec::new());
let mut publisher_session = PublisherSession::new(addr, PublisherStream::new(framed));

for topic in self.subscribed_topics.iter() {
publisher_session.subscribe(topic.clone());
}

self.publishers.insert(addr, publisher_session);
}
}

/// Manages the state of a single publisher, represented as a [`Stream`].
#[must_use = "streams do nothing unless polled"]
pub(super) struct PublisherSession<Io> {
/// The addr of the publisher
addr: SocketAddr,
/// The egress queue (for subscribe / unsubscribe messages)
egress: VecDeque<pubsub::Message>,
/// The inner stream
stream: PublisherStream<Io>,
}

impl<Io: AsyncRead + AsyncWrite + Send + Unpin> PublisherSession<Io> {
fn new(addr: SocketAddr, stream: PublisherStream<Io>) -> Self {
Self {
addr,
stream,
egress: VecDeque::with_capacity(4),
}
}

/// Queues a subscribe message for this publisher.
/// On the next poll, the message will be attempted to be sent.
fn subscribe(&mut self, topic: String) {
self.egress
.push_back(pubsub::Message::new_sub(Bytes::from(topic)));
}

/// Queues an unsubscribe message for this publisher.
/// On the next poll, the message will be attempted to be sent.
fn unsubscribe(&mut self, topic: String) {
self.egress
.push_back(pubsub::Message::new_unsub(Bytes::from(topic)));
}
}

impl<Io: AsyncRead + AsyncWrite + Unpin> Stream for PublisherSession<Io> {
type Item = Result<TopicMessage, pubsub::Error>;

/// This poll implementation prioritizes incoming messages over outgoing messages.
#[inline]
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();

loop {
match this.stream.poll_next_unpin(cx) {
Poll::Ready(Some(result)) => {
return Poll::Ready(Some(result));
}
Poll::Ready(None) => {
error!(addr = %this.addr, "Publisher stream closed");
return Poll::Ready(None);
}
Poll::Pending => {}
}

let mut progress = false;
while let Some(msg) = this.egress.pop_front() {
// TODO(perf): do we need to clone the message here?
if this.stream.poll_send(cx, msg.clone()).is_ready() {
progress = true;
debug!("Queued message for sending: {:?}", msg);
} else {
this.egress.push_back(msg);
break;
}
}

if progress {
continue;
}

return Poll::Pending;
}
}
}
Loading

0 comments on commit dfad88c

Please sign in to comment.