From 53dd8066d62c78d513239838459d9406c5ecd211 Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 4 Dec 2024 12:55:13 +0100 Subject: [PATCH] Add tokenizers support. --- yomikomi/Cargo.toml | 1 + yomikomi/src/error.rs | 3 ++ yomikomi/src/tokenize.rs | 60 +++++++++++++++++++++++++++++++++------- 3 files changed, 54 insertions(+), 10 deletions(-) diff --git a/yomikomi/Cargo.toml b/yomikomi/Cargo.toml index cbb6860..53619f0 100644 --- a/yomikomi/Cargo.toml +++ b/yomikomi/Cargo.toml @@ -15,4 +15,5 @@ sentencepiece = "0.11.2" serde_json = "1.0.108" symphonia = { version = "0.5.3", features = ["all-codecs"] } thiserror = "1.0.50" +tokenizers = "0.21.0" zstd = "0.13.0" diff --git a/yomikomi/src/error.rs b/yomikomi/src/error.rs index 2e8de76..caa03d5 100644 --- a/yomikomi/src/error.rs +++ b/yomikomi/src/error.rs @@ -46,6 +46,9 @@ pub enum Error { #[error(transparent)] Io(#[from] std::io::Error), + #[error(transparent)] + Tokenizers(#[from] tokenizers::tokenizer::Error), + /// Arbitrary errors wrapping. #[error(transparent)] Wrapped(Box), diff --git a/yomikomi/src/tokenize.rs b/yomikomi/src/tokenize.rs index 62e313c..1ae1951 100644 --- a/yomikomi/src/tokenize.rs +++ b/yomikomi/src/tokenize.rs @@ -1,9 +1,43 @@ -use crate::{Array, Error, Result, Stream}; +use crate::{Array, Error as E, Result, Stream}; use sentencepiece::SentencePieceProcessor; use std::sync::{Arc, Mutex}; +use tokenizers::tokenizer::Tokenizer; + +enum Processor { + Tokenizers { inner: Tokenizer, bos_id: Option, eos_id: Option }, + SentencePiece(SentencePieceProcessor), +} + +impl Processor { + fn bos_id(&self) -> Option { + match self { + Self::SentencePiece(p) => p.bos_id(), + Self::Tokenizers { inner: _, bos_id, eos_id: _ } => bos_id.as_ref().copied(), + } + } + + fn eos_id(&self) -> Option { + match self { + Self::SentencePiece(p) => p.eos_id(), + Self::Tokenizers { inner: _, bos_id: _, eos_id } => eos_id.as_ref().copied(), + } + } + + fn encode(&self, str: &str) -> Result> { + let tokens: Vec<_> = match self { + Self::SentencePiece(p) => { + p.encode(str).map_err(E::wrap)?.iter().map(|v| v.id).collect() + } + Self::Tokenizers { inner, bos_id: _, eos_id: _ } => { + inner.encode(str, false)?.get_ids().to_vec() + } + }; + Ok(tokens) + } +} pub struct Tokenize { - spp: Arc, + processor: Arc, input: T, in_key: String, out_key: String, @@ -23,14 +57,20 @@ impl Tokenize { include_bos: bool, include_eos: bool, ) -> Result { - let spp = SentencePieceProcessor::open(path).map_err(Error::wrap)?; - let nl_id = match spp.encode("\n").map_err(Error::wrap)?.last() { + let path = path.as_ref(); + let processor = if path.extension().map_or(false, |v| v == "json") { + let inner = Tokenizer::from_file(path)?; + Processor::Tokenizers { inner, bos_id: None, eos_id: None } + } else { + Processor::SentencePiece(SentencePieceProcessor::open(path).map_err(E::wrap)?) + }; + let nl_id = match processor.encode("\n").map_err(E::wrap)?.last() { None => crate::bail!("no specific token id for newline"), - Some(p) => p.id, + Some(p) => *p, }; let tokens_and_chars = if report_bpb { Some(Mutex::new((0, 0))) } else { None }; Ok(Self { - spp: Arc::new(spp), + processor: Arc::new(processor), input, in_key, out_key, @@ -62,7 +102,7 @@ impl Stream for Tokenize { let text = String::from_utf8_lossy(values); let mut all_tokens = Vec::new(); if self.include_bos { - if let Some(bos_id) = self.spp.bos_id() { + if let Some(bos_id) = self.processor.bos_id() { all_tokens.push(bos_id) } } @@ -72,7 +112,7 @@ impl Stream for Tokenize { if idx > 0 { all_tokens.push(self.nl_id) } - let tokens = match self.spp.encode(text) { + let tokens = match self.processor.encode(text) { Ok(tokens) => tokens, Err(err) => { eprintln!("tokenizer encode error {err:?}"); @@ -86,11 +126,11 @@ impl Stream for Tokenize { bpb = Some(tokens_and_chars.0 as f64 / tokens_and_chars.1 as f64 / f64::ln(2.)) }; for token in tokens { - all_tokens.push(token.id) + all_tokens.push(token) } } if self.include_eos { - if let Some(eos_id) = self.spp.eos_id() { + if let Some(eos_id) = self.processor.eos_id() { all_tokens.push(eos_id) } }