From 3726eb7ffb14b82721678c1df43a1ef0d1b0c450 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 1 Feb 2024 19:42:05 +0000 Subject: [PATCH] fix tokenizer setting in tests --- rllm-cuda/src/llm/phi.rs | 10 ++++++++-- rllm-cuda/src/server/mod.rs | 15 +++++++++------ 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/rllm-cuda/src/llm/phi.rs b/rllm-cuda/src/llm/phi.rs index b7daa67e..f5042c48 100644 --- a/rllm-cuda/src/llm/phi.rs +++ b/rllm-cuda/src/llm/phi.rs @@ -216,8 +216,14 @@ impl TModelInner for MixFormerSequentialForCausalLM { // it should approximately match... let tok_size = self.config.tok_vocab_size as i64; - assert!(r.size()[1] >= tok_size); - assert!(r.size()[1] < tok_size + 1000); + if r.size()[1] < tok_size || r.size()[1] > tok_size + 1000 { + panic!( + "unexpected logits size: {:?} ({}/{})", + r.size(), + tok_size, + self.config.vocab_size + ); + } let r = r.i((.., 0..tok_size)); batch_info.extract_positions(&r) diff --git a/rllm-cuda/src/server/mod.rs b/rllm-cuda/src/server/mod.rs index 366ee724..6d8cbdca 100644 --- a/rllm-cuda/src/server/mod.rs +++ b/rllm-cuda/src/server/mod.rs @@ -601,13 +601,11 @@ pub async fn server_main(mut args: RllmCliArgs) -> () { loader_args.dtype = dtype; } - if args.test.len() > 0 { - run_tests(&args, loader_args); - return; - } - match &args.tokenizer { - Some(v) => loader_args.tokenizer = v.clone(), + Some(v) => { + log::info!("explicit tokenizer: {}", v); + loader_args.tokenizer = v.clone(); + } None => match guess_tokenizer(&loader_args.model_id) { Some(v) => { log::info!("guessed tokenizer: {}", v); @@ -621,6 +619,11 @@ pub async fn server_main(mut args: RllmCliArgs) -> () { }, } + if args.test.len() > 0 { + run_tests(&args, loader_args); + return; + } + let (tokenizer, tok_trie) = RllmEngine::load_tokenizer(&mut loader_args).expect("failed to load tokenizer");