Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Pub/Sub auto-decompression #45

Merged
merged 9 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ rand = "0.8"
rustc-hash = "1"
flate2 = "1"
zstd = "0.13"
snap = "1"

[profile.dev]
opt-level = 1
Expand Down
46 changes: 37 additions & 9 deletions msg-socket/src/pub/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ use std::io;
use thiserror::Error;

mod driver;
use msg_wire::{compression::Compressor, pubsub};
use msg_wire::{
compression::{CompressionType, Compressor},
pubsub,
};
mod session;
mod socket;
mod stats;
Expand Down Expand Up @@ -43,6 +46,9 @@ pub struct PubOptions {
/// The maximum number of bytes that can be buffered in the session before being flushed.
/// This internally sets [`Framed::set_backpressure_boundary`](tokio_util::codec::Framed).
backpressure_boundary: usize,
/// Minimum payload size in bytes for compression to be used. If the payload is smaller than
/// this threshold, it will not be compressed.
min_compress_size: usize,
}

impl Default for PubOptions {
Expand All @@ -52,6 +58,7 @@ impl Default for PubOptions {
session_buffer_size: 1024,
flush_interval: Some(std::time::Duration::from_micros(50)),
backpressure_boundary: 8192,
min_compress_size: 8192,
}
}
}
Expand Down Expand Up @@ -83,12 +90,21 @@ impl PubOptions {
self.flush_interval = Some(flush_interval);
self
}

/// Sets the minimum payload size in bytes for compression to be used. If the payload is smaller than
/// this threshold, it will not be compressed.
pub fn min_compress_size(mut self, min_compress_size: usize) -> Self {
self.min_compress_size = min_compress_size;
self
}
}

