diff --git a/src/lib.rs b/src/lib.rs index b466edd1..66b068f3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,8 @@ // This check is new and seems buggy (possibly with PyO3 interaction) #![allow(clippy::borrow_deref_ref)] -use std::collections::HashSet; +use std::collections::{BTreeMap, BTreeSet, HashSet}; +use std::iter::successors; use std::num::NonZeroU64; use std::thread; @@ -15,7 +16,17 @@ use rustc_hash::FxHashMap as HashMap; type Rank = u32; +const LARGE_ENCODER_CHARACTER_LIMIT: usize = 500; + fn _byte_pair_merge(ranks: &HashMap, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> { + if piece.len() < LARGE_ENCODER_CHARACTER_LIMIT { + _byte_pair_merge_small(ranks, piece) // Quadratic, but lightweight + } else { + _byte_pair_merge_large(ranks, piece) // Linearithmic, but heavy + } +} + +fn _byte_pair_merge_small(ranks: &HashMap, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> { // This is a vector of (start, rank). // The rank is of the pair starting at position start. let mut parts = Vec::with_capacity(piece.len() + 1); @@ -73,6 +84,78 @@ fn _byte_pair_merge(ranks: &HashMap, Rank>, piece: &[u8]) -> Vec<(usize, parts } +fn _byte_pair_merge_large(ranks: &HashMap, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> { + let mut rank_indexes = BTreeMap::>::new(); + let mut index_rank = vec![Rank::MAX; piece.len() + 1]; + let mut index_prev = vec![usize::MAX; piece.len() + 1]; + let mut index_next = vec![usize::MAX; piece.len() + 1]; + + let get_rank = |start_idx: usize, end_idx: usize| -> Rank { + *piece.get(start_idx..end_idx) + .and_then(|p| ranks.get(p)) + .unwrap_or(&Rank::MAX) + }; + + let mut prev_node = None; + for i in 0..=piece.len() { + let rank = get_rank(i, i + 2); + index_rank[i] = rank; + if let Some(prev) = prev_node { + index_prev[i] = prev; + index_next[prev] = i; + } + prev_node = Some(i); + + rank_indexes.entry(rank).or_default().insert(i); + } + + while rank_indexes.len() > 1 { + let mut skip_next = false; + if let Some((_, nodes)) = rank_indexes.pop_first() { + for &min_node in &nodes { + if skip_next { + skip_next = false; + continue; + } + + let min_rank = index_rank[min_node]; + + let prev_node = index_prev[min_node]; + let next_node = index_next[min_node]; + let next_next_node = index_next[next_node]; + let next_next_next_node = index_next[next_next_node]; + + if prev_node != usize::MAX { + let new_rank = get_rank(prev_node, next_next_node); + if index_rank[prev_node] != new_rank { + rank_indexes.get_mut(&index_rank[prev_node]).unwrap().remove(&prev_node); + index_rank[prev_node] = new_rank; + rank_indexes.entry(new_rank).or_default().insert(prev_node); + } + } + + let new_rank = get_rank(min_node, next_next_next_node); + index_rank[min_node] = new_rank; + rank_indexes.entry(new_rank).or_default().insert(min_node); + + index_next[min_node] = next_next_node; + index_prev[next_next_node] = min_node; + + let next_node_rank = index_rank[next_node]; + if next_node_rank == min_rank { + skip_next = true; + } else if next_node_rank != Rank::MAX { + rank_indexes.get_mut(&next_node_rank).unwrap().remove(&next_node); + } + } + } + } + + successors(Some(0), |&n| index_next.get(n).filter(|&&x| x != usize::MAX).copied()) + .map(|n| (n, Rank::MAX)) + .collect() +} + pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap, Rank>) -> Vec { assert!(piece.len() > 1); _byte_pair_merge(&ranks, &piece)