From 7d4ea71e2cdf02fc302c59eb27a49b2ff38d5709 Mon Sep 17 00:00:00 2001 From: Martin Habovstiak Date: Fri, 31 Mar 2023 20:40:13 +0200 Subject: [PATCH] Replace internal buffer in decoder with `BufRead` The `DecoderReader` used an internal buffer for reading which came with a number of disadvantages, such as: * Needless copying from in-memory readers (slices) * Double-buffering already-buffered readers * The bytes are lost when dropping or calling `into_inner` * `BUF_SIZE` is not configurable * `std::io::BufReader` has access to some unstable optimizations which this crate cannot use * Reinvents the wheel; there already is `std::io::BufReader` This change removes it and requires `BufRead` instead. Decoding is implemented on top of `fill_buf` for larger chunks with fallback to small, stack-allocated buffer for tiny chunks that may appear at boundaries. To resolve borrowing problems `decode_to_buf` had to be removed which also enabled decoding bytes directly into internal buffer when the buffer to be filled is small rather than into a temporary buffer which is then copied. This improves performance of reading by around 7-22% on my machine. --- examples/base64.rs | 6 +- src/read/decoder.rs | 190 +++++++++++++++----------------------- src/read/decoder_tests.rs | 22 ++++- 3 files changed, 99 insertions(+), 119 deletions(-) diff --git a/examples/base64.rs b/examples/base64.rs index 0a214d2..aed6222 100644 --- a/examples/base64.rs +++ b/examples/base64.rs @@ -1,5 +1,5 @@ use std::fs::File; -use std::io::{self, Read}; +use std::io::{self, BufRead, BufReader}; use std::path::PathBuf; use std::process; use std::str::FromStr; @@ -48,7 +48,7 @@ struct Opt { fn main() { let opt = Opt::from_args(); let stdin; - let mut input: Box = match opt.file { + let mut input: Box = match opt.file { None => { stdin = io::stdin(); Box::new(stdin.lock()) @@ -57,7 +57,7 @@ fn main() { stdin = io::stdin(); Box::new(stdin.lock()) } - Some(f) => Box::new(File::open(f).unwrap()), + Some(f) => Box::new(File::open(f).map(BufReader::new).unwrap()), }; let alphabet = opt.alphabet.unwrap_or_default(); diff --git a/src/read/decoder.rs b/src/read/decoder.rs index 4888c9c..3d641c9 100644 --- a/src/read/decoder.rs +++ b/src/read/decoder.rs @@ -1,9 +1,6 @@ use crate::{engine::Engine, DecodeError}; use std::{cmp, fmt, io}; -// This should be large, but it has to fit on the stack. -pub(crate) const BUF_SIZE: usize = 1024; - // 4 bytes of base64 data encode 3 bytes of raw data (modulo padding). const BASE64_CHUNK_SIZE: usize = 4; const DECODED_CHUNK_SIZE: usize = 3; @@ -30,17 +27,11 @@ const DECODED_CHUNK_SIZE: usize = 3; /// assert_eq!(b"asdf", &result[..]); /// /// ``` -pub struct DecoderReader<'e, E: Engine, R: io::Read> { +pub struct DecoderReader<'e, E: Engine, R: io::BufRead> { engine: &'e E, /// Where b64 data is read from inner: R, - // Holds b64 data read from the delegate reader. - b64_buffer: [u8; BUF_SIZE], - // The start of the pending buffered data in b64_buffer. - b64_offset: usize, - // The amount of buffered b64 data. - b64_len: usize, // Since the caller may provide us with a buffer of size 1 or 2 that's too small to copy a // decoded chunk in to, we have to be able to hang on to a few decoded bytes. // Technically we only need to hold 2 bytes but then we'd need a separate temporary buffer to @@ -55,11 +46,9 @@ pub struct DecoderReader<'e, E: Engine, R: io::Read> { total_b64_decoded: usize, } -impl<'e, E: Engine, R: io::Read> fmt::Debug for DecoderReader<'e, E, R> { +impl<'e, E: Engine, R: io::BufRead> fmt::Debug for DecoderReader<'e, E, R> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("DecoderReader") - .field("b64_offset", &self.b64_offset) - .field("b64_len", &self.b64_len) .field("decoded_buffer", &self.decoded_buffer) .field("decoded_offset", &self.decoded_offset) .field("decoded_len", &self.decoded_len) @@ -68,15 +57,12 @@ impl<'e, E: Engine, R: io::Read> fmt::Debug for DecoderReader<'e, E, R> { } } -impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> { +impl<'e, E: Engine, R: io::BufRead> DecoderReader<'e, E, R> { /// Create a new decoder that will read from the provided reader `r`. pub fn new(reader: R, engine: &'e E) -> Self { DecoderReader { engine, inner: reader, - b64_buffer: [0; BUF_SIZE], - b64_offset: 0, - b64_len: 0, decoded_buffer: [0; DECODED_CHUNK_SIZE], decoded_offset: 0, decoded_len: 0, @@ -107,59 +93,6 @@ impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> { Ok(copy_len) } - /// Read into the remaining space in the buffer after the current contents. - /// Must only be called when there is space to read into in the buffer. - /// Returns the number of bytes read. - fn read_from_delegate(&mut self) -> io::Result { - debug_assert!(self.b64_offset + self.b64_len < BUF_SIZE); - - let read = self - .inner - .read(&mut self.b64_buffer[self.b64_offset + self.b64_len..])?; - self.b64_len += read; - - debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE); - - Ok(read) - } - - /// Decode the requested number of bytes from the b64 buffer into the provided buffer. It's the - /// caller's responsibility to choose the number of b64 bytes to decode correctly. - /// - /// Returns a Result with the number of decoded bytes written to `buf`. - fn decode_to_buf(&mut self, num_bytes: usize, buf: &mut [u8]) -> io::Result { - debug_assert!(self.b64_len >= num_bytes); - debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE); - debug_assert!(!buf.is_empty()); - - let decoded = self - .engine - .internal_decode( - &self.b64_buffer[self.b64_offset..self.b64_offset + num_bytes], - buf, - self.engine.internal_decoded_len_estimate(num_bytes), - ) - .map_err(|e| match e { - DecodeError::InvalidByte(offset, byte) => { - DecodeError::InvalidByte(self.total_b64_decoded + offset, byte) - } - DecodeError::InvalidLength => DecodeError::InvalidLength, - DecodeError::InvalidLastSymbol(offset, byte) => { - DecodeError::InvalidLastSymbol(self.total_b64_decoded + offset, byte) - } - DecodeError::InvalidPadding => DecodeError::InvalidPadding, - }) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - - self.total_b64_decoded += num_bytes; - self.b64_offset += num_bytes; - self.b64_len -= num_bytes; - - debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE); - - Ok(decoded) - } - /// Unwraps this `DecoderReader`, returning the base reader which it reads base64 encoded /// input from. /// @@ -171,7 +104,23 @@ impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> { } } -impl<'e, E: Engine, R: io::Read> io::Read for DecoderReader<'e, E, R> { +fn map_error_offset(total_b64_decoded: usize) -> impl FnOnce(DecodeError) -> io::Error { + move |error| { + let error = match error { + DecodeError::InvalidByte(offset, byte) => { + DecodeError::InvalidByte(total_b64_decoded + offset, byte) + } + DecodeError::InvalidLength => DecodeError::InvalidLength, + DecodeError::InvalidLastSymbol(offset, byte) => { + DecodeError::InvalidLastSymbol(total_b64_decoded + offset, byte) + } + DecodeError::InvalidPadding => DecodeError::InvalidPadding, + }; + io::Error::new(io::ErrorKind::InvalidData, error) + } +} + +impl<'e, E: Engine, R: io::BufRead> io::Read for DecoderReader<'e, E, R> { /// Decode input from the wrapped reader. /// /// Under non-error circumstances, this returns `Ok` with the value being the number of bytes @@ -189,15 +138,6 @@ impl<'e, E: Engine, R: io::Read> io::Read for DecoderReader<'e, E, R> { return Ok(0); } - // offset == BUF_SIZE when we copied it all last time - debug_assert!(self.b64_offset <= BUF_SIZE); - debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE); - debug_assert!(if self.b64_offset == BUF_SIZE { - self.b64_len == 0 - } else { - self.b64_len <= BUF_SIZE - }); - debug_assert!(if self.decoded_len == 0 { // can be = when we were able to copy the complete chunk self.decoded_offset <= DECODED_CHUNK_SIZE @@ -215,54 +155,66 @@ impl<'e, E: Engine, R: io::Read> io::Read for DecoderReader<'e, E, R> { // we have a few leftover decoded bytes; flush that rather than pull in more b64 self.flush_decoded_buf(buf) } else { - let mut at_eof = false; - while self.b64_len < BASE64_CHUNK_SIZE { - // Work around lack of copy_within, which is only present in 1.37 - // Copy any bytes we have to the start of the buffer. - // We know we have < 1 chunk, so we can use a tiny tmp buffer. - let mut memmove_buf = [0_u8; BASE64_CHUNK_SIZE]; - memmove_buf[..self.b64_len].copy_from_slice( - &self.b64_buffer[self.b64_offset..self.b64_offset + self.b64_len], - ); - self.b64_buffer[0..self.b64_len].copy_from_slice(&memmove_buf[..self.b64_len]); - self.b64_offset = 0; + let mut b64_bytes = self.inner.fill_buf()?; - // then fill in more data - let read = self.read_from_delegate()?; - if read == 0 { - // we never pass in an empty buf, so 0 => we've hit EOF - at_eof = true; - break; - } - } - - if self.b64_len == 0 { - debug_assert!(at_eof); - // we must be at EOF, and we have no data left to decode + if b64_bytes.is_empty() { return Ok(0); }; + let mut b64_bytes_tmp; + let mut at_eof = false; + let mut short = false; + if b64_bytes.len() < BASE64_CHUNK_SIZE { + short = true; + // Read as much as we can, trying to have a full chunk. + b64_bytes_tmp = [0; BASE64_CHUNK_SIZE]; + b64_bytes_tmp[..b64_bytes.len()].copy_from_slice(b64_bytes); + let mut pos = b64_bytes.len(); + self.inner.consume(pos); + while pos < BASE64_CHUNK_SIZE { + let bytes_read = match self.inner.read(&mut b64_bytes_tmp[pos..]) { + Ok(len) => len, + Err(error) if error.kind() == io::ErrorKind::Interrupted => continue, + Err(error) => return Err(error), + }; + if bytes_read == 0 { + at_eof = true; + break; + } + pos += bytes_read; + } + b64_bytes = &b64_bytes_tmp[..pos]; + } + debug_assert!(if at_eof { // if we are at eof, we may not have a complete chunk - self.b64_len > 0 + b64_bytes.len() > 0 } else { // otherwise, we must have at least one chunk - self.b64_len >= BASE64_CHUNK_SIZE + b64_bytes.len() >= BASE64_CHUNK_SIZE }); debug_assert_eq!(0, self.decoded_len); if buf.len() < DECODED_CHUNK_SIZE { // caller requested an annoyingly short read - // have to write to a tmp buf first to avoid double mutable borrow - let mut decoded_chunk = [0_u8; DECODED_CHUNK_SIZE]; // if we are at eof, could have less than BASE64_CHUNK_SIZE, in which case we have // to assume that these last few tokens are, in fact, valid (i.e. must be 2-4 b64 // tokens, not 1, since 1 token can't decode to 1 byte). - let to_decode = cmp::min(self.b64_len, BASE64_CHUNK_SIZE); + let to_decode = cmp::min(b64_bytes.len(), BASE64_CHUNK_SIZE); + debug_assert!(b64_bytes.len() > BASE64_CHUNK_SIZE || to_decode == b64_bytes.len()); - let decoded = self.decode_to_buf(to_decode, &mut decoded_chunk[..])?; - self.decoded_buffer[..decoded].copy_from_slice(&decoded_chunk[..decoded]); + let decoded = self + .engine + .internal_decode( + &b64_bytes[..to_decode], + &mut self.decoded_buffer, + self.engine.internal_decoded_len_estimate(to_decode), + ) + .map_err(map_error_offset(self.total_b64_decoded))?; + + self.total_b64_decoded += to_decode; + if !short { self.inner.consume(to_decode); } self.decoded_offset = 0; self.decoded_len = decoded; @@ -270,6 +222,7 @@ impl<'e, E: Engine, R: io::Read> io::Read for DecoderReader<'e, E, R> { // can be less than 3 on last block due to padding debug_assert!(decoded <= 3); + self.flush_decoded_buf(buf) } else { let b64_bytes_that_can_decode_into_buf = (buf.len() / DECODED_CHUNK_SIZE) @@ -278,17 +231,28 @@ impl<'e, E: Engine, R: io::Read> io::Read for DecoderReader<'e, E, R> { debug_assert!(b64_bytes_that_can_decode_into_buf >= BASE64_CHUNK_SIZE); let b64_bytes_available_to_decode = if at_eof { - self.b64_len + b64_bytes.len() } else { // only use complete chunks - self.b64_len - self.b64_len % 4 + b64_bytes.len() - b64_bytes.len() % 4 }; let actual_decode_len = cmp::min( b64_bytes_that_can_decode_into_buf, b64_bytes_available_to_decode, ); - self.decode_to_buf(actual_decode_len, buf) + let decoded = self + .engine + .internal_decode( + &b64_bytes[..actual_decode_len], + buf, + self.engine.internal_decoded_len_estimate(actual_decode_len), + ) + .map_err(map_error_offset(self.total_b64_decoded))?; + + self.total_b64_decoded += actual_decode_len; + if !short { self.inner.consume(actual_decode_len); } + Ok(decoded) } } } diff --git a/src/read/decoder_tests.rs b/src/read/decoder_tests.rs index 65d58d8..7653dc0 100644 --- a/src/read/decoder_tests.rs +++ b/src/read/decoder_tests.rs @@ -6,13 +6,15 @@ use std::{ use rand::{Rng as _, RngCore as _}; -use super::decoder::{DecoderReader, BUF_SIZE}; +use super::decoder::DecoderReader; use crate::{ engine::{general_purpose::STANDARD, Engine, GeneralPurpose}, tests::{random_alphabet, random_config, random_engine}, DecodeError, }; +const BUF_SIZE: usize = 1024; + #[test] fn simple() { let tests: &[(&[u8], &[u8])] = &[ @@ -113,7 +115,6 @@ fn handles_short_read_from_delegate() { }; let mut decoder = DecoderReader::new(&mut short_reader, &engine); - let decoded_len = decoder.read_to_end(&mut decoded).unwrap(); assert_eq!(size, decoded_len); assert_eq!(&bytes[..], &decoded[..]); @@ -341,6 +342,21 @@ impl<'a, 'b, R: io::Read, N: rand::Rng> io::Read for RandomShortRead<'a, 'b, R, // avoid 0 since it means EOF for non-empty buffers let effective_len = cmp::min(self.rng.gen_range(1..20), buf.len()); - self.delegate.read(&mut buf[..effective_len]) + self.delegate.read(&mut buf [..effective_len]) + } +} + +impl<'a, 'b, R: io::BufRead, N: rand::Rng> io::BufRead for RandomShortRead<'a, 'b, R, N> { + fn fill_buf(&mut self) -> Result<&[u8], io::Error> { + self.delegate.fill_buf().map(|buf| { + // avoid 0 since it means EOF for non-empty buffers + let effective_len = cmp::min(self.rng.gen_range(1..20), buf.len()); + + &buf[..effective_len] + }) + } + + fn consume(&mut self, amount: usize) { + self.delegate.consume(amount) } }