Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Pub/Sub auto-decompression #45

Merged
merged 9 commits into from
Dec 20, 2023
46 changes: 37 additions & 9 deletions msg-socket/src/pub/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ use std::io;
use thiserror::Error;

mod driver;
use msg_wire::{compression::Compressor, pubsub};
use msg_wire::{
compression::{CompressionType, Compressor},
pubsub,
};
mod session;
mod socket;
mod stats;
Expand Down Expand Up @@ -43,6 +46,9 @@ pub struct PubOptions {
/// The maximum number of bytes that can be buffered in the session before being flushed.
/// This internally sets [`Framed::set_backpressure_boundary`](tokio_util::codec::Framed).
backpressure_boundary: usize,
/// Minimum payload size in bytes for compression to be used. If the payload is smaller than
/// this threshold, it will not be compressed.
min_compress_size: usize,
}

impl Default for PubOptions {
Expand All @@ -52,6 +58,7 @@ impl Default for PubOptions {
session_buffer_size: 1024,
flush_interval: Some(std::time::Duration::from_micros(50)),
backpressure_boundary: 8192,
min_compress_size: 1024,
mempirate marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
Expand Down Expand Up @@ -83,12 +90,21 @@ impl PubOptions {
self.flush_interval = Some(flush_interval);
self
}

/// Sets the minimum payload size in bytes for compression to be used. If the payload is smaller than
/// this threshold, it will not be compressed.
pub fn min_compress_size(mut self, min_compress_size: usize) -> Self {
self.min_compress_size = min_compress_size;
self
}
}

/// A message received from a publisher.
/// Includes the source, topic, and payload.
#[derive(Debug, Clone)]
pub struct PubMessage {
/// The compression type used for the message payload.
compression_type: CompressionType,
/// The topic of the message.
topic: String,
/// The message payload.
Expand All @@ -98,7 +114,13 @@ pub struct PubMessage {
#[allow(unused)]
impl PubMessage {
pub fn new(topic: String, payload: Bytes) -> Self {
Self { topic, payload }
Self {
// Initialize the compression type to None.
// The actual compression type will be set in the `compress` method.
compression_type: CompressionType::None,
topic,
payload,
}
}

#[inline]
Expand All @@ -118,12 +140,18 @@ impl PubMessage {

#[inline]
pub fn into_wire(self, seq: u32) -> pubsub::Message {
pubsub::Message::new(seq, Bytes::from(self.topic), self.payload)
pubsub::Message::new(
seq,
Bytes::from(self.topic),
self.payload,
self.compression_type as u8,
)
}

#[inline]
pub fn compress(&mut self, compressor: &dyn Compressor) -> Result<(), io::Error> {
self.payload = compressor.compress(&self.payload)?;
self.compression_type = compressor.compression_type();

Ok(())
}
Expand All @@ -141,7 +169,7 @@ mod tests {

use futures::StreamExt;
use msg_transport::{Tcp, TcpOptions};
use msg_wire::compression::{GzipCompressor, GzipDecompressor};
use msg_wire::compression::GzipCompressor;

use crate::SubSocket;

Expand Down Expand Up @@ -216,16 +244,16 @@ mod tests {
async fn pubsub_many_compressed() {
let _ = tracing_subscriber::fmt::try_init();

let mut pub_socket = PubSocket::new(Tcp::new()).with_compressor(GzipCompressor::new(6));
let mut pub_socket =
PubSocket::with_options(Tcp::new(), PubOptions::default().min_compress_size(0))
.with_compressor(GzipCompressor::new(6));
let mut sub1 = SubSocket::new(Tcp::new_with_options(
TcpOptions::default().with_blocking_connect(),
))
.with_decompressor(GzipDecompressor::new());
));

let mut sub2 = SubSocket::new(Tcp::new_with_options(
TcpOptions::default().with_blocking_connect(),
))
.with_decompressor(GzipDecompressor::new());
));

pub_socket.bind("0.0.0.0:0").await.unwrap();
let addr = pub_socket.local_addr().unwrap();
Expand Down
24 changes: 13 additions & 11 deletions msg-socket/src/pub/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,19 @@ impl<T: ServerTransport> PubSocket<T> {
let mut msg = PubMessage::new(topic, message);

// We compress here since that way we only have to do it once.
if let Some(ref compressor) = self.compressor {
let len_before = msg.payload().len();

// For relatively small messages, this takes <100us
msg.compress(compressor.as_ref())?;

debug!(
"Compressed message from {} to {} bytes",
len_before,
msg.payload().len(),
);
// Compression is only done if the message is larger than the
// configured minimum payload size.
let len_before = msg.payload().len();
if len_before > self.options.min_compress_size {
if let Some(ref compressor) = self.compressor {
msg.compress(compressor.as_ref())?;

debug!(
"Compressed message from {} to {} bytes",
len_before,
msg.payload().len(),
);
}
}

// Broadcast the message directly to all active sessions.
Expand Down
22 changes: 7 additions & 15 deletions msg-socket/src/sub/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use super::{
};
use msg_common::unix_micros;
use msg_transport::ClientTransport;
use msg_wire::compression::Decompressor;
use msg_wire::pubsub;

type ConnectionResult<Io, E> = Result<(SocketAddr, Io), E>;
Expand All @@ -35,8 +34,6 @@ pub(crate) struct SubDriver<T: ClientTransport> {
pub(super) to_socket: mpsc::Sender<PubMessage>,
/// A joinset of authentication tasks.
pub(super) connection_tasks: JoinSet<ConnectionResult<T::Io, T::Error>>,
/// Optional payload decompressor.
pub(super) decompressor: Option<Arc<dyn Decompressor>>,
/// The set of subscribed topics.
pub(super) subscribed_topics: HashSet<String>,
/// All active publisher sessions for this subscriber socket.
Expand All @@ -59,18 +56,19 @@ where
if let Poll::Ready(Some((addr, result))) = this.publishers.poll_next_unpin(cx) {
match result {
Ok(mut msg) => {
if let Some(ref compressor) = this.decompressor {
let Ok(decompressed) = compressor.decompress(&msg.payload) else {
match msg.try_decompress() {
None => { /* No decompression necessary */ }
Some(Ok(decompressed)) => msg.payload = decompressed,
Some(Err(e)) => {
error!(
topic = msg.topic.as_str(),
"Failed to decompress message payload"
"Failed to decompress message payload: {:?}", e
);

continue;
};

msg.payload = decompressed;
}
}

this.on_message(PubMessage::new(addr, msg.topic, msg.payload));
}
Err(e) => {
Expand Down Expand Up @@ -109,12 +107,6 @@ impl<T> SubDriver<T>
where
T: ClientTransport + Send + Sync + 'static,
{
/// Sets the payload decompressor for the socket. This decompressor will be used to decompress all incoming
/// messages from the publishers.
pub fn set_decompressor<C: Decompressor>(&mut self, decompressor: C) {
self.decompressor = Some(Arc::new(decompressor));
}

fn on_command(&mut self, cmd: Command) {
debug!("Received command: {:?}", cmd);
match cmd {
Expand Down
13 changes: 0 additions & 13 deletions msg-socket/src/sub/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use tokio::{sync::mpsc, task::JoinSet};
use tokio_stream::StreamMap;

use msg_transport::ClientTransport;
use msg_wire::compression::Decompressor;

use super::{
Command, PubMessage, SocketState, SocketStats, SubDriver, SubError, SubOptions,
Expand Down Expand Up @@ -53,7 +52,6 @@ where
transport: Arc::new(transport),
from_socket,
to_socket,
decompressor: None,
connection_tasks: JoinSet::new(),
publishers: StreamMap::with_capacity(24),
subscribed_topics: HashSet::with_capacity(32),
Expand All @@ -70,17 +68,6 @@ where
}
}

/// Sets the payload decompressor for the socket. This decompressor will be used to decompress
/// all incoming messages from the publishers.
pub fn with_decompressor<C: Decompressor>(mut self, decompressor: C) -> Self {
self.driver
.as_mut()
.expect("Driver has been spawned already, cannot set compressor")
.set_decompressor(decompressor);

self
}

/// Asynchronously connects to the endpoint.
pub async fn connect(&mut self, endpoint: &str) -> Result<(), SubError> {
self.ensure_active_driver();
Expand Down
36 changes: 33 additions & 3 deletions msg-socket/src/sub/stream.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
use bytes::Bytes;
use futures::{SinkExt, Stream, StreamExt};
use std::{
io,
pin::Pin,
task::{ready, Context, Poll},
};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::codec::Framed;
use tracing::debug;
use tracing::{debug, trace};

use super::SubError;
use msg_wire::pubsub;
use msg_wire::{
compression::{CompressionType, Decompressor, GzipDecompressor, ZstdDecompressor},
pubsub,
};

/// Wraps a framed connection to a publisher and exposes all the PUBSUB specific methods.
pub(super) struct PublisherStream<Io> {
Expand Down Expand Up @@ -49,30 +53,56 @@ impl<Io: AsyncRead + AsyncWrite + Unpin> PublisherStream<Io> {

pub(super) struct TopicMessage {
pub timestamp: u64,
pub compression_type: u8,
pub topic: String,
pub payload: Bytes,
}

impl TopicMessage {
/// Tries to decompress the message payload if necessary.
///
/// - Returns `Some(Ok(Bytes))` if the payload is compressed and decompression succeeded.
/// - Returns `Some(Err(..))` if the payload is compressed but could not be decompressed.
/// - Returns `None` if the payload is not compressed.
pub fn try_decompress(&self) -> Option<Result<Bytes, io::Error>> {
match CompressionType::try_from(self.compression_type) {
Ok(supported_compression_type) => match supported_compression_type {
CompressionType::None => None,
// NOTE: Decompressors are unit structs, so there is no allocation here
CompressionType::Gzip => Some(GzipDecompressor::new().decompress(&self.payload)),
CompressionType::Zstd => Some(ZstdDecompressor::new().decompress(&self.payload)),
},
Err(unsupported_compression_type) => Some(Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("unsupported compression type: {unsupported_compression_type}"),
))),
}
}
}

impl<Io: AsyncRead + AsyncWrite + Unpin> Stream for PublisherStream<Io> {
type Item = Result<TopicMessage, pubsub::Error>;

#[inline]
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();

// We set flush to false only when flush returns ready (i.e. the buffer is fully flushed)
if this.flush && this.conn.poll_flush_unpin(cx).is_ready() {
tracing::trace!("Flushed connection");
trace!("Flushed connection");
this.flush = false
}

if let Some(result) = ready!(this.conn.poll_next_unpin(cx)) {
return Poll::Ready(Some(result.map(|msg| {
let timestamp = msg.timestamp();
let compression_type = msg.compression_type();
let (topic, payload) = msg.into_parts();

// TODO: this will allocate. Can we just return the `Cow`?
let topic = String::from_utf8_lossy(&topic).to_string();
TopicMessage {
compression_type,
timestamp,
topic,
payload,
Expand Down
7 changes: 6 additions & 1 deletion msg-wire/src/compression/gzip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use bytes::Bytes;
use flate2::{read::GzDecoder, write::GzEncoder, Compression};
use std::io::{self, Read, Write};

use super::{Compressor, Decompressor};
use super::{CompressionType, Compressor, Decompressor};

/// A compressor that uses the gzip algorithm.
pub struct GzipCompressor {
Expand All @@ -17,6 +17,10 @@ impl GzipCompressor {
}

impl Compressor for GzipCompressor {
fn compression_type(&self) -> CompressionType {
CompressionType::Gzip
}

fn compress(&self, data: &[u8]) -> Result<Bytes, io::Error> {
// Optimistically allocate the compressed buffer to 1/4 of the original size.
let mut encoder = GzEncoder::new(
Expand All @@ -36,6 +40,7 @@ impl Compressor for GzipCompressor {
pub struct GzipDecompressor;

impl GzipDecompressor {
#[inline]
pub fn new() -> Self {
Self
}
Expand Down
25 changes: 25 additions & 0 deletions msg-wire/src/compression/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,34 @@ mod zstd;
pub use gzip::*;
pub use zstd::*;

/// The possible compression type used for a message.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum CompressionType {
None = 0,
Gzip = 1,
Zstd = 2,
}

impl TryFrom<u8> for CompressionType {
type Error = u8;

fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0 => Ok(CompressionType::None),
1 => Ok(CompressionType::Gzip),
2 => Ok(CompressionType::Zstd),
_ => Err(value),
}
}
}

/// This trait is used to implement message-level compression algorithms for payloads.
/// On outgoing messages, the payload is compressed before being sent using the `compress` method.
pub trait Compressor: Send + Sync + Unpin + 'static {
/// Returns the compression type assigned to this compressor.
fn compression_type(&self) -> CompressionType;

/// Compresses a byte slice payload into a `Bytes` object.
fn compress(&self, data: &[u8]) -> Result<Bytes, io::Error>;
}
Expand Down
Loading
Loading