/// A message received from a publisher.
/// Includes the source, topic, and payload.
#[derive(Debug, Clone)]
pub struct PubMessage {
/// The compression type used for the message payload.
compression_type: CompressionType,
/// The topic of the message.
topic: String,
/// The message payload.
Expand All @@ -98,7 +114,13 @@ pub struct PubMessage {
#[allow(unused)]
impl PubMessage {
pub fn new(topic: String, payload: Bytes) -> Self {
Self { topic, payload }
Self {
// Initialize the compression type to None.
// The actual compression type will be set in the `compress` method.
compression_type: CompressionType::None,
topic,
payload,
}
}

#[inline]
Expand All @@ -118,12 +140,18 @@ impl PubMessage {

#[inline]
pub fn into_wire(self, seq: u32) -> pubsub::Message {
pubsub::Message::new(seq, Bytes::from(self.topic), self.payload)
pubsub::Message::new(
seq,
Bytes::from(self.topic),
self.payload,
self.compression_type as u8,
)
}

#[inline]
pub fn compress(&mut self, compressor: &dyn Compressor) -> Result<(), io::Error> {
self.payload = compressor.compress(&self.payload)?;
self.compression_type = compressor.compression_type();

Ok(())
}
Expand All @@ -141,7 +169,7 @@ mod tests {

use futures::StreamExt;
use msg_transport::{Tcp, TcpOptions};
use msg_wire::compression::{GzipCompressor, GzipDecompressor};
use msg_wire::compression::GzipCompressor;

use crate::SubSocket;

Expand Down Expand Up @@ -216,16 +244,16 @@ mod tests {
async fn pubsub_many_compressed() {
let _ = tracing_subscriber::fmt::try_init();

let mut pub_socket = PubSocket::new(Tcp::new()).with_compressor(GzipCompressor::new(6));
let mut pub_socket =
PubSocket::with_options(Tcp::new(), PubOptions::default().min_compress_size(0))
.with_compressor(GzipCompressor::new(6));
let mut sub1 = SubSocket::new(Tcp::new_with_options(
TcpOptions::default().with_blocking_connect(),
))
.with_decompressor(GzipDecompressor::new());
));

let mut sub2 = SubSocket::new(Tcp::new_with_options(
TcpOptions::default().with_blocking_connect(),
))
.with_decompressor(GzipDecompressor::new());
));

pub_socket.bind("0.0.0.0:0").await.unwrap();
let addr = pub_socket.local_addr().unwrap();
Expand Down
24 changes: 13 additions & 11 deletions msg-socket/src/pub/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,19 @@ impl<T: ServerTransport> PubSocket<T> {
let mut msg = PubMessage::new(topic, message);

// We compress here since that way we only have to do it once.
if let Some(ref compressor) = self.compressor {
let len_before = msg.payload().len();

// For relatively small messages, this takes <100us
msg.compress(compressor.as_ref())?;

debug!(
"Compressed message from {} to {} bytes",
len_before,
msg.payload().len(),
);
// Compression is only done if the message is larger than the
// configured minimum payload size.
let len_before = msg.payload().len();
if len_before > self.options.min_compress_size {
if let Some(ref compressor) = self.compressor {
msg.compress(compressor.as_ref())?;

debug!(
"Compressed message from {} to {} bytes",
len_before,
msg.payload().len(),
);
}
}

// Broadcast the message directly to all active sessions.
Expand Down
22 changes: 7 additions & 15 deletions msg-socket/src/sub/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use super::{
};
use msg_common::unix_micros;
use msg_transport::ClientTransport;
use msg_wire::compression::Decompressor;
use msg_wire::pubsub;

type ConnectionResult<Io, E> = Result<(SocketAddr, Io), E>;
Expand All @@ -35,8 +34,6 @@ pub(crate) struct SubDriver<T: ClientTransport> {
pub(super) to_socket: mpsc::Sender<PubMessage>,
/// A joinset of authentication tasks.
pub(super) connection_tasks: JoinSet<ConnectionResult<T::Io, T::Error>>,
/// Optional payload decompressor.
pub(super) decompressor: Option<Arc<dyn Decompressor>>,
/// The set of subscribed topics.
pub(super) subscribed_topics: HashSet<String>,
/// All active publisher sessions for this subscriber socket.
Expand All @@ -59,18 +56,19 @@ where
if let Poll::Ready(Some((addr, result))) = this.publishers.poll_next_unpin(cx) {
match result {
Ok(mut msg) => {
if let Some(ref compressor) = this.decompressor {
let Ok(decompressed) = compressor.decompress(&msg.payload) else {
match msg.try_decompress() {
None => { /* No decompression necessary */ }
Some(Ok(decompressed)) => msg.payload = decompressed,
Some(Err(e)) => {
error!(
topic = msg.topic.as_str(),
"Failed to decompress message payload"
"Failed to decompress message payload: {:?}", e
);

continue;
};

msg.payload = decompressed;
}
}

this.on_message(PubMessage::new(addr, msg.topic, msg.payload));
}
Err(e) => {
Expand Down Expand Up @@ -109,12 +107,6 @@ impl<T> SubDriver<T>
where
T: ClientTransport + Send + Sync + 'static,
{
/// Sets the payload decompressor for the socket. This decompressor will be used to decompress all incoming
/// messages from the publishers.
pub fn set_decompressor<C: Decompressor>(&mut self, decompressor: C) {
self.decompressor = Some(Arc::new(decompressor));
}

fn on_command(&mut self, cmd: Command) {
debug!("Received command: {:?}", cmd);
match cmd {
Expand Down
13 changes: 0 additions & 13 deletions msg-socket/src/sub/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use tokio::{sync::mpsc, task::JoinSet};
use tokio_stream::StreamMap;

use msg_transport::ClientTransport;
use msg_wire::compression::Decompressor;

use super::{
Command, PubMessage, SocketState, SocketStats, SubDriver, SubError, SubOptions,
Expand Down Expand Up @@ -53,7 +52,6 @@ where
transport: Arc::new(transport),
from_socket,
to_socket,
decompressor: None,
connection_tasks: JoinSet::new(),
publishers: StreamMap::with_capacity(24),
subscribed_topics: HashSet::with_capacity(32),
Expand All @@ -70,17 +68,6 @@ where
}
}

/// Sets the payload decompressor for the socket. This decompressor will be used to decompress
/// all incoming messages from the publishers.
pub fn with_decompressor<C: Decompressor>(mut self, decompressor: C) -> Self {
self.driver
.as_mut()
.expect("Driver has been spawned already, cannot set compressor")
.set_decompressor(decompressor);

self
}

/// Asynchronously connects to the endpoint.
pub async fn connect(&mut self, endpoint: &str) -> Result<(), SubError> {
self.ensure_active_driver();
Expand Down
39 changes: 36 additions & 3 deletions msg-socket/src/sub/stream.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
use bytes::Bytes;
use futures::{SinkExt, Stream, StreamExt};
use std::{
io,
pin::Pin,
task::{ready, Context, Poll},
};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::codec::Framed;
use tracing::debug;
use tracing::{debug, trace};

use super::SubError;
use msg_wire::pubsub;
use msg_wire::{
compression::{
CompressionType, Decompressor, GzipDecompressor, SnappyDecompressor, ZstdDecompressor,
},
pubsub,
};

/// Wraps a framed connection to a publisher and exposes all the PUBSUB specific methods.
pub(super) struct PublisherStream<Io> {
Expand Down Expand Up @@ -49,30 +55,57 @@ impl<Io: AsyncRead + AsyncWrite + Unpin> PublisherStream<Io> {

pub(super) struct TopicMessage {
pub timestamp: u64,
pub compression_type: u8,
pub topic: String,
pub payload: Bytes,
}

impl TopicMessage {
/// Tries to decompress the message payload if necessary.
///
/// - Returns `Some(Ok(Bytes))` if the payload is compressed and decompression succeeded.
/// - Returns `Some(Err(..))` if the payload is compressed but could not be decompressed.
/// - Returns `None` if the payload is not compressed.
pub fn try_decompress(&self) -> Option<Result<Bytes, io::Error>> {
match CompressionType::try_from(self.compression_type) {
Ok(supported_compression_type) => match supported_compression_type {
CompressionType::None => None,
// NOTE: Decompressors are unit structs, so there is no allocation here
CompressionType::Gzip => Some(GzipDecompressor.decompress(&self.payload)),
CompressionType::Zstd => Some(ZstdDecompressor.decompress(&self.payload)),
CompressionType::Snappy => Some(SnappyDecompressor.decompress(&self.payload)),
},
Err(unsupported_compression_type) => Some(Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("unsupported compression type: {unsupported_compression_type}"),
))),
}
}
}

impl<Io: AsyncRead + AsyncWrite + Unpin> Stream for PublisherStream<Io> {
type Item = Result<TopicMessage, pubsub::Error>;

#[inline]
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();

// We set flush to false only when flush returns ready (i.e. the buffer is fully flushed)
if this.flush && this.conn.poll_flush_unpin(cx).is_ready() {
tracing::trace!("Flushed connection");
trace!("Flushed connection");
this.flush = false
}

if let Some(result) = ready!(this.conn.poll_next_unpin(cx)) {
return Poll::Ready(Some(result.map(|msg| {
let timestamp = msg.timestamp();
let compression_type = msg.compression_type();
let (topic, payload) = msg.into_parts();

// TODO: this will allocate. Can we just return the `Cow`?
let topic = String::from_utf8_lossy(&topic).to_string();
TopicMessage {
compression_type,
timestamp,
topic,
payload,
Expand Down
1 change: 1 addition & 0 deletions msg-wire/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ tokio-util.workspace = true
tracing.workspace = true
flate2.workspace = true
zstd.workspace = true
snap.workspace = true
12 changes: 5 additions & 7 deletions msg-wire/src/compression/gzip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use bytes::Bytes;
use flate2::{read::GzDecoder, write::GzEncoder, Compression};
use std::io::{self, Read, Write};

use super::{Compressor, Decompressor};
use super::{CompressionType, Compressor, Decompressor};

/// A compressor that uses the gzip algorithm.
pub struct GzipCompressor {
Expand All @@ -17,6 +17,10 @@ impl GzipCompressor {
}

impl Compressor for GzipCompressor {
fn compression_type(&self) -> CompressionType {
CompressionType::Gzip
}

fn compress(&self, data: &[u8]) -> Result<Bytes, io::Error> {
// Optimistically allocate the compressed buffer to 1/4 of the original size.
let mut encoder = GzEncoder::new(
Expand All @@ -35,12 +39,6 @@ impl Compressor for GzipCompressor {
#[derive(Debug, Default)]
pub struct GzipDecompressor;

impl GzipDecompressor {
pub fn new() -> Self {
Self
}
}

impl Decompressor for GzipDecompressor {
fn decompress(&self, data: &[u8]) -> Result<Bytes, io::Error> {
let mut decoder = GzDecoder::new(data);
Expand Down
Loading
Loading