diff --git a/yomikomi-pyo3/src/lib.rs b/yomikomi-pyo3/src/lib.rs index f3d740c..59d5f57 100644 --- a/yomikomi-pyo3/src/lib.rs +++ b/yomikomi-pyo3/src/lib.rs @@ -214,6 +214,8 @@ struct Tokenize { report_bpb: bool, include_bos: bool, include_eos: bool, + bos_id: Option, + eos_id: Option, } impl Iterable for Tokenize { @@ -227,6 +229,8 @@ impl Iterable for Tokenize { self.report_bpb, self.include_bos, self.include_eos, + self.bos_id, + self.eos_id, ) .map_err(w)?; Ok(StreamIter { stream: Box::new(stream) }) @@ -409,7 +413,8 @@ impl YkIterable { /// Loads a sentencepiece tokenizer, and use it to tokenize the field passed as an argument of /// this function. - #[pyo3(signature = (path, *, in_field="text".to_string(), out_field=None, report_bpb=true, include_bos=true, include_eos=false))] + #[allow(clippy::too_many_arguments)] + #[pyo3(signature = (path, *, in_field="text".to_string(), out_field=None, report_bpb=true, include_bos=true, include_eos=false, bos_id=None, eos_id=None))] fn tokenize( &self, path: std::path::PathBuf, @@ -418,6 +423,8 @@ impl YkIterable { report_bpb: bool, include_bos: bool, include_eos: bool, + bos_id: Option, + eos_id: Option, ) -> PyResult { let out_field = out_field.unwrap_or_else(|| in_field.clone()); let inner = Tokenize { @@ -428,6 +435,8 @@ impl YkIterable { report_bpb, include_bos, include_eos, + bos_id, + eos_id, }; Ok(Self { inner: Arc::new(inner) }) } diff --git a/yomikomi/src/tokenize.rs b/yomikomi/src/tokenize.rs index 069392f..9d9ccc9 100644 --- a/yomikomi/src/tokenize.rs +++ b/yomikomi/src/tokenize.rs @@ -4,7 +4,7 @@ use std::sync::{Arc, Mutex}; use tokenizers::tokenizer::Tokenizer; enum Processor { - Tokenizers { inner: Box, bos_id: Option, eos_id: Option }, + Tokenizers(Box), SentencePiece(SentencePieceProcessor), } @@ -12,14 +12,14 @@ impl Processor { fn bos_id(&self) -> Option { match self { Self::SentencePiece(p) => p.bos_id(), - Self::Tokenizers { inner: _, bos_id, eos_id: _ } => bos_id.as_ref().copied(), + Self::Tokenizers(_) => None, } } fn eos_id(&self) -> Option { match self { Self::SentencePiece(p) => p.eos_id(), - Self::Tokenizers { inner: _, bos_id: _, eos_id } => eos_id.as_ref().copied(), + Self::Tokenizers(_) => None, } } @@ -28,9 +28,7 @@ impl Processor { Self::SentencePiece(p) => { p.encode(str).map_err(E::wrap)?.iter().map(|v| v.id).collect() } - Self::Tokenizers { inner, bos_id: _, eos_id: _ } => { - inner.encode(str, false)?.get_ids().to_vec() - } + Self::Tokenizers(p) => p.encode(str, false)?.get_ids().to_vec(), }; Ok(tokens) } @@ -45,9 +43,12 @@ pub struct Tokenize { tokens_and_chars: Option>, include_bos: bool, include_eos: bool, + bos_id: Option, + eos_id: Option, } impl Tokenize { + #[allow(clippy::too_many_arguments)] pub fn new>( path: P, input: T, @@ -56,11 +57,13 @@ impl Tokenize { report_bpb: bool, include_bos: bool, include_eos: bool, + bos_id: Option, + eos_id: Option, ) -> Result { let path = path.as_ref(); let processor = if path.extension().map_or(false, |v| v == "json") { let inner = Box::new(Tokenizer::from_file(path)?); - Processor::Tokenizers { inner, bos_id: None, eos_id: None } + Processor::Tokenizers(inner) } else { Processor::SentencePiece(SentencePieceProcessor::open(path).map_err(E::wrap)?) }; @@ -78,6 +81,8 @@ impl Tokenize { tokens_and_chars, include_bos, include_eos, + bos_id, + eos_id, }) } } @@ -102,7 +107,8 @@ impl Stream for Tokenize { let text = String::from_utf8_lossy(values); let mut all_tokens = Vec::new(); if self.include_bos { - if let Some(bos_id) = self.processor.bos_id() { + let bos_id = self.bos_id.or_else(|| self.processor.bos_id()); + if let Some(bos_id) = bos_id { all_tokens.push(bos_id) } } @@ -130,7 +136,8 @@ impl Stream for Tokenize { } } if self.include_eos { - if let Some(eos_id) = self.processor.eos_id() { + let eos_id = self.eos_id.or_else(|| self.processor.eos_id()); + if let Some(eos_id) = eos_id { all_tokens.push(eos_id) } }