From 791c74cc7d128420e8f12dec71991b3f6120eb76 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 23 Jan 2024 22:19:13 +0000 Subject: [PATCH] more sensible cmd line parsing --- Cargo.lock | 13 ++-- aicirt/src/bintokens.rs | 18 ++++-- cpp-rllm/Cargo.toml | 1 + cpp-rllm/src/cpp-rllm.rs | 19 +++++- rllm/src/driver.rs | 14 ++++- rllm/src/server/mod.rs | 133 ++++++++++++++++++++------------------- rllm/src/util.rs | 26 ++++++-- 7 files changed, 141 insertions(+), 83 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1b6e1d2e..fda08fba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -373,9 +373,9 @@ dependencies = [ [[package]] name = "anstream" -version = "0.6.5" +version = "0.6.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d664a92ecae85fd0a7392615844904654d1d5f5514837f471ddef4a057aba1b6" +checksum = "6e2e1ebcb11de5c03c67de28a7df593d32191b44939c482e97702baaaa6ab6a5" dependencies = [ "anstyle", "anstyle-parse", @@ -721,9 +721,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.4.12" +version = "4.4.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcfab8ba68f3668e89f6ff60f5b205cea56aa7b769451a59f34b8682f51c056d" +checksum = "1e578d6ec4194633722ccf9544794b71b1385c3c027efe0c55db226fc880865c" dependencies = [ "clap_builder", "clap_derive", @@ -731,9 +731,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.4.12" +version = "4.4.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb7fb5e4e979aec3be7791562fcba452f94ad85e954da024396433e0e25a79e9" +checksum = "4df4df40ec50c46000231c914968278b1eb05098cf8f1b3a518a95030e71d1c7" dependencies = [ "anstream", "anstyle", @@ -851,6 +851,7 @@ name = "cpp-rllm" version = "0.1.0" dependencies = [ "actix-web", + "clap", "llama_cpp_low", "rllm", ] diff --git a/aicirt/src/bintokens.rs b/aicirt/src/bintokens.rs index 711560fc..72223a83 100644 --- a/aicirt/src/bintokens.rs +++ b/aicirt/src/bintokens.rs @@ -88,11 +88,22 @@ pub fn tokenizers() -> Vec { ), tok!("falcon", "used by Falcon 7b, 40b, etc."), tok!("mpt", "MPT"), - tok!("phi", "Phi 1.5"), + tok!("phi", "Phi 1.5 and Phi 2"), tok!("gpt2", "GPT-2"), ] } +pub fn list_tokenizers() -> String { + format!( + "Available tokenizers for -t or --tokenizer:\n{}", + tokenizers() + .iter() + .map(|t| format!(" -t {:16} {}", t.name, t.description)) + .collect::>() + .join("\n") + ) +} + pub fn find_tokenizer(name: &str) -> Result { for mut t in tokenizers() { if t.name == name { @@ -102,10 +113,7 @@ pub fn find_tokenizer(name: &str) -> Result { } println!("unknown tokenizer: {}", name); - println!("available tokenizers:"); - for t in tokenizers() { - println!(" {:20} {}", t.name, t.description); - } + println!("{}", list_tokenizers()); return Err(anyhow!("unknown tokenizer: {}", name)); } diff --git a/cpp-rllm/Cargo.toml b/cpp-rllm/Cargo.toml index a9708a45..d9199600 100644 --- a/cpp-rllm/Cargo.toml +++ b/cpp-rllm/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] actix-web = "4.4.0" +clap = { version = "4.4.18", features = ["derive"] } llama_cpp_low = { path = "../llama-cpp-low" } rllm = { path = "../rllm", default-features = false, features = ["llamacpp"] } diff --git a/cpp-rllm/src/cpp-rllm.rs b/cpp-rllm/src/cpp-rllm.rs index 2538dddd..f56d9b51 100644 --- a/cpp-rllm/src/cpp-rllm.rs +++ b/cpp-rllm/src/cpp-rllm.rs @@ -1,4 +1,21 @@ +use clap::Parser; +use rllm::util::parse_with_settings; + +/// Serve LLMs with AICI over HTTP with llama.cpp backend. +#[derive(Parser, Debug)] +#[command(version, about, long_about = None)] +pub struct CppArgs { + #[clap(flatten)] + pub args: rllm::server::RllmCliArgs, + + /// Name of .gguf file inside of the model folder/repo. + #[arg(long, help_heading = "Model")] + pub gguf: Option, +} + #[actix_web::main] async fn main() -> () { - rllm::server::server_main().await; + let mut args = parse_with_settings::(); + args.args.gguf = args.gguf; + rllm::server::server_main(args.args).await; } diff --git a/rllm/src/driver.rs b/rllm/src/driver.rs index 2538dddd..25d85c57 100644 --- a/rllm/src/driver.rs +++ b/rllm/src/driver.rs @@ -1,4 +1,16 @@ +use clap::Parser; +use rllm::util::parse_with_settings; + +/// Serve LLMs with AICI over HTTP with tch (torch) backend. +#[derive(Parser, Debug)] +#[command(version, about, long_about = None)] +pub struct DriverArgs { + #[clap(flatten)] + pub args: rllm::server::RllmCliArgs, +} + #[actix_web::main] async fn main() -> () { - rllm::server::server_main().await; + let args = parse_with_settings::(); + rllm::server::server_main(args.args).await; } diff --git a/rllm/src/server/mod.rs b/rllm/src/server/mod.rs index 79f10949..1218341a 100644 --- a/rllm/src/server/mod.rs +++ b/rllm/src/server/mod.rs @@ -1,3 +1,10 @@ +use crate::{ + config::{ModelConfig, SamplingParams}, + iface::{kill_self, AiciRtIface, AsyncCmdChannel}, + seq::RequestOutput, + util::apply_settings, + AddRequest, DType, HashMap, LoaderArgs, RllmEngine, +}; use actix_web::{middleware::Logger, web, App, HttpServer}; use aici_abi::toktree::TokTrie; use aicirt::{ @@ -6,15 +13,8 @@ use aicirt::{ }; use anyhow::Result; use base64::Engine; -use clap::Parser; +use clap::Args; use openai::responses::APIError; -use crate::{ - config::{ModelConfig, SamplingParams}, - iface::{kill_self, AiciRtIface, AsyncCmdChannel}, - seq::RequestOutput, - util::apply_settings, - AddRequest, DType, HashMap, LoaderArgs, RllmEngine, -}; use std::{ fmt::Display, sync::{Arc, Mutex}, @@ -54,85 +54,85 @@ pub struct OpenAIServerData { pub stats: Arc>, } -#[derive(Parser, Debug)] -#[command(author, version, about, long_about = None)] -pub struct Args { - /// Port to serve on (localhost:port) - #[arg(long, default_value_t = 8080)] - port: u16, - - /// Set verbose mode (print all requests) - #[arg(long, default_value_t = false)] - verbose: bool, +#[derive(Args, Debug)] +pub struct RllmCliArgs { + /// Set engine setting (see below or in --help for list) + #[arg(long, short, name = "NAME=VALUE")] + pub setting: Vec, - /// HuggingFace model name; can be also path starting with "./" - #[arg(short, long)] - model: String, + /// HuggingFace model name, URL or path starting with "./" + #[arg(short, long, help_heading = "Model")] + pub model: String, /// HuggingFace model revision; --model foo/bar@revision is also possible - #[arg(long)] - revision: Option, + #[arg(long, help_heading = "Model")] + pub revision: Option, /// The folder name that contains safetensor weights and json files /// (same structure as HuggingFace online) - #[arg(long)] - local_weights: Option, + #[arg(long, help_heading = "Model")] + pub local_weights: Option, + + /// Tokenizer to use (see below or in --help for list) + #[arg(short, long, default_value = "llama", help_heading = "Model")] + pub tokenizer: String, + + /// Specify which type to use in the model (bf16, f16, f32) + #[arg(long, default_value = "", help_heading = "Model")] + pub dtype: String, + + /// Port to serve on (localhost:port) + #[arg(long, default_value_t = 8080, help_heading = "Server")] + pub port: u16, - /// Name of .gguf file inside of the model folder/repo. - #[arg(long)] - gguf: Option, + /// Set verbose mode (print all requests) + #[arg(long, default_value_t = false, help_heading = "Server")] + pub verbose: bool, - /// Tokenizer to use; try --tokenizer list to see options - #[arg(short, long, default_value = "llama")] - tokenizer: String, + /// Enable daemon mode (log timestamps) + #[arg(long, default_value_t = false, help_heading = "Server")] + pub daemon: bool, /// Path to the aicirt binary. - #[arg(long)] - aicirt: String, + #[arg(long, help_heading = "AICI settings")] + pub aicirt: String, /// Size of JSON comm buffer in megabytes - #[arg(long, default_value = "128")] - json_size: usize, + #[arg(long, default_value = "128", help_heading = "AICI settings")] + pub json_size: usize, /// Size of binary comm buffer in megabytes - #[arg(long, default_value = "32")] - bin_size: usize, + #[arg(long, default_value = "32", help_heading = "AICI settings")] + pub bin_size: usize, /// How many milliseconds to spin-wait for a message over IPC and SHM. - #[arg(long, default_value = "200")] - busy_wait_time: u64, + #[arg(long, default_value = "200", help_heading = "AICI settings")] + pub busy_wait_time: u64, /// Shm/semaphore name prefix - #[arg(long, default_value = "/aici0-")] - shm_prefix: String, + #[arg(long, default_value = "/aici0-", help_heading = "AICI settings")] + pub shm_prefix: String, /// Enable nvprof profiling for given engine step - #[arg(long, default_value_t = 0)] - profile_step: usize, - - /// Specify which type to use in the model (bf16, f16, f32) - #[arg(long, default_value = "")] - dtype: String, + #[cfg(feature = "cuda")] + #[arg(long, default_value_t = 0, help_heading = "Development")] + pub profile_step: usize, /// Specify test-cases (expected/*/*.safetensors) - #[arg(long)] - test: Vec, + #[arg(long, help_heading = "Development")] + pub test: Vec, /// Specify warm-up request (expected/*/*.safetensors or "off") - #[arg(long, short)] - warmup: Option, + #[arg(long, short, help_heading = "Development")] + pub warmup: Option, /// Exit after processing warmup request - #[arg(long, default_value_t = false)] - warmup_only: bool, + #[arg(long, default_value_t = false, help_heading = "Development")] + pub warmup_only: bool, - /// Set engine setting; try '--setting help' to list them - #[arg(long, short, name = "NAME=VALUE")] - setting: Vec, - - /// Enable daemon mode (log timestamps) - #[arg(long, default_value_t = false)] - daemon: bool, + // these are copied from command-specific parsers + #[arg(skip)] + pub gguf: Option, } #[actix_web::get("/v1/aici_modules/tags")] @@ -344,12 +344,12 @@ fn inference_loop( } #[cfg(not(feature = "tch"))] -fn run_tests(_args: &Args, _loader_args: LoaderArgs) { +fn run_tests(_args: &RllmCliArgs, _loader_args: LoaderArgs) { panic!("tests not supported without tch feature") } #[cfg(feature = "tch")] -fn run_tests(args: &Args, loader_args: LoaderArgs) { +fn run_tests(args: &RllmCliArgs, loader_args: LoaderArgs) { let mut engine = RllmEngine::load(loader_args).expect("failed to load model"); let mut tests = args.test.clone(); @@ -384,7 +384,7 @@ fn run_tests(args: &Args, loader_args: LoaderArgs) { } fn spawn_inference_loop( - args: &Args, + args: &RllmCliArgs, loader_args: LoaderArgs, iface: AiciRtIface, stats: Arc>, @@ -394,7 +394,10 @@ fn spawn_inference_loop( let handle = handle_res.clone(); // prep for move + #[cfg(feature = "cuda")] let profile_step = args.profile_step; + #[cfg(not(feature = "cuda"))] + let profile_step = 0; let warmup = args.warmup.clone(); let warmup_only = args.warmup_only.clone(); @@ -452,9 +455,7 @@ fn url_decode(encoded_str: &str) -> String { } // #[actix_web::main] -pub async fn server_main() -> () { - let mut args = Args::parse(); - +pub async fn server_main(mut args: RllmCliArgs) -> () { aicirt::init_log(if args.daemon { aicirt::LogMode::Deamon } else { diff --git a/rllm/src/util.rs b/rllm/src/util.rs index a04ec475..99053e7d 100644 --- a/rllm/src/util.rs +++ b/rllm/src/util.rs @@ -1,5 +1,7 @@ use crate::HashMap; +use aicirt::bintokens::list_tokenizers; use anyhow::{bail, Result}; +use clap::{Args, Command, Parser}; use std::time::Instant; const SETTINGS: [(&'static str, &'static str, f64); 4] = [ @@ -16,9 +18,12 @@ lazy_static::lazy_static! { } pub fn all_settings() -> String { - SETTINGS - .map(|(k, d, v)| format!("{}: {} (default={})", k, d, v)) - .join("\n") + format!( + "Settings available via -s or --setting (with their default values):\n{all}\n", + all = SETTINGS + .map(|(k, d, v)| format!(" -s {:20} {}", format!("{}={}", k, v), d)) + .join("\n") + ) } pub fn set_setting(name: &str, val: f64) -> Result<()> { @@ -56,7 +61,7 @@ pub fn apply_settings(settings: &Vec) -> Result<()> { Ok(_) => {} Err(e) => { bail!( - "all settings:\n{all}\nfailed to set setting {s}: {e}", + "{all}\nfailed to set setting {s}: {e}", all = all_settings() ); } @@ -65,6 +70,19 @@ pub fn apply_settings(settings: &Vec) -> Result<()> { Ok(()) } +pub fn parse_with_settings() -> T +where + T: Parser + Args, +{ + let cli = + Command::new("CLI").after_help(format!("\n{}\n{}", all_settings(), list_tokenizers())); + let cli = T::augment_args(cli); + let matches = cli.get_matches(); + T::from_arg_matches(&matches) + .map_err(|err| err.exit()) + .unwrap() +} + pub fn limit_str(s: &str, max_len: usize) -> String { limit_bytes(s.as_bytes(), max_len) }