Skip to content

Commit

Permalink
add roundtrip test for huffman coding that's fit to also be fuzzed
Browse files Browse the repository at this point in the history
  • Loading branch information
KillingSpark committed Oct 11, 2024
1 parent 8443da9 commit cc53b2a
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 4 deletions.
7 changes: 6 additions & 1 deletion src/encoding/bit_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,15 @@ impl BitWriter {
/// dumping
pub fn dump(self) -> Vec<u8> {
if self.bit_idx % 8 != 0 {
panic!("`dump` was called on a bit writer but an even number of bytes weren't written into the buffer")
panic!("`dump` was called on a bit writer but an even number of bytes weren't written into the buffer. Was: {}", self.bit_idx)
}
self.output
}

/// Returns how many bits are missing for an even byte
pub fn misaligned(&self) -> usize {
8 - (self.bit_idx % 8)
}
}

#[cfg(test)]
Expand Down
7 changes: 7 additions & 0 deletions src/huff0/huff0_decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -586,4 +586,11 @@ impl HuffmanTable {

Ok(())
}

pub(super) fn from_weights(weights: Vec<u8>) -> Self {
let mut new = Self::new();
new.weights = weights;
new.build_table_from_weights().unwrap();
new
}
}
31 changes: 28 additions & 3 deletions src/huff0/huff0_encoder.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use alloc::vec::Vec;
use core::cmp::Ordering;
use std::eprintln;

use crate::encoding::bit_writer::BitWriter;

Expand All @@ -10,7 +9,13 @@ pub struct HuffmanEncoder {
}

impl HuffmanEncoder {
fn encode(&mut self, data: &[u8]) {
pub fn new(table: HuffmanTable) -> Self {
Self {
table,
writer: BitWriter::new(),
}
}
pub fn encode(&mut self, data: &[u8]) {
for symbol in data {
let (code, mut num_bits) = self.table.codes[*symbol as usize];
while num_bits > 0 {
Expand All @@ -21,11 +26,29 @@ impl HuffmanEncoder {
}
}
}
fn dump(&mut self) -> Vec<u8> {
pub fn dump(&mut self) -> Vec<u8> {
let mut writer = BitWriter::new();
std::mem::swap(&mut self.writer, &mut writer);
let bits_to_fill = writer.misaligned();
if bits_to_fill == 0 {
writer.write_bits(&[(1u8 << 7)], 8);
} else {
writer.write_bits(&[(1u8 << (bits_to_fill - 1))], bits_to_fill);
}
writer.dump()
}
pub(super) fn weights(&self) -> Vec<u8> {
let max = self.table.codes.iter().map(|(_, nb)| nb).max().unwrap();
let mut weights = self
.table
.codes
.iter()
.copied()
.map(|(_, nb)| if nb == 0 { 0 } else { max - nb + 1 })
.collect::<Vec<u8>>();

weights
}
}

pub struct HuffmanTable {
Expand All @@ -51,6 +74,7 @@ impl HuffmanTable {
let mut weights = distribute_weights(counts.len() - zeros);
let limit = weights.len().ilog2() as usize + 2;
redistribute_weights(&mut weights, limit);

weights.reverse();
let mut counts_sorted = counts.iter().enumerate().collect::<Vec<_>>();
counts_sorted.sort_by(|(_, c1), (_, c2)| c1.cmp(c2));
Expand All @@ -64,6 +88,7 @@ impl HuffmanTable {
weights_distributed[idx] = weights.pop().unwrap();
}
}

Self::build_from_weights(&weights_distributed)
}

Expand Down
35 changes: 35 additions & 0 deletions src/huff0/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,40 @@
/// used symbols get longer codes. Codes are prefix free, meaning no two codes
/// will start with the same sequence of bits.
mod huff0_decoder;
use std::vec::Vec;

pub use huff0_decoder::*;

use crate::decoding::bit_reader_reverse::BitReaderReversed;
mod huff0_encoder;

pub fn round_trip(data: &[u8]) {
let encoder_table = huff0_encoder::HuffmanTable::build_from_data(data);
let mut encoder = huff0_encoder::HuffmanEncoder::new(encoder_table);

encoder.encode(data);
let encoded = encoder.dump();
let decoder_table = HuffmanTable::from_weights(encoder.weights());
let mut decoder = HuffmanDecoder::new(&decoder_table);
let mut br = BitReaderReversed::new(&encoded);

for _ in 0..7 {
if br.get_bits(1) == 1 {
break;
}
}

decoder.init_state(&mut br);
let mut decoded = Vec::new();
while br.bits_remaining() > 0 {
let symbol = decoder.decode_symbol();
decoder.next_state(&mut br);
decoded.push(symbol);
}
assert_eq!(&decoded, data);
}

#[test]
fn roundtrip() {
round_trip(&[1, 1, 1, 2, 3]);
}

0 comments on commit cc53b2a

Please sign in to comment.