Skip to content

Commit

Permalink
guess tokenizer and aicirt
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Jan 23, 2024
1 parent 791c74c commit 5e3493a
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 7 deletions.
12 changes: 12 additions & 0 deletions aicirt/src/bintokens.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ pub fn list_tokenizers() -> String {
)
}

pub fn guess_tokenizer(model_name: &str) -> Option<String> {
let m = model_name.to_lowercase();
if m.contains("codellama-13b") {
Some("llama16".to_string())
} else {
tokenizers()
.iter()
.find(|t| m.contains(&t.name))
.map(|t| t.name.clone())
}
}

pub fn find_tokenizer(name: &str) -> Result<Tokenizer> {
for mut t in tokenizers() {
if t.name == name {
Expand Down
48 changes: 41 additions & 7 deletions rllm/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ use actix_web::{middleware::Logger, web, App, HttpServer};
use aici_abi::toktree::TokTrie;
use aicirt::{
api::{AuthInfo, GetTagsResp, MkModuleReq, MkModuleResp, SetTagsReq},
bintokens::{guess_tokenizer, list_tokenizers},
set_max_priority,
};
use anyhow::Result;
use anyhow::{bail, Result};
use base64::Engine;
use clap::Args;
use openai::responses::APIError;
Expand Down Expand Up @@ -74,8 +75,8 @@ pub struct RllmCliArgs {
pub local_weights: Option<String>,

/// Tokenizer to use (see below or in --help for list)
#[arg(short, long, default_value = "llama", help_heading = "Model")]
pub tokenizer: String,
#[arg(short, long, help_heading = "Model")]
pub tokenizer: Option<String>,

/// Specify which type to use in the model (bf16, f16, f32)
#[arg(long, default_value = "", help_heading = "Model")]
Expand All @@ -95,7 +96,7 @@ pub struct RllmCliArgs {

/// Path to the aicirt binary.
#[arg(long, help_heading = "AICI settings")]
pub aicirt: String,
pub aicirt: Option<String>,

/// Size of JSON comm buffer in megabytes
#[arg(long, default_value = "128", help_heading = "AICI settings")]
Expand Down Expand Up @@ -454,6 +455,17 @@ fn url_decode(encoded_str: &str) -> String {
.to_string()
}

fn guess_aicirt() -> Result<String> {
let mut path = std::env::current_exe()?;
path.pop();
path.push("aicirt");
if path.to_str().is_some() && path.exists() {
Ok(path.to_str().unwrap().to_string())
} else {
bail!("can't find aicirt binary (tried {:?})", path)
}
}

// #[actix_web::main]
pub async fn server_main(mut args: RllmCliArgs) -> () {
aicirt::init_log(if args.daemon {
Expand Down Expand Up @@ -511,7 +523,6 @@ pub async fn server_main(mut args: RllmCliArgs) -> () {
loader_args.model_id = args.model.clone();
loader_args.revision = args.revision.clone();
loader_args.local_weights = args.local_weights.clone();
loader_args.tokenizer = args.tokenizer.clone();
loader_args.gguf = args.gguf.clone();
if dtype.is_some() {
loader_args.dtype = dtype;
Expand All @@ -522,6 +533,18 @@ pub async fn server_main(mut args: RllmCliArgs) -> () {
return;
}

match &args.tokenizer {
Some(v) => loader_args.tokenizer = v.clone(),
None => match guess_tokenizer(&loader_args.model_id) {
Some(v) => loader_args.tokenizer = v,
None => {
eprintln!("can't guess tokenizer from {}", loader_args.model_id);
eprintln!("{}", list_tokenizers());
std::process::exit(10);
}
},
}

let (tokenizer, tok_trie) =
RllmEngine::load_tokenizer(&mut loader_args).expect("failed to load tokenizer");

Expand All @@ -530,9 +553,20 @@ pub async fn server_main(mut args: RllmCliArgs) -> () {
let model_config =
RllmEngine::load_model_config(&mut loader_args).expect("failed to load model config");

let aicirt = match &args.aicirt {
Some(v) => v.clone(),
None => match guess_aicirt() {
Ok(v) => v,
Err(e) => {
eprintln!("can't find aicirt; specify with --aicirt=PATH\n{e}");
std::process::exit(10);
}
},
};

let rt_args = crate::iface::Args {
aicirt: args.aicirt.clone(),
tokenizer: args.tokenizer.clone(),
aicirt,
tokenizer: loader_args.tokenizer.clone(),
json_size: args.json_size,
bin_size: args.bin_size,
shm_prefix: args.shm_prefix.clone(),
Expand Down

0 comments on commit 5e3493a

Please sign in to comment.