From 5ddfeabfe75dbc23ef04964d6cc0ffde4e1762b8 Mon Sep 17 00:00:00 2001 From: Hugo Larcher Date: Thu, 19 Sep 2024 11:57:10 +0200 Subject: [PATCH] feat: Allow HF_TOKEN env variable --- README.md | 13 +++++++++++-- src/app.rs | 4 ++-- src/benchmark.rs | 4 +++- src/lib.rs | 5 +++-- src/main.rs | 14 +++++++++++--- src/requests.rs | 20 ++++++++++++++------ 6 files changed, 44 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index b857ca2..8da638b 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ Benchmarks using constant arrival rate or constant virtual user count. -![ui.png](assets%2Fui.png) +![ui.png](assets/ui.png) ## Table of contents @@ -13,7 +13,12 @@ Benchmarks using constant arrival rate or constant virtual user count. * [Text Generation Inference benchmarking tool](#text-generation-inference-benchmarking-tool) * [Table of contents](#table-of-contents) * [TODO](#todo) - * [Running a benchmark](#running-a-benchmark) + * [Get started](#get-started) + * [Run a benchmark](#run-a-benchmark) + * [Configure your benchmark](#configure-your-benchmark) + * [Benchmark mode](#benchmark-mode) + * [Dataset configuration](#dataset-configuration) + * [Prompt configuration](#prompt-configuration) * [Development](#development) * [Frequently Asked Questions](#frequently-asked-questions) @@ -130,3 +135,7 @@ $ make build * **Constant arrival rate** means that the rate of requests is fixed and the number of virtual users is adjusted to maintain that rate. Queries hit the server independently of responses performances. **Constant virtual user count** is a closed loop model where the server's response time dictates the number of iterations. **Constant arrival rate** is an open-loop model more representative of real-life workloads. + +* **What is the influence of CUDA graphs?** +CUDA graphs are used to optimize the GPU usage by minimizing the overhead of launching kernels. This can lead to better performance in some cases, but can also lead to worse performance in others. +If your CUDA graphs are not evenly distributed, you may see a performance drop at some request rates as batch size may fall in a bigger CUDA graph batch size leading to a lost of compute due to excessive padding. \ No newline at end of file diff --git a/src/app.rs b/src/app.rs index 3356bd4..e238803 100644 --- a/src/app.rs +++ b/src/app.rs @@ -138,8 +138,8 @@ pub async fn run_console( _ = stop_receiver_signal.recv() => {} } }); - event_thread.await.unwrap(); - app_thread.await.unwrap(); + let _ = event_thread.await; + let _ = app_thread.await; } impl App { diff --git a/src/benchmark.rs b/src/benchmark.rs index a5e3b19..93e0134 100644 --- a/src/benchmark.rs +++ b/src/benchmark.rs @@ -9,6 +9,8 @@ use crate::{executors, scheduler}; use crate::results::{BenchmarkReport, BenchmarkResults}; use crate::scheduler::{ExecutorType, SchedulerProgress}; +const THROUGHPUT_BUDGET: f64 = 1.2; // sweep up to 120% of max throughput + #[derive(Clone, Debug, strum_macros::Display, Serialize)] pub enum BenchmarkKind { Throughput, @@ -297,7 +299,7 @@ impl Benchmark { let mut rates = Vec::new(); let num_rates = self.config.num_rates; for i in 1..=num_rates { - rates.push(i as f64 * max_throughput / num_rates as f64); + rates.push(i as f64 * max_throughput*THROUGHPUT_BUDGET / num_rates as f64); } for rate in rates { self.run_rate(rate).await?; diff --git a/src/lib.rs b/src/lib.rs index a32b34d..298a05d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -38,6 +38,7 @@ pub struct RunConfiguration { pub decode_options: Option, pub dataset: String, pub dataset_file: String, + pub hf_token: Option, } pub async fn run(run_config: RunConfiguration, @@ -113,8 +114,8 @@ pub async fn run(run_config: RunConfiguration, timestamp: chrono::Utc::now(), level: Level::Info, })); - let filepath = requests::ConversationTextRequestGenerator::download_dataset(run_config.dataset, run_config.dataset_file).expect("Can't download dataset"); - let requests = requests::ConversationTextRequestGenerator::load(filepath, run_config.tokenizer_name.clone(), run_config.prompt_options.clone(), run_config.decode_options.clone())?; + let filepath = requests::ConversationTextRequestGenerator::download_dataset(run_config.dataset, run_config.dataset_file,run_config.hf_token.clone()).expect("Can't download dataset"); + let requests = requests::ConversationTextRequestGenerator::load(filepath, run_config.tokenizer_name, run_config.prompt_options, run_config.decode_options, run_config.hf_token)?; let mut benchmark = benchmark::Benchmark::new(config.clone(), Box::new(backend), Arc::from(Mutex::from(requests)), tx.clone(), stop_sender.clone()); let mut stop_receiver = stop_sender.subscribe(); diff --git a/src/main.rs b/src/main.rs index a7b78cb..f4247c1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -74,10 +74,10 @@ struct Args { )] decode_options: Option, /// Hugging Face dataset to use for prompt generation - #[clap(default_value="hlarcher/share_gpt_small",long,env)] + #[clap(default_value = "hlarcher/share_gpt_small", long, env)] dataset: String, /// File to use in the Dataset - #[clap(default_value="share_gpt_filtered_small.json",long,env)] + #[clap(default_value = "share_gpt_filtered_small.json", long, env)] dataset_file: String, } @@ -128,6 +128,13 @@ async fn main() { }); let stop_sender_clone = stop_sender.clone(); + // get HF token + let token_env_key = "HF_TOKEN".to_string(); + let cache = hf_hub::Cache::default(); + let hf_token = match std::env::var(token_env_key).ok() { + Some(token) => Some(token), + None => cache.token(), + }; let run_config = RunConfiguration { url: args.url.clone(), tokenizer_name: args.tokenizer_name.clone(), @@ -142,6 +149,7 @@ async fn main() { decode_options: args.decode_options.clone(), dataset: args.dataset.clone(), dataset_file: args.dataset_file.clone(), + hf_token, }; let main_thread = tokio::spawn(async move { match run(run_config, @@ -153,5 +161,5 @@ async fn main() { } }; }); - main_thread.await.expect("Failed to run main thread"); + let _ = main_thread.await; } diff --git a/src/requests.rs b/src/requests.rs index cce5227..20c61d2 100644 --- a/src/requests.rs +++ b/src/requests.rs @@ -6,9 +6,9 @@ use tokio::sync::mpsc::Sender; use reqwest_eventsource::{Error, Event, EventSource}; use log::{debug, error, info, trace}; use rand_distr::Distribution; -use tokenizers::Tokenizer; +use tokenizers::{FromPretrainedParameters, Tokenizer}; use futures_util::StreamExt; -use hf_hub::api::sync::Api; +use hf_hub::api::sync::{ApiBuilder}; use indicatif::{ProgressBar, ProgressStyle}; use rayon::iter::split; use rayon::prelude::*; @@ -221,8 +221,16 @@ impl TokenizeOptions { } impl ConversationTextRequestGenerator { - pub fn load(filepath: PathBuf, tokenizer: String, prompt_tokenize_opts: Option, decode_tokenize_opts: Option) -> anyhow::Result { - let tokenizer = Arc::new(Tokenizer::from_pretrained(tokenizer, None).expect("Unable to load tokenizer")); + pub fn load(filepath: PathBuf, tokenizer: String, prompt_tokenize_opts: Option, decode_tokenize_opts: Option, hf_token: Option) -> anyhow::Result { + let mut params = FromPretrainedParameters::default(); + params.auth_token = hf_token; + let tokenizer = match Tokenizer::from_pretrained(tokenizer, Some(params)) { + Ok(tokenizer) => tokenizer, + Err(e) => { + return Err(anyhow::anyhow!("Error loading tokenizer: {e}")); + } + }; + let tokenizer = Arc::new(tokenizer); // load json file let input = std::fs::read_to_string(&filepath)?; let data: Vec = serde_json::from_str(&input).expect("Unable to parse input file. Check that it is valid JSON and matches the expected format."); @@ -305,8 +313,8 @@ impl ConversationTextRequestGenerator { }) } - pub fn download_dataset(repo_name: String, filename: String) -> anyhow::Result { - let api = Api::new().unwrap(); + pub fn download_dataset(repo_name: String, filename: String, hf_token: Option) -> anyhow::Result { + let api = ApiBuilder::new().with_token(hf_token).build()?; let repo = api.dataset(repo_name); let dataset = repo.get(&filename)?; Ok(dataset)