Skip to content

Commit

Permalink
Use a table-based Huffman decoder (#88)
Browse files Browse the repository at this point in the history
  • Loading branch information
fintelia authored Aug 5, 2024
1 parent 22f23ef commit d49ef2d
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 137 deletions.
259 changes: 135 additions & 124 deletions src/huffman.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,49 +17,45 @@ enum HuffmanTreeNode {
Empty,
}

/// Huffman tree
#[derive(Clone, Debug, Default)]
pub(crate) struct HuffmanTree {
tree: Vec<HuffmanTreeNode>,
max_nodes: usize,
num_nodes: usize,
#[derive(Clone, Debug)]
enum HuffmanTreeInner {
Single(u16),
Tree {
tree: Vec<HuffmanTreeNode>,
table: Vec<u32>,
table_mask: u16,
},
}

impl HuffmanTree {
fn is_full(&self) -> bool {
self.num_nodes == self.max_nodes
}

/// Turns a node from empty into a branch and assigns its children
fn assign_children(&mut self, node_index: usize) -> usize {
let offset_index = self.num_nodes - node_index;
self.tree[node_index] = HuffmanTreeNode::Branch(offset_index);
self.num_nodes += 2;
/// Huffman tree
#[derive(Clone, Debug)]
pub(crate) struct HuffmanTree(HuffmanTreeInner);

offset_index
impl Default for HuffmanTree {
fn default() -> Self {
Self(HuffmanTreeInner::Single(0))
}
}

/// Init a huffman tree
fn init(num_leaves: usize) -> Result<HuffmanTree, DecodingError> {
if num_leaves == 0 {
return Err(DecodingError::HuffmanError);
}
impl HuffmanTree {
/// Builds a tree implicitly, just from code lengths
pub(crate) fn build_implicit(code_lengths: Vec<u16>) -> Result<HuffmanTree, DecodingError> {
let mut num_symbols = 0;
let mut root_symbol = 0;

let max_nodes = 2 * num_leaves - 1;
let tree = vec![HuffmanTreeNode::Empty; max_nodes];
let num_nodes = 1;
for (symbol, length) in code_lengths.iter().enumerate() {
if *length > 0 {
num_symbols += 1;
root_symbol = symbol.try_into().unwrap();
}
}

let tree = HuffmanTree {
tree,
max_nodes,
num_nodes,
if num_symbols == 0 {
return Err(DecodingError::HuffmanError);
} else if num_symbols == 1 {
return Ok(Self::build_single_node(root_symbol));
};

Ok(tree)
}

/// Converts code lengths to codes
fn code_lengths_to_codes(code_lengths: &[u16]) -> Result<Vec<Option<u16>>, DecodingError> {
let max_code_length = *code_lengths
.iter()
.reduce(|a, b| if a >= b { a } else { b })
Expand All @@ -86,129 +82,117 @@ impl HuffmanTree {

// Assign codes
let mut curr_code = 0;
let mut next_codes = [None; MAX_ALLOWED_CODE_LENGTH + 1];
let mut next_codes = [0; MAX_ALLOWED_CODE_LENGTH + 1];
for code_len in 1..=usize::from(max_code_length) {
curr_code = (curr_code + code_length_hist[code_len - 1]) << 1;
next_codes[code_len] = Some(curr_code);
next_codes[code_len] = curr_code;
}
let mut huff_codes = vec![None; code_lengths.len()];
let mut huff_codes = vec![0u16; code_lengths.len()];
for (symbol, &length) in code_lengths.iter().enumerate() {
let length = usize::from(length);
if length > 0 {
huff_codes[symbol] = next_codes[length];
if let Some(value) = next_codes[length].as_mut() {
*value += 1;
}
} else {
huff_codes[symbol] = None;
next_codes[length] += 1;
}
}

Ok(huff_codes)
}

/// Adds a symbol to a huffman tree
fn add_symbol(
&mut self,
symbol: u16,
code: u16,
code_length: u16,
) -> Result<(), DecodingError> {
let mut node_index = 0;
let code = usize::from(code);

for length in (0..code_length).rev() {
if node_index >= self.max_nodes {
return Err(DecodingError::HuffmanError);
}

let node = self.tree[node_index];

let offset = match node {
HuffmanTreeNode::Empty => {
if self.is_full() {
return Err(DecodingError::HuffmanError);
}
self.assign_children(node_index)
// Populate decoding table
let table_bits = max_code_length.min(10);
let table_size = (1 << table_bits) as usize;
let table_mask = table_size as u16 - 1;
let mut table = vec![0; table_size];
for (symbol, (&code, &length)) in huff_codes.iter().zip(code_lengths.iter()).enumerate() {
if length != 0 && length <= table_bits {
let mut j = (u16::reverse_bits(code) >> (16 - length)) as usize;
let entry = ((length as u32) << 16) | symbol as u32;
while j < table_size {
table[j] = entry;
j += 1 << length as usize;
}
HuffmanTreeNode::Leaf(_) => return Err(DecodingError::HuffmanError),
HuffmanTreeNode::Branch(offset) => offset,
};

node_index += offset + ((code >> length) & 1);
}

match self.tree[node_index] {
HuffmanTreeNode::Empty => self.tree[node_index] = HuffmanTreeNode::Leaf(symbol),
HuffmanTreeNode::Leaf(_) => return Err(DecodingError::HuffmanError),
HuffmanTreeNode::Branch(_offset) => return Err(DecodingError::HuffmanError),
}

Ok(())
}

/// Builds a tree implicitly, just from code lengths
pub(crate) fn build_implicit(code_lengths: Vec<u16>) -> Result<HuffmanTree, DecodingError> {
let mut num_symbols = 0;
let mut root_symbol = 0;

for (symbol, length) in code_lengths.iter().enumerate() {
if *length > 0 {
num_symbols += 1;
root_symbol = symbol.try_into().unwrap();
}
}

let mut tree = HuffmanTree::init(num_symbols)?;

if num_symbols == 1 {
tree.add_symbol(root_symbol, 0, 0)?;
} else {
let codes = HuffmanTree::code_lengths_to_codes(&code_lengths)?;
// If the longest code is larger than the table size, build a tree as a fallback.
let mut tree = Vec::new();
if max_code_length > table_bits {
tree = vec![HuffmanTreeNode::Empty; 2 * num_symbols - 1];

let mut num_nodes = 1;
for (symbol, &length) in code_lengths.iter().enumerate() {
if length > 0 && codes[symbol].is_some() {
tree.add_symbol(symbol.try_into().unwrap(), codes[symbol].unwrap(), length)?;
let code = huff_codes[symbol];
let code_length = length;
let symbol = symbol.try_into().unwrap();

if length > 0 {
let mut node_index = 0;
let code = usize::from(code);

for length in (0..code_length).rev() {
let node = tree[node_index];

let offset = match node {
HuffmanTreeNode::Empty => {
// Turns a node from empty into a branch and assigns its children
let offset_index = num_nodes - node_index;
tree[node_index] = HuffmanTreeNode::Branch(offset_index);
num_nodes += 2;
offset_index
}
HuffmanTreeNode::Leaf(_) => return Err(DecodingError::HuffmanError),
HuffmanTreeNode::Branch(offset) => offset,
};

node_index += offset + ((code >> length) & 1);
}

match tree[node_index] {
HuffmanTreeNode::Empty => tree[node_index] = HuffmanTreeNode::Leaf(symbol),
HuffmanTreeNode::Leaf(_) => return Err(DecodingError::HuffmanError),
HuffmanTreeNode::Branch(_offset) => {
return Err(DecodingError::HuffmanError)
}
}
}
}
}

Ok(tree)
Ok(Self(HuffmanTreeInner::Tree {
tree,
table,
table_mask,
}))
}

/// Builds a tree explicitly from lengths, codes and symbols
pub(crate) fn build_explicit(
code_lengths: Vec<u16>,
codes: Vec<u16>,
symbols: Vec<u16>,
) -> Result<HuffmanTree, DecodingError> {
let mut tree = HuffmanTree::init(symbols.len())?;

for i in 0..symbols.len() {
tree.add_symbol(symbols[i], codes[i], code_lengths[i])?;
}
pub(crate) fn build_single_node(symbol: u16) -> HuffmanTree {
Self(HuffmanTreeInner::Single(symbol))
}

Ok(tree)
pub(crate) fn build_two_node(zero: u16, one: u16) -> HuffmanTree {
Self(HuffmanTreeInner::Tree {
tree: vec![
HuffmanTreeNode::Leaf(zero),
HuffmanTreeNode::Leaf(one),
HuffmanTreeNode::Empty,
],
table: vec![1 << 16 | zero as u32, 1 << 16 | one as u32],
table_mask: 0x1,
})
}

pub(crate) fn is_single_node(&self) -> bool {
self.num_nodes == 1
matches!(self.0, HuffmanTreeInner::Single(_))
}

/// Reads a symbol using the bitstream.
///
/// You must call call `bit_reader.fill()` before calling this function or it may erroroneosly
/// detect the end of the stream and return a bitstream error.
pub(crate) fn read_symbol<R: Read>(
&self,
#[inline(never)]
fn read_symbol_slowpath<R: Read>(
tree: &[HuffmanTreeNode],
mut v: usize,
bit_reader: &mut BitReader<R>,
) -> Result<u16, DecodingError> {
let mut v = bit_reader.peek(15) as usize;
let mut depth = 0;

let mut index = 0;
loop {
match &self.tree[index] {
match &tree[index] {
HuffmanTreeNode::Branch(children_offset) => {
index += children_offset + (v & 1);
depth += 1;
Expand All @@ -222,4 +206,31 @@ impl HuffmanTree {
}
}
}

/// Reads a symbol using the bitstream.
///
/// You must call call `bit_reader.fill()` before calling this function or it may erroroneosly
/// detect the end of the stream and return a bitstream error.
pub(crate) fn read_symbol<R: Read>(
&self,
bit_reader: &mut BitReader<R>,
) -> Result<u16, DecodingError> {
match &self.0 {
HuffmanTreeInner::Tree {
tree,
table,
table_mask,
} => {
let v = bit_reader.peek_full() as u16;
let entry = table[(v & table_mask) as usize];
if entry != 0 {
bit_reader.consume((entry >> 16) as u8)?;
return Ok(entry as u16);
}

Self::read_symbol_slowpath(tree, v as usize, bit_reader)
}
HuffmanTreeInner::Single(symbol) => Ok(*symbol),
}
}
}
29 changes: 16 additions & 13 deletions src/lossless.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,24 +358,22 @@ impl<R: Read> LosslessDecoder<R> {
if simple {
let num_symbols = self.bit_reader.read_bits::<u8>(1)? + 1;

let mut code_lengths = vec![u16::from(num_symbols - 1)];
let mut codes = vec![0];
let mut symbols = Vec::new();

let is_first_8bits = self.bit_reader.read_bits::<u8>(1)?;
symbols.push(self.bit_reader.read_bits::<u16>(1 + 7 * is_first_8bits)?);

if num_symbols == 2 {
symbols.push(self.bit_reader.read_bits::<u16>(8)?);
code_lengths.push(1);
codes.push(1);
}
let zero_symbol = self.bit_reader.read_bits::<u16>(1 + 7 * is_first_8bits)?;

if symbols.iter().any(|&s| s > alphabet_size) {
if zero_symbol >= alphabet_size {
return Err(DecodingError::BitStreamError);
}

HuffmanTree::build_explicit(code_lengths, codes, symbols)
if num_symbols == 1 {
Ok(HuffmanTree::build_single_node(zero_symbol))
} else {
let one_symbol = self.bit_reader.read_bits::<u16>(8)?;
if one_symbol >= alphabet_size {
return Err(DecodingError::BitStreamError);
}
Ok(HuffmanTree::build_two_node(zero_symbol, one_symbol))
}
} else {
let mut code_length_code_lengths = vec![0; CODE_LENGTH_CODES];

Expand Down Expand Up @@ -751,6 +749,11 @@ impl<R: Read> BitReader<R> {
self.buffer & ((1 << num) - 1)
}

/// Peeks at the full buffer.
pub(crate) fn peek_full(&self) -> u64 {
self.buffer
}

/// Consumes `num` bits from the buffer returning an error if there are not enough bits.
pub(crate) fn consume(&mut self, num: u8) -> Result<(), DecodingError> {
if self.nbits < num {
Expand Down

0 comments on commit d49ef2d

Please sign in to comment.