Skip to content

Commit

Permalink
Add tokenizer and convert model
Browse files Browse the repository at this point in the history
Signed-off-by: Aisuko <[email protected]>
  • Loading branch information
Aisuko committed Oct 31, 2023
1 parent 390d41d commit 4c7f5ca
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 2 deletions.
61 changes: 61 additions & 0 deletions backend/rust/backend-burn/src/pkg/convert.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
use crate::models::{Llama, LlamaConfig};

use burn::{
backend::tch::{TchBackend, TchDevice},
config::Config,
module::Module,
tensor::backend::Backend,
};

use burn::record::{BinFileRecorder, HalfPrecisionSettings, Recorder, RecorderError};

use super::Loader;

pub struct Convertion {}

impl Convertion {
pub fn convert_llama_dump_to_model<B: Backend>(
dump_path: &str,
model_name: &str,
device: &B::Device,
) -> Result<(), Box<dyn std::error::Error>> {
let (llama, llama_conifg): (Llama<B>, LlamaConfig) =
Loader::load_llama_dmp(dump_path, device)?;
Convertion::save_llama_model_file(llama, model_name)?;
llama_conifg.save(&format!("{model_name}.cfg"))?;
Ok(())
}

pub fn save_llama_model_file<B: Backend>(
llama: Llama<B>,
name: &str,
) -> Result<(), RecorderError> {
BinFileRecorder::<HalfPrecisionSettings>::new().record(llama.into_record(), name.into())
}
}

#[cfg(test)]
mod tests {

use super::*;

#[test]
fn test_convertion() {
type Backend = TchBackend<f32>;
let device = TchDevice::Mps;

// get home env
let home = std::env::var("HOME").unwrap();

let dump_path = &format!("{}/Downloads/workspace/llama/tokenizer.model", home);
let model_name = "llama2-7b-chat";

let option =
Convertion::convert_llama_dump_to_model::<Backend>(dump_path, model_name, &device);

match option {
Ok(_) => println!("ok"),
Err(e) => println!("error: {}", e),
}
}
}
2 changes: 0 additions & 2 deletions backend/rust/backend-burn/src/pkg/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
//! Adapted by Aisuko

use core::f32;
use npyz::NpyFile;
use num_traits::cast::ToPrimitive;
use std::io::Read;

use burn::{
nn,
Expand Down
4 changes: 4 additions & 0 deletions backend/rust/backend-burn/src/pkg/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
pub mod convert;
pub mod loader;
pub mod tokenizer;

pub use convert::*;
pub use loader::*;
pub use tokenizer::*;
91 changes: 91 additions & 0 deletions backend/rust/backend-burn/src/pkg/tokenizer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
//! LLM tokenization tools crate.
//! Adapted from the https://github.com/Gadersd/llama2-burn/blob/main/src/token.rs
//! Adapted by Aisuko

use rust_tokenizers::{
error::TokenizerError,
tokenizer::{SentencePieceBpeTokenizer, Tokenizer, TruncationStrategy},
vocab::Vocab,
};

use std::{result, vec};

const BOS_TOKEN_ID: i64 = 1;
const EOS_TOKEN_ID: i64 = 2;

pub type Result<T> = result::Result<T, TokenizerError>;

pub struct LlamaTokenizer {
spm: SentencePieceBpeTokenizer,
}

impl LlamaTokenizer {
pub fn new(tokenizer_path: &str) -> Result<Self> {
let lower_case = false;
SentencePieceBpeTokenizer::from_file(tokenizer_path, lower_case).map(|spm| Self { spm })
}

pub fn encode(&self, text: &str, inlcude_bos: bool, include_eos: bool) -> Vec<i64> {
let pre = if inlcude_bos {
vec![BOS_TOKEN_ID]
} else {
vec![]
};

let post = if include_eos {
vec![EOS_TOKEN_ID]
} else {
vec![]
};

let token_ids = self
.spm
.encode(
text,
None,
std::usize::MAX,
&TruncationStrategy::LongestFirst,
0,
)
.token_ids;

[pre, token_ids, post]
.into_iter()
.flat_map(|v| v.into_iter())
.collect()
}

pub fn decode(&self, tokens: &[i64], skip_special_tokens: bool) -> String {
let clean_spaces = false;
self.spm.decode(tokens, skip_special_tokens, clean_spaces)
}

pub fn vocab_size(&self, include_special_tokens: bool) -> usize {
let vocab = self.spm.vocab();
if include_special_tokens {
vocab.values().len() + vocab.special_values().len()
} else {
vocab.values().len()
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_tokenizer() {
let home = std::env::var("HOME").unwrap();
let tm_path = &format!("{}/Downloads/workspace/llama/tokenizer.model", home);
let tokenizer = LlamaTokenizer::new(tm_path).unwrap();
// tokenizer.vocab_size(fale) should be >0
assert!(tokenizer.vocab_size(false) > 0);

let test_str = "Hello, I am Llama2!";
let encoded = tokenizer.encode(test_str, true, true);
let decoded = tokenizer.decode(&encoded, false);

assert_eq!(test_str, decoded);
}
}

0 comments on commit 4c7f5ca

Please sign in to comment.