diff --git a/include/tokenizers_c.h b/include/tokenizers_c.h index 4276ee8..90b62f1 100644 --- a/include/tokenizers_c.h +++ b/include/tokenizers_c.h @@ -17,8 +17,8 @@ extern "C" { typedef void* TokenizerHandle; typedef struct { - int* token_ids; - size_t len; + int* token_ids; + size_t len; } TokenizerEncodeResult; TokenizerHandle tokenizers_new_from_str(const char* json, size_t len); @@ -28,10 +28,17 @@ TokenizerHandle byte_level_bpe_tokenizers_new_from_str(const char* vocab, size_t const char* added_tokens, size_t added_tokens_len); -void tokenizers_encode(TokenizerHandle handle, const char* data, size_t len, int add_special_token, TokenizerEncodeResult* result); +void tokenizers_encode(TokenizerHandle handle, const char* data, size_t len, int add_special_token, + TokenizerEncodeResult* result); -void tokenizers_encode_batch(TokenizerHandle handle, const char** data, size_t* len, size_t num_seqs, - int add_special_token, TokenizerEncodeResult* results); +void tokenizers_encode_batch(TokenizerHandle handle, const char** data, size_t* len, + size_t num_seqs, int add_special_token, + TokenizerEncodeResult* results); + +void tokenizers_encode_batch_with_mask(TokenizerHandle handle, const char** data, size_t* len, + size_t num_seqs, int add_special_token, + TokenizerEncodeResult* results, + TokenizerEncodeResult* masks); void tokenizers_free_encode_results(TokenizerEncodeResult* results, size_t num_seqs); diff --git a/include/tokenizers_cpp.h b/include/tokenizers_cpp.h index d37aa57..72a261d 100644 --- a/include/tokenizers_cpp.h +++ b/include/tokenizers_cpp.h @@ -6,10 +6,11 @@ #ifndef TOKENIZERS_CPP_H_ #define TOKENIZERS_CPP_H_ +#include + #include #include #include - namespace tokenizers { /*! @@ -57,13 +58,14 @@ class Tokenizer { virtual size_t GetVocabSize() = 0; /*! - * \brief Convert the given id to its corresponding token if it exists. If not, return an - * empty string. + * \brief Convert the given id to its corresponding token if it exists. If + * not, return an empty string. */ virtual std::string IdToToken(int32_t token_id) = 0; /*! - * \brief Convert the given token to its corresponding id if it exists. If not, return -1. + * \brief Convert the given token to its corresponding id if it exists. If + * not, return -1. */ virtual int32_t TokenToId(const std::string& token) = 0; @@ -106,5 +108,65 @@ class Tokenizer { static std::unique_ptr FromBlobRWKVWorld(const std::string& model_blob); }; +class HFTokenizer : public Tokenizer { + public: + explicit HFTokenizer(TokenizerHandle handle); + + HFTokenizer(const HFTokenizer&); + HFTokenizer(HFTokenizer&& other); + + ~HFTokenizer(); + + // use i32 to be consistent with sentencepiece + std::vector Encode(const std::string& text, bool add_special_tokens); + + // use i32 to be consistent with sentencepiece + std::vector Encode(const std::string& text) final; + + // version specific to HFTokenizer, which adds special tokens flag + std::vector> EncodeBatch(const std::vector& texts, + bool add_special_tokens); + + std::tuple>, std::vector>> + EncodeBatchWithMask(const std::vector& texts, bool add_special_tokens); + + std::vector> EncodeBatch(const std::vector& texts) final; + + // use i32 to be consistent with sentencepiece + std::string Decode(const std::vector& ids, bool skip_special_tokens); + + std::string Decode(const std::vector& ids) final; + + size_t GetVocabSize() final; + + std::string IdToToken(int32_t id) final; + + int32_t TokenToId(const std::string& token) final; + + /*! + * \brief Create HF tokenizer from a single in-memory json blob. + * + * \param json_blob The json blob. + * \return The created tokenzier. + */ + static std::unique_ptr FromBlobJSON(const std::string& json_blob); + + /*! + * \brief Create BPE tokenizer + * + * \param vocab_blob The blob that contains vocabs. + * \param merges_blob The blob that contains the merges. + * \param added_tokens The added tokens. + * \return The created tokenizer. + */ + static std::unique_ptr FromBlobByteLevelBPE(const std::string& vocab_blob, + const std::string& merges_blob, + const std::string& added_tokens = ""); + + private: + // internal handle + TokenizerHandle handle_{nullptr}; +}; + } // namespace tokenizers #endif // TOKENIZERS_CPP_H_ diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 96b0ea8..a94f0d3 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -91,12 +91,24 @@ impl TokenizerWrapper { return encoded.get_ids().to_vec(); } - pub fn encode_batch(&mut self, texts: Vec<&str>, add_special_tokens: bool) -> Vec> { - let results = self.tokenizer.encode_batch(texts, add_special_tokens).unwrap() - .into_iter() - .map(|encoded| encoded.get_ids().to_vec()) + pub fn encode_batch_with_mask( + &mut self, + texts: Vec<&str>, + add_special_tokens: bool, + ) -> (Vec>, Vec>) { + let encoded = self + .tokenizer + .encode_batch(texts, add_special_tokens) + .unwrap(); + let tokens = encoded + .iter() + .map(|e| e.get_ids().to_vec()) .collect::>>(); - return results; + let attention_mask = encoded + .iter() + .map(|e| e.get_attention_mask().to_vec()) + .collect::>>(); + return (tokens, attention_mask); } pub fn decode(&mut self, ids: &[u32], skip_special_tokens: bool) { @@ -170,10 +182,49 @@ extern "C" fn tokenizers_encode_batch( unsafe { let input_data = (0..num_seqs) .map(|i| { - std::str::from_utf8(std::slice::from_raw_parts(*input_cstr.offset(i as isize), *input_len.offset(i as isize))).unwrap() + std::str::from_utf8(std::slice::from_raw_parts( + *input_cstr.offset(i as isize), + *input_len.offset(i as isize), + )) + .unwrap() + }) + .collect::>(); + let (encoded_batch, _encoded_masks) = + (*handle).encode_batch_with_mask(input_data, add_special_tokens != 0); + for (i, encoded) in encoded_batch.into_iter().enumerate() { + let len = encoded.len(); + let result = TokenizerEncodeResult { + token_ids: Box::into_raw(encoded.into_boxed_slice()) as *mut u32, + len: len, + }; + *out_result.offset(i as isize) = result; + } + } +} + +#[no_mangle] +extern "C" fn tokenizers_encode_batch_with_mask( + handle: *mut TokenizerWrapper, + input_cstr: *const *const u8, + input_len: *const usize, + num_seqs: usize, + add_special_tokens: i32, + out_result: *mut TokenizerEncodeResult, + out_mask: *mut TokenizerEncodeResult, +) { + unsafe { + let input_data = (0..num_seqs) + .map(|i| { + std::str::from_utf8(std::slice::from_raw_parts( + *input_cstr.offset(i as isize), + *input_len.offset(i as isize), + )) + .unwrap() }) .collect::>(); - let encoded_batch = (*handle).encode_batch(input_data, add_special_tokens != 0); + let (encoded_batch, encoded_mask) = + (*handle).encode_batch_with_mask(input_data, add_special_tokens != 0); + for (i, encoded) in encoded_batch.into_iter().enumerate() { let len = encoded.len(); let result = TokenizerEncodeResult { @@ -182,6 +233,14 @@ extern "C" fn tokenizers_encode_batch( }; *out_result.offset(i as isize) = result; } + for (i, encoded) in encoded_mask.into_iter().enumerate() { + let len = encoded.len(); + let result = TokenizerEncodeResult { + token_ids: Box::into_raw(encoded.into_boxed_slice()) as *mut u32, + len: len, + }; + *out_mask.offset(i as isize) = result; + } } } @@ -190,7 +249,10 @@ extern "C" fn tokenizers_free_encode_results(results: *mut TokenizerEncodeResult unsafe { let slice = std::slice::from_raw_parts_mut(results, num_seqs); for result in &mut *slice { - drop(Box::from_raw(std::slice::from_raw_parts_mut(result.token_ids, result.len))); + drop(Box::from_raw(std::slice::from_raw_parts_mut( + result.token_ids, + result.len, + ))); } } } diff --git a/src/huggingface_tokenizer.cc b/src/huggingface_tokenizer.cc index 6cbe0d8..0353665 100644 --- a/src/huggingface_tokenizer.cc +++ b/src/huggingface_tokenizer.cc @@ -13,110 +13,156 @@ namespace tokenizers { /*! * \brief A simple c++ header of tokenizer via C API. */ -class HFTokenizer : public Tokenizer { - public: - explicit HFTokenizer(TokenizerHandle handle) : handle_(handle) { - #ifdef COMPILE_WASM_RUNTIME - setenv("TOKENIZERS_PARALLELISM", "false", true); - #endif - } - HFTokenizer(const HFTokenizer&) = delete; - HFTokenizer(HFTokenizer&& other) { std::swap(other.handle_, handle_); } +/* +These are the methods for the HFTokenizer class. +*/ - ~HFTokenizer() { - if (handle_ != nullptr) { - tokenizers_free(handle_); - } - } +HFTokenizer::HFTokenizer(TokenizerHandle handle) : handle_(handle) { +#ifdef COMPILE_WASM_RUNTIME + setenv("TOKENIZERS_PARALLELISM", "false", true); +#endif +} - // use i32 to be consistent with sentencepiece - std::vector Encode(const std::string& text, bool add_special_tokens) { - TokenizerEncodeResult result; - tokenizers_encode(handle_, text.data(), text.length(), static_cast(add_special_tokens), - &result); - std::vector ret(result.token_ids, result.token_ids + result.len); - tokenizers_free_encode_results(&result, 1); - return ret; - } +// HFTokenizer::HFTokenizer(const HFTokenizer&) = delete; +HFTokenizer::HFTokenizer(HFTokenizer&& other) { std::swap(other.handle_, handle_); } - // use i32 to be consistent with sentencepiece - std::vector Encode(const std::string& text) final { return Encode(text, false); } - - std::vector> EncodeBatch(const std::vector& texts, - bool add_special_tokens) { - std::vector texts_raw; - std::vector seq_lens; - size_t num_seqs = texts.size(); - texts_raw.reserve(num_seqs); - seq_lens.reserve(num_seqs); - for (const auto& text : texts) { - texts_raw.push_back(text.data()); - seq_lens.push_back(text.length()); - } - std::vector results(num_seqs); - tokenizers_encode_batch(handle_, texts_raw.data(), seq_lens.data(), texts.size(), - static_cast(add_special_tokens), results.data()); - std::vector> ret; - ret.reserve(texts.size()); - for (size_t i = 0; i < texts.size(); ++i) { - ret.push_back( - std::vector(results[i].token_ids, results[i].token_ids + results[i].len)); - } - tokenizers_free_encode_results(results.data(), texts.size()); - return ret; +HFTokenizer::~HFTokenizer() { + if (handle_ != nullptr) { + tokenizers_free(handle_); } +} - std::vector> EncodeBatch(const std::vector& texts) final { - return EncodeBatch(texts, false); +// use i32 to be consistent with sentencepiece +std::vector HFTokenizer::Encode(const std::string& text, bool add_special_tokens) { + TokenizerEncodeResult result; + tokenizers_encode(handle_, text.data(), text.length(), static_cast(add_special_tokens), + &result); + std::vector ret(result.token_ids, result.token_ids + result.len); + tokenizers_free_encode_results(&result, 1); + return ret; +} + +// use i32 to be consistent with sentencepiece +std::vector HFTokenizer::Encode(const std::string& text) { return Encode(text, false); } + +std::vector> HFTokenizer::EncodeBatch(const std::vector& texts, + bool add_special_tokens) { + std::vector texts_raw; + std::vector seq_lens; + size_t num_seqs = texts.size(); + texts_raw.reserve(num_seqs); + seq_lens.reserve(num_seqs); + for (const auto& text : texts) { + texts_raw.push_back(text.data()); + seq_lens.push_back(text.length()); } + std::vector results(num_seqs); + tokenizers_encode_batch(handle_, texts_raw.data(), seq_lens.data(), texts.size(), + static_cast(add_special_tokens), results.data()); + std::vector> ret; + ret.reserve(texts.size()); + for (size_t i = 0; i < texts.size(); ++i) { + ret.push_back( + std::vector(results[i].token_ids, results[i].token_ids + results[i].len)); + } + tokenizers_free_encode_results(results.data(), texts.size()); + return ret; +} - // use i32 to be consistent with sentencepiece - std::string Decode(const std::vector& ids, bool skip_special_tokens) { - tokenizers_decode(handle_, reinterpret_cast(ids.data()), ids.size(), - static_cast(skip_special_tokens)); - const char* data; - size_t len; - tokenizers_get_decode_str(handle_, &data, &len); - return std::string(data, len); +std::tuple>, std::vector>> +HFTokenizer::EncodeBatchWithMask(const std::vector& texts, bool add_special_tokens) { + std::vector texts_raw; + std::vector seq_lens; + size_t num_seqs = texts.size(); + texts_raw.reserve(num_seqs); + seq_lens.reserve(num_seqs); + for (const auto& text : texts) { + texts_raw.push_back(text.data()); + seq_lens.push_back(text.length()); } + std::vector results(num_seqs); + std::vector masks(num_seqs); + tokenizers_encode_batch_with_mask(handle_, texts_raw.data(), seq_lens.data(), texts.size(), + static_cast(add_special_tokens), results.data(), + masks.data()); + // process the tokens: + std::vector> ret_tokens; + ret_tokens.reserve(texts.size()); + for (size_t i = 0; i < texts.size(); ++i) { + ret_tokens.push_back( + std::vector(results[i].token_ids, results[i].token_ids + results[i].len)); + } + tokenizers_free_encode_results(results.data(), texts.size()); + // process the masks: + std::vector> ret_masks; + ret_masks.reserve(texts.size()); + for (size_t i = 0; i < texts.size(); ++i) { + ret_masks.push_back( + std::vector(masks[i].token_ids, masks[i].token_ids + masks[i].len)); + } + tokenizers_free_encode_results(masks.data(), texts.size()); + return std::make_tuple(ret_tokens, ret_masks); +} - std::string Decode(const std::vector& ids) final { return Decode(ids, false); } +std::vector> HFTokenizer::EncodeBatch(const std::vector& texts) { + return EncodeBatch(texts, false); +} - size_t GetVocabSize() final { - size_t size; - tokenizers_get_vocab_size(handle_, &size); - assert(size > 0); - return size; - } +// use i32 to be consistent with sentencepiece +std::string HFTokenizer::Decode(const std::vector& ids, bool skip_special_tokens) { + tokenizers_decode(handle_, reinterpret_cast(ids.data()), ids.size(), + static_cast(skip_special_tokens)); + const char* data; + size_t len; + tokenizers_get_decode_str(handle_, &data, &len); + return std::string(data, len); +} - std::string IdToToken(int32_t id) final { - const char* data; - size_t len; - tokenizers_id_to_token(handle_, static_cast(id), &data, &len); - return std::string(data, len); - } +std::string HFTokenizer::Decode(const std::vector& ids) { return Decode(ids, false); } - int32_t TokenToId(const std::string& token) final { - int32_t id; - tokenizers_token_to_id(handle_, token.data(), token.length(), &id); - return id; - } +size_t HFTokenizer::GetVocabSize() { + size_t size; + tokenizers_get_vocab_size(handle_, &size); + assert(size > 0); + return size; +} - private: - // internal handle - TokenizerHandle handle_{nullptr}; -}; +std::string HFTokenizer::IdToToken(int32_t id) { + const char* data; + size_t len; + tokenizers_id_to_token(handle_, static_cast(id), &data, &len); + return std::string(data, len); +} -std::unique_ptr Tokenizer::FromBlobJSON(const std::string& json) { +int32_t HFTokenizer::TokenToId(const std::string& token) { + int32_t id; + tokenizers_token_to_id(handle_, token.data(), token.length(), &id); + return id; +} + +// These are factory methods defined in the base class Tokenizer: + +std::unique_ptr HFTokenizer::FromBlobJSON(const std::string& json) { return std::make_unique(tokenizers_new_from_str(json.data(), json.length())); } -std::unique_ptr Tokenizer::FromBlobByteLevelBPE(const std::string& vocab, - const std::string& merges, - const std::string& added_tokens) { +std::unique_ptr Tokenizer::FromBlobJSON(const std::string& json) { + return HFTokenizer::FromBlobJSON(json); +} + +std::unique_ptr HFTokenizer::FromBlobByteLevelBPE(const std::string& vocab, + const std::string& merges, + const std::string& added_tokens) { return std::make_unique(byte_level_bpe_tokenizers_new_from_str( vocab.data(), vocab.length(), merges.data(), merges.length(), added_tokens.data(), added_tokens.length())); } + +std::unique_ptr Tokenizer::FromBlobByteLevelBPE(const std::string& vocab, + const std::string& merges, + const std::string& added_tokens) { + return HFTokenizer::FromBlobByteLevelBPE(vocab, merges, added_tokens); +} + } // namespace tokenizers