From e24240a9738c90e0118522489c093f8ae64ad32a Mon Sep 17 00:00:00 2001 From: Hugo Larcher Date: Fri, 13 Sep 2024 13:14:33 +0200 Subject: [PATCH] feat: add prompt/decode sampling options --- src/benchmark.rs | 7 +- src/lib.rs | 55 ++++++------- src/main.rs | 92 ++++++++++++++++------ src/requests.rs | 196 ++++++++++++++++++++++++++++++++--------------- 4 files changed, 237 insertions(+), 113 deletions(-) diff --git a/src/benchmark.rs b/src/benchmark.rs index 7b5df50..3567c61 100644 --- a/src/benchmark.rs +++ b/src/benchmark.rs @@ -6,7 +6,7 @@ use sysinfo::{CpuRefreshKind, MemoryRefreshKind, System}; use tokio::fs; use tokio::sync::{broadcast, mpsc, Mutex}; use tokio::sync::mpsc::{Receiver, Sender}; -use crate::requests::{TextGenerationBackend, TextRequestGenerator}; +use crate::requests::{TextGenerationBackend, TextRequestGenerator, TokenizeOptions}; use crate::{executors, scheduler}; use crate::results::{BenchmarkReport, BenchmarkResults}; use crate::scheduler::{ExecutorType, SchedulerProgress}; @@ -67,9 +67,8 @@ pub struct BenchmarkConfig { pub warmup_duration: Duration, pub rate: Option, pub num_rates: u64, - pub prompt_length: u64, - pub prompt_variance: u64, - pub decode_length: u64, + pub prompt_options: Option, + pub decode_options: Option, } impl BenchmarkConfig { diff --git a/src/lib.rs b/src/lib.rs index 6b65d4d..c854e43 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,6 +11,7 @@ pub use crate::app::run_console; use crate::benchmark::{BenchmarkReportWriter, Event, MessageEvent}; pub use crate::benchmark::{BenchmarkConfig, BenchmarkKind}; use crate::requests::{OpenAITextGenerationBackend}; +pub use crate::requests::{TokenizeOptions}; mod requests; mod executors; @@ -22,45 +23,47 @@ mod app; mod event; mod flux; -pub async fn run(url: String, - tokenizer_name: String, - prompt_length: u64, - prompt_variance: u64, - decode_length: u64, - max_vus: u64, - duration: std::time::Duration, - rate: Option, - num_rates: u64, - benchmark_kind: String, - prewarm_duration: std::time::Duration, - interactive: bool, +pub struct RunConfiguration { + pub url: String, + pub tokenizer_name: String, + pub max_vus: u64, + pub duration: std::time::Duration, + pub rate: Option, + pub num_rates: u64, + pub benchmark_kind: String, + pub warmup_duration: std::time::Duration, + pub interactive: bool, + pub prompt_options: Option, + pub decode_options: Option, +} + +pub async fn run(run_config: RunConfiguration, stop_sender: Sender<()>, ) -> anyhow::Result<()> { info!("Starting benchmark"); // set process system limits sysinfo::set_open_files_limit(0); // let backend = OpenAITextGenerationBackend::new("".to_string(), "http://10.90.11.68:8000".to_string()); - let backend = OpenAITextGenerationBackend::new("".to_string(), url, tokenizer_name.clone()); + let backend = OpenAITextGenerationBackend::new("".to_string(), run_config.url.clone(), run_config.tokenizer_name.clone()); let config = BenchmarkConfig { - max_vus, - duration, - benchmark_kind: match benchmark_kind.to_lowercase().as_str() { + max_vus:run_config.max_vus, + duration:run_config.duration, + benchmark_kind: match run_config.benchmark_kind.to_lowercase().as_str() { "throughput" => BenchmarkKind::Throughput, "sweep" => BenchmarkKind::Sweep, "rate" => BenchmarkKind::Rate, _ => BenchmarkKind::Sweep, }, - warmup_duration: prewarm_duration, - rate, - num_rates, - prompt_length, - prompt_variance, - decode_length, + warmup_duration: run_config.warmup_duration, + rate: run_config.rate, + num_rates: run_config.num_rates, + prompt_options: run_config.prompt_options.clone(), + decode_options: run_config.decode_options.clone(), }; config.validate()?; let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); - if interactive { + if run_config.interactive { // send logs to file let target = Box::new(File::create("log.txt").expect("Can't create file")); env_logger::Builder::new() @@ -90,7 +93,7 @@ pub async fn run(url: String, error!("Received stop signal, stopping benchmark"); } _ = async{ - if interactive { + if run_config.interactive { run_console(config_clone, rx, stop_sender_clone).await; } else { // consume the channel to avoid closed channel error @@ -107,8 +110,8 @@ pub async fn run(url: String, timestamp: chrono::Utc::now(), level: Level::Info, })); - let filepath = requests::ShareGPTTextRequestGenerator::download_dataset("hlarcher/share_gpt_small".to_string(), "share_gpt_filtered_small.json".to_string()).expect("Can't download dataset"); - let requests = requests::ShareGPTTextRequestGenerator::new(filepath, tokenizer_name, prompt_length, 1, prompt_length * 2, prompt_variance); + let filepath = requests::ConversationTextRequestGenerator::download_dataset("hlarcher/share_gpt_small".to_string(), "share_gpt_filtered_small.json".to_string()).expect("Can't download dataset"); + let requests = requests::ConversationTextRequestGenerator::new(filepath, run_config.tokenizer_name.clone(), run_config.prompt_options.clone(), run_config.decode_options.clone()); let mut benchmark = benchmark::Benchmark::new(config.clone(), Box::new(backend), Arc::from(Mutex::from(requests)), tx.clone(), stop_sender.clone()); let mut stop_receiver = stop_sender.subscribe(); diff --git a/src/main.rs b/src/main.rs index a5a61ba..474e60d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,7 +4,7 @@ use clap::error::ErrorKind::InvalidValue; use log::error; use reqwest::Url; use tokio::sync::broadcast; -use text_generation_inference_benchmark::{run}; +use text_generation_inference_benchmark::{run, RunConfiguration, TokenizeOptions}; #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] @@ -40,15 +40,39 @@ struct Args { /// Disable console UI #[clap(short, long, env)] no_console: bool, - /// Prompt token length - #[clap(default_value = "50", long, env)] - prompt_length: u64, - /// Variance of prompt token length following a normal distribution - #[clap(default_value = "10", long, env)] - prompt_variance: u64, - /// Decode token length (number of tokens to generate) - #[clap(default_value = "10", long, env)] - decode_length: u64, + /// Constraints for prompt length. + /// No value means use the input prompt as defined in input dataset. + /// We sample the number of tokens to generate from a normal distribution. + /// Specified as a comma-separated list of key=value pairs. + /// * num_tokens: target number of prompt tokens + /// * min_tokens: minimum number of prompt tokens + /// * max_tokens: maximum number of prompt tokens + /// * variance: variance in the number of prompt tokens + /// + /// Example: num_tokens=50,max_tokens=60,min_tokens=40,variance=10 + #[clap( + default_value = "num_tokens=50,max_tokens=60,min_tokens=40,variance=10", + long, + env, + value_parser(parse_tokenizer_options) + )] + prompt_options: Option, + /// Constraints for the generated text. + /// We sample the number of tokens to generate from a normal distribution. + /// Specified as a comma-separated list of key=value pairs. + /// * num_tokens: target number of generated tokens + /// * min_tokens: minimum number of generated tokens + /// * max_tokens: maximum number of generated tokens + /// * variance: variance in the number of generated tokens + /// + /// Example: num_tokens=50,max_tokens=60,min_tokens=40,variance=10 + #[clap( + default_value = "num_tokens=50,max_tokens=60,min_tokens=40,variance=10", + long, + env, + value_parser(parse_tokenizer_options) + )] + decode_options: Option, } fn parse_duration(s: &str) -> Result { @@ -62,6 +86,28 @@ fn parse_url(s: &str) -> Result { } } +fn parse_tokenizer_options(s: &str) -> Result { + let mut tokenizer_options = TokenizeOptions::new(); + let items = s.split(",").collect::>(); + for item in items.iter() { + let key_value = item.split("=").collect::>(); + if key_value.len() != 2 { + return Err(Error::new(InvalidValue)); + } + match key_value[0] { + "num_tokens" => tokenizer_options.num_tokens = key_value[1].parse::().unwrap(), + "min_tokens" => tokenizer_options.min_tokens = key_value[1].parse::().unwrap(), + "max_tokens" => tokenizer_options.max_tokens = key_value[1].parse::().unwrap(), + "variance" => tokenizer_options.variance = key_value[1].parse::().unwrap(), + _ => return Err(Error::new(InvalidValue)), + } + }; + if tokenizer_options.num_tokens == 0 || tokenizer_options.min_tokens == 0 || tokenizer_options.max_tokens == 0 || tokenizer_options.min_tokens > tokenizer_options.max_tokens { + return Err(Error::new(InvalidValue)); + } + Ok(tokenizer_options) +} + #[tokio::main] async fn main() { let args = Args::parse(); @@ -76,19 +122,21 @@ async fn main() { }); let stop_sender_clone = stop_sender.clone(); + let run_config = RunConfiguration { + url: args.url.clone(), + tokenizer_name: args.tokenizer_name.clone(), + max_vus: args.max_vus, + duration: args.duration, + rate: args.rate, + num_rates: args.num_rates, + benchmark_kind: args.benchmark_kind.clone(), + warmup_duration: args.warmup, + interactive: !args.no_console, + prompt_options: args.prompt_options.clone(), + decode_options: args.decode_options.clone(), + }; let main_thread = tokio::spawn(async move { - match run(args.url, - args.tokenizer_name, - args.prompt_length, - args.prompt_variance, - args.decode_length, - args.max_vus, - args.duration, - args.rate, - args.num_rates, - args.benchmark_kind, - args.warmup, - !args.no_console, + match run(run_config, stop_sender_clone, ).await { Ok(_) => {} diff --git a/src/requests.rs b/src/requests.rs index 2f647cf..4bfa56e 100644 --- a/src/requests.rs +++ b/src/requests.rs @@ -14,11 +14,13 @@ use rayon::iter::split; use rayon::prelude::*; use serde::{Deserialize, Serialize}; + #[derive(Debug, Clone)] pub struct TextGenerationRequest { pub prompt: String, - pub num_tokens: u64, - pub max_tokens: u64, + pub num_prompt_tokens: u64, // this includes the system prompt if present + pub num_decode_tokens: Option, + pub system_prompt: Option, } #[async_trait] @@ -47,7 +49,7 @@ impl Clone for Box { pub struct OpenAITextGenerationBackend { pub api_key: String, pub base_url: String, - pub model_name: String + pub model_name: String, } #[derive(Deserialize, Serialize, Clone, Debug)] @@ -103,11 +105,11 @@ impl TextGenerationBackend for OpenAITextGenerationBackend { "content": request.prompt } ], - "max_tokens": request.max_tokens, + "max_tokens": request.num_decode_tokens, "stream": true, })); // start timer - aggregated_response.start(request.num_tokens); + aggregated_response.start(request.num_prompt_tokens); let mut es = EventSource::new(req).unwrap(); let mut final_response = "".to_string(); while let Some(event) = es.next().await { @@ -135,7 +137,7 @@ impl TextGenerationBackend for OpenAITextGenerationBackend { aggregated_response.add_tokens(1); aggregated_response.stop(); let content = choices[0].clone().delta.unwrap().content; - trace!("Generated text using OpenAI API | prompt: {prompt}, max tokens: {max_tokens}, response: {message}", prompt = request.prompt, max_tokens = request.max_tokens,message = &content); + trace!("Generated text using OpenAI API | prompt: {prompt}, max tokens: {max_tokens:?}, response: {message}", prompt = request.prompt, max_tokens = request.num_decode_tokens,message = &content); } }; } @@ -163,29 +165,48 @@ pub trait TextRequestGenerator: Sync { } #[derive(Clone)] -pub struct ShareGPTTextRequestGenerator { +pub struct ConversationTextRequestGenerator { pub requests: Vec, current_index: Arc, } #[derive(Deserialize, Serialize, Clone)] -pub struct ShareGPTConversation { - pub from: String, - pub value: String, +pub struct Conversation { + pub role: String, + pub content: String, } #[derive(Deserialize, Serialize, Clone)] -pub struct ShareGPTEntry { +pub struct ConversationEntry { pub id: String, - pub conversations: Vec, + pub conversations: Vec, +} + +#[derive(Clone, Serialize, Debug)] +pub struct TokenizeOptions { + pub num_tokens: u64, + pub min_tokens: u64, + pub max_tokens: u64, + pub variance: u64, } -impl ShareGPTTextRequestGenerator { - pub fn new(filepath: PathBuf, tokenizer: String, prompt_tokens: u64, min_tokens: u64, max_tokens: u64, variance: u64) -> Self { +impl TokenizeOptions { + pub fn new() -> Self { + Self { + num_tokens: 0, + min_tokens: 0, + max_tokens: 0, + variance: 0, + } + } +} + +impl ConversationTextRequestGenerator { + pub fn new(filepath: PathBuf, tokenizer: String, prompt_tokenize_opts: Option, decode_tokenize_opts: Option) -> Self { let tokenizer = Arc::new(Tokenizer::from_pretrained(tokenizer, None).expect("Unable to load tokenizer")); // load json file let input = std::fs::read_to_string(&filepath).expect("Unable to read input file"); - let data: Vec = serde_json::from_str(&input).expect("Unable to parse input file"); + let data: Vec = serde_json::from_str(&input).expect("Unable to parse input file"); // generate requests let requests: Arc>> = Arc::from(Mutex::from(Vec::new())); info!("Generating requests from {filepath}", filepath = filepath.display().to_string()); @@ -198,27 +219,61 @@ impl ShareGPTTextRequestGenerator { if entry.conversations.len() == 0 { continue; } - let prompt = entry.conversations[0].value.clone(); - // compute number of tokens to generate using a Gaussian distribution - let normal = rand_distr::Normal::new(prompt_tokens as f64, variance as f64).unwrap(); - let mut num_tokens = normal.sample(&mut rand::thread_rng()) as u64; - if num_tokens < min_tokens { - num_tokens = min_tokens; - } - if num_tokens > max_tokens { - num_tokens = max_tokens; - } - let sampled_prompt = match tokenize_prompt(prompt, tokenizer.clone(), num_tokens) { - Ok(prompt) => prompt, - Err(e) => { - debug!("Error tokenizing prompt: {e}"); - continue; - } + let system_prompt = entry.conversations.iter().find(|c| c.role == "system").map(|c| c.content.clone()); + let system_prompt_tokens = match system_prompt { + Some(ref prompt) => { + let (_, num_tokens) = match tokenize_prompt(prompt.clone(), tokenizer.clone(), None) { + Ok((prompt, num_tokens)) => (prompt, num_tokens), + Err(e) => { + debug!("Error tokenizing system prompt: {e}"); + return; + } + }; + num_tokens + }, + None => 0, }; - requests.lock().unwrap().push(TextGenerationRequest { - prompt: sampled_prompt, - num_tokens, - max_tokens, + entry.conversations.iter().filter(|c| c.role == "user").for_each(|c| { + let prompt = c.content.clone(); + let num_decode_tokens = decode_tokenize_opts.clone().map_or_else(|| None, |opts| Some(sample_num_tokens(opts.num_tokens, opts.min_tokens, opts.max_tokens, opts.variance))); + match &prompt_tokenize_opts { + None => { + let (_, num_tokens) = match tokenize_prompt(prompt.clone(), tokenizer.clone(), None) { + Ok((prompt, num_tokens)) => (prompt, num_tokens), + Err(e) => { + debug!("Error tokenizing prompt: {e}"); + return; + } + }; + requests.lock().unwrap().push(TextGenerationRequest { + prompt, + num_prompt_tokens: num_tokens+system_prompt_tokens, + num_decode_tokens, + system_prompt: system_prompt.clone(), + }); + } + Some(options) => { + let num_tokens = options.num_tokens; + let min_tokens = options.min_tokens; + let max_tokens = options.max_tokens; + let variance = options.variance; + // compute number of tokens to generate using a Gaussian distribution + let num_tokens = sample_num_tokens(num_tokens, min_tokens, max_tokens, variance); + let sampled_prompt = match tokenize_prompt(prompt.clone(), tokenizer.clone(), Some(num_tokens)) { + Ok(prompt) => prompt, + Err(e) => { + debug!("Error tokenizing prompt: {e}"); + return; + } + }; + requests.lock().unwrap().push(TextGenerationRequest { + prompt: sampled_prompt.0, + num_prompt_tokens: num_tokens+system_prompt_tokens, + num_decode_tokens, + system_prompt: system_prompt.clone(), + }); + } + } }); // TODO: check that we have enough requests } @@ -239,7 +294,19 @@ impl ShareGPTTextRequestGenerator { } } -fn entry_splitter(gen: Vec) -> (Vec, Option>) { +fn sample_num_tokens(num_tokens: u64, min_tokens: u64, max_tokens: u64, variance: u64) -> u64 { + let normal = rand_distr::Normal::new(num_tokens as f64, variance as f64).unwrap(); + let mut num_tokens = normal.sample(&mut rand::thread_rng()) as u64; + if num_tokens < min_tokens { + num_tokens = min_tokens; + } + if num_tokens > max_tokens { + num_tokens = max_tokens; + } + num_tokens +} + +fn entry_splitter(gen: Vec) -> (Vec, Option>) { if gen.len() <= 2 { return (gen, None); } @@ -250,7 +317,7 @@ fn entry_splitter(gen: Vec) -> (Vec, Option TextGenerationRequest { let idx = self.current_index.fetch_add(1, std::sync::atomic::Ordering::SeqCst); if idx >= (self.requests.len() - 1) as i64 { @@ -261,33 +328,40 @@ impl TextRequestGenerator for ShareGPTTextRequestGenerator { } -fn tokenize_prompt(prompt: String, tokenizer: Arc, num_tokens: u64) -> anyhow::Result { +fn tokenize_prompt(prompt: String, tokenizer: Arc, num_tokens: Option) -> anyhow::Result<(String, u64)> { let prompt_tokens = tokenizer.encode(prompt.clone(), false).map_err(|_| anyhow::anyhow!("Error tokenizing prompt"))?; - if prompt_tokens.len() < num_tokens as usize { - return Err(anyhow::anyhow!("Prompt is too short to tokenize")); - } - // let's do a binary search to find the right number of tokens - let mut low = 1; - let mut high = prompt.len() as u64; - let mut prompt_sub = String::new(); - while low < high { - let mid = (low + high) / 2; - prompt_sub = prompt.chars().skip((low - 1) as usize).take(high as usize).collect::(); - let tokenized_len = match tokenizer.encode(prompt_sub.clone(), false) { - Ok(tokens) => tokens.len(), - Err(_) => { - return Err(anyhow::anyhow!("Error tokenizing prompt")); + match num_tokens { + None => { + Ok((prompt, prompt_tokens.len() as u64)) + } + Some(num_tokens) => { + if prompt_tokens.len() < num_tokens as usize { + return Err(anyhow::anyhow!("Prompt is too short to tokenize")); } - }; - if tokenized_len == num_tokens as usize { - return Ok(prompt_sub.to_string()); - } else if tokenized_len > num_tokens as usize { - high = mid; - } else { - low = mid + 1; + // let's do a binary search to find the right number of tokens + let mut low = 1; + let mut high = prompt.len() as u64; + let mut prompt_sub = String::new(); + while low < high { + let mid = (low + high) / 2; + prompt_sub = prompt.chars().skip((low - 1) as usize).take(high as usize).collect::(); + let tokenized_len = match tokenizer.encode(prompt_sub.clone(), false) { + Ok(tokens) => tokens.len(), + Err(_) => { + return Err(anyhow::anyhow!("Error tokenizing prompt")); + } + }; + if tokenized_len == num_tokens as usize { + return Ok((prompt_sub.to_string(), num_tokens)); + } else if tokenized_len > num_tokens as usize { + high = mid; + } else { + low = mid + 1; + } + } + Ok((prompt_sub.to_string(), prompt_tokens.len() as u64)) } } - Ok(prompt_sub.to_string()) } @@ -314,7 +388,7 @@ impl TextGenerationAggregatedResponse { failed: false, } } - fn start(&mut self, num_prompt_tokens:u64) { + fn start(&mut self, num_prompt_tokens: u64) { self.start_time = Some(std::time::Instant::now()); self.last_received_token_time = std::time::Instant::now(); self.num_prompt_tokens = num_prompt_tokens;