Skip to content

Commit

Permalink
fix tokenizer setting in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Feb 1, 2024
1 parent e76ebed commit 3726eb7
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
10 changes: 8 additions & 2 deletions rllm-cuda/src/llm/phi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,14 @@ impl TModelInner for MixFormerSequentialForCausalLM {

// it should approximately match...
let tok_size = self.config.tok_vocab_size as i64;
assert!(r.size()[1] >= tok_size);
assert!(r.size()[1] < tok_size + 1000);
if r.size()[1] < tok_size || r.size()[1] > tok_size + 1000 {
panic!(
"unexpected logits size: {:?} ({}/{})",
r.size(),
tok_size,
self.config.vocab_size
);
}

let r = r.i((.., 0..tok_size));
batch_info.extract_positions(&r)
Expand Down
15 changes: 9 additions & 6 deletions rllm-cuda/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -601,13 +601,11 @@ pub async fn server_main(mut args: RllmCliArgs) -> () {
loader_args.dtype = dtype;
}

if args.test.len() > 0 {
run_tests(&args, loader_args);
return;
}

match &args.tokenizer {
Some(v) => loader_args.tokenizer = v.clone(),
Some(v) => {
log::info!("explicit tokenizer: {}", v);
loader_args.tokenizer = v.clone();
}
None => match guess_tokenizer(&loader_args.model_id) {
Some(v) => {
log::info!("guessed tokenizer: {}", v);
Expand All @@ -621,6 +619,11 @@ pub async fn server_main(mut args: RllmCliArgs) -> () {
},
}

if args.test.len() > 0 {
run_tests(&args, loader_args);
return;
}

let (tokenizer, tok_trie) =
RllmEngine::load_tokenizer(&mut loader_args).expect("failed to load tokenizer");

Expand Down

0 comments on commit 3726eb7

Please sign in to comment.