From 5af8058ea743b2cb78793755a344a8be12773cc5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Mon, 15 Jan 2024 22:03:26 +0100 Subject: [PATCH] Add a _byte_pair_merge_large for worst-case scenarios We're storing the ranks in a sorted tree of sorted (or linked) trees. Getting the minimum rank is logarithmic and each subsequent occurrence is constant time. To know the previous and next indexes (and the corresponding ranks), we're storing them in arrays (the keys are the indexes). We're updating each after finding the minimum via the tree. We're iterating duplicates without removing them one-by-one, but if they are neighbors, we're skipping them manually. --- src/lib.rs | 85 +++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 84 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index b466edd1..66b068f3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,8 @@ // This check is new and seems buggy (possibly with PyO3 interaction) #![allow(clippy::borrow_deref_ref)] -use std::collections::HashSet; +use std::collections::{BTreeMap, BTreeSet, HashSet}; +use std::iter::successors; use std::num::NonZeroU64; use std::thread; @@ -15,7 +16,17 @@ use rustc_hash::FxHashMap as HashMap; type Rank = u32; +const LARGE_ENCODER_CHARACTER_LIMIT: usize = 500; + fn _byte_pair_merge(ranks: &HashMap, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> { + if piece.len() < LARGE_ENCODER_CHARACTER_LIMIT { + _byte_pair_merge_small(ranks, piece) // Quadratic, but lightweight + } else { + _byte_pair_merge_large(ranks, piece) // Linearithmic, but heavy + } +} + +fn _byte_pair_merge_small(ranks: &HashMap, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> { // This is a vector of (start, rank). // The rank is of the pair starting at position start. let mut parts = Vec::with_capacity(piece.len() + 1); @@ -73,6 +84,78 @@ fn _byte_pair_merge(ranks: &HashMap, Rank>, piece: &[u8]) -> Vec<(usize, parts } +fn _byte_pair_merge_large(ranks: &HashMap, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> { + let mut rank_indexes = BTreeMap::>::new(); + let mut index_rank = vec![Rank::MAX; piece.len() + 1]; + let mut index_prev = vec![usize::MAX; piece.len() + 1]; + let mut index_next = vec![usize::MAX; piece.len() + 1]; + + let get_rank = |start_idx: usize, end_idx: usize| -> Rank { + *piece.get(start_idx..end_idx) + .and_then(|p| ranks.get(p)) + .unwrap_or(&Rank::MAX) + }; + + let mut prev_node = None; + for i in 0..=piece.len() { + let rank = get_rank(i, i + 2); + index_rank[i] = rank; + if let Some(prev) = prev_node { + index_prev[i] = prev; + index_next[prev] = i; + } + prev_node = Some(i); + + rank_indexes.entry(rank).or_default().insert(i); + } + + while rank_indexes.len() > 1 { + let mut skip_next = false; + if let Some((_, nodes)) = rank_indexes.pop_first() { + for &min_node in &nodes { + if skip_next { + skip_next = false; + continue; + } + + let min_rank = index_rank[min_node]; + + let prev_node = index_prev[min_node]; + let next_node = index_next[min_node]; + let next_next_node = index_next[next_node]; + let next_next_next_node = index_next[next_next_node]; + + if prev_node != usize::MAX { + let new_rank = get_rank(prev_node, next_next_node); + if index_rank[prev_node] != new_rank { + rank_indexes.get_mut(&index_rank[prev_node]).unwrap().remove(&prev_node); + index_rank[prev_node] = new_rank; + rank_indexes.entry(new_rank).or_default().insert(prev_node); + } + } + + let new_rank = get_rank(min_node, next_next_next_node); + index_rank[min_node] = new_rank; + rank_indexes.entry(new_rank).or_default().insert(min_node); + + index_next[min_node] = next_next_node; + index_prev[next_next_node] = min_node; + + let next_node_rank = index_rank[next_node]; + if next_node_rank == min_rank { + skip_next = true; + } else if next_node_rank != Rank::MAX { + rank_indexes.get_mut(&next_node_rank).unwrap().remove(&next_node); + } + } + } + } + + successors(Some(0), |&n| index_next.get(n).filter(|&&x| x != usize::MAX).copied()) + .map(|n| (n, Rank::MAX)) + .collect() +} + pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap, Rank>) -> Vec { assert!(piece.len() > 1); _byte_pair_merge(&ranks, &piece)