Skip to content

Commit

Permalink
feat: Allow HF_TOKEN env variable
Browse files Browse the repository at this point in the history
  • Loading branch information
Hugoch committed Sep 19, 2024
1 parent e5614fb commit 5ddfeab
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 16 deletions.
13 changes: 11 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,20 @@ Benchmarks using constant arrival rate or constant virtual user count.



![ui.png](assets%2Fui.png)
![ui.png](assets/ui.png)

## Table of contents

<!-- TOC -->
* [Text Generation Inference benchmarking tool](#text-generation-inference-benchmarking-tool)
* [Table of contents](#table-of-contents)
* [TODO](#todo)
* [Running a benchmark](#running-a-benchmark)
* [Get started](#get-started)
* [Run a benchmark](#run-a-benchmark)
* [Configure your benchmark](#configure-your-benchmark)
* [Benchmark mode](#benchmark-mode)
* [Dataset configuration](#dataset-configuration)
* [Prompt configuration](#prompt-configuration)
* [Development](#development)
* [Frequently Asked Questions](#frequently-asked-questions)
<!-- TOC -->
Expand Down Expand Up @@ -130,3 +135,7 @@ $ make build
* **Constant arrival rate** means that the rate of requests is fixed and the number of virtual users is adjusted to maintain that rate. Queries hit the server independently of responses performances.

**Constant virtual user count** is a closed loop model where the server's response time dictates the number of iterations. **Constant arrival rate** is an open-loop model more representative of real-life workloads.

* **What is the influence of CUDA graphs?**
CUDA graphs are used to optimize the GPU usage by minimizing the overhead of launching kernels. This can lead to better performance in some cases, but can also lead to worse performance in others.
If your CUDA graphs are not evenly distributed, you may see a performance drop at some request rates as batch size may fall in a bigger CUDA graph batch size leading to a lost of compute due to excessive padding.
4 changes: 2 additions & 2 deletions src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ pub async fn run_console(
_ = stop_receiver_signal.recv() => {}
}
});
event_thread.await.unwrap();
app_thread.await.unwrap();
let _ = event_thread.await;
let _ = app_thread.await;
}

impl App {
Expand Down
4 changes: 3 additions & 1 deletion src/benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use crate::{executors, scheduler};
use crate::results::{BenchmarkReport, BenchmarkResults};
use crate::scheduler::{ExecutorType, SchedulerProgress};

const THROUGHPUT_BUDGET: f64 = 1.2; // sweep up to 120% of max throughput

#[derive(Clone, Debug, strum_macros::Display, Serialize)]
pub enum BenchmarkKind {
Throughput,
Expand Down Expand Up @@ -297,7 +299,7 @@ impl Benchmark {
let mut rates = Vec::new();
let num_rates = self.config.num_rates;
for i in 1..=num_rates {
rates.push(i as f64 * max_throughput / num_rates as f64);
rates.push(i as f64 * max_throughput*THROUGHPUT_BUDGET / num_rates as f64);
}
for rate in rates {
self.run_rate(rate).await?;
Expand Down
5 changes: 3 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pub struct RunConfiguration {
pub decode_options: Option<TokenizeOptions>,
pub dataset: String,
pub dataset_file: String,
pub hf_token: Option<String>,
}

pub async fn run(run_config: RunConfiguration,
Expand Down Expand Up @@ -113,8 +114,8 @@ pub async fn run(run_config: RunConfiguration,
timestamp: chrono::Utc::now(),
level: Level::Info,
}));
let filepath = requests::ConversationTextRequestGenerator::download_dataset(run_config.dataset, run_config.dataset_file).expect("Can't download dataset");
let requests = requests::ConversationTextRequestGenerator::load(filepath, run_config.tokenizer_name.clone(), run_config.prompt_options.clone(), run_config.decode_options.clone())?;
let filepath = requests::ConversationTextRequestGenerator::download_dataset(run_config.dataset, run_config.dataset_file,run_config.hf_token.clone()).expect("Can't download dataset");
let requests = requests::ConversationTextRequestGenerator::load(filepath, run_config.tokenizer_name, run_config.prompt_options, run_config.decode_options, run_config.hf_token)?;

let mut benchmark = benchmark::Benchmark::new(config.clone(), Box::new(backend), Arc::from(Mutex::from(requests)), tx.clone(), stop_sender.clone());
let mut stop_receiver = stop_sender.subscribe();
Expand Down
14 changes: 11 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ struct Args {
)]
decode_options: Option<TokenizeOptions>,
/// Hugging Face dataset to use for prompt generation
#[clap(default_value="hlarcher/share_gpt_small",long,env)]
#[clap(default_value = "hlarcher/share_gpt_small", long, env)]
dataset: String,
/// File to use in the Dataset
#[clap(default_value="share_gpt_filtered_small.json",long,env)]
#[clap(default_value = "share_gpt_filtered_small.json", long, env)]
dataset_file: String,
}

Expand Down Expand Up @@ -128,6 +128,13 @@ async fn main() {
});

let stop_sender_clone = stop_sender.clone();
// get HF token
let token_env_key = "HF_TOKEN".to_string();
let cache = hf_hub::Cache::default();
let hf_token = match std::env::var(token_env_key).ok() {
Some(token) => Some(token),
None => cache.token(),
};
let run_config = RunConfiguration {
url: args.url.clone(),
tokenizer_name: args.tokenizer_name.clone(),
Expand All @@ -142,6 +149,7 @@ async fn main() {
decode_options: args.decode_options.clone(),
dataset: args.dataset.clone(),
dataset_file: args.dataset_file.clone(),
hf_token,
};
let main_thread = tokio::spawn(async move {
match run(run_config,
Expand All @@ -153,5 +161,5 @@ async fn main() {
}
};
});
main_thread.await.expect("Failed to run main thread");
let _ = main_thread.await;
}
20 changes: 14 additions & 6 deletions src/requests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ use tokio::sync::mpsc::Sender;
use reqwest_eventsource::{Error, Event, EventSource};
use log::{debug, error, info, trace};
use rand_distr::Distribution;
use tokenizers::Tokenizer;
use tokenizers::{FromPretrainedParameters, Tokenizer};
use futures_util::StreamExt;
use hf_hub::api::sync::Api;
use hf_hub::api::sync::{ApiBuilder};
use indicatif::{ProgressBar, ProgressStyle};
use rayon::iter::split;
use rayon::prelude::*;
Expand Down Expand Up @@ -221,8 +221,16 @@ impl TokenizeOptions {
}

impl ConversationTextRequestGenerator {
pub fn load(filepath: PathBuf, tokenizer: String, prompt_tokenize_opts: Option<TokenizeOptions>, decode_tokenize_opts: Option<TokenizeOptions>) -> anyhow::Result<Self> {
let tokenizer = Arc::new(Tokenizer::from_pretrained(tokenizer, None).expect("Unable to load tokenizer"));
pub fn load(filepath: PathBuf, tokenizer: String, prompt_tokenize_opts: Option<TokenizeOptions>, decode_tokenize_opts: Option<TokenizeOptions>, hf_token: Option<String>) -> anyhow::Result<Self> {
let mut params = FromPretrainedParameters::default();
params.auth_token = hf_token;
let tokenizer = match Tokenizer::from_pretrained(tokenizer, Some(params)) {
Ok(tokenizer) => tokenizer,
Err(e) => {
return Err(anyhow::anyhow!("Error loading tokenizer: {e}"));
}
};
let tokenizer = Arc::new(tokenizer);
// load json file
let input = std::fs::read_to_string(&filepath)?;
let data: Vec<ConversationEntry> = serde_json::from_str(&input).expect("Unable to parse input file. Check that it is valid JSON and matches the expected format.");
Expand Down Expand Up @@ -305,8 +313,8 @@ impl ConversationTextRequestGenerator {
})
}

pub fn download_dataset(repo_name: String, filename: String) -> anyhow::Result<PathBuf> {
let api = Api::new().unwrap();
pub fn download_dataset(repo_name: String, filename: String, hf_token: Option<String>) -> anyhow::Result<PathBuf> {
let api = ApiBuilder::new().with_token(hf_token).build()?;
let repo = api.dataset(repo_name);
let dataset = repo.get(&filename)?;
Ok(dataset)
Expand Down

0 comments on commit 5ddfeab

Please sign in to comment.