Skip to content

Commit

Permalink
feat: add prompt/decode sampling options
Browse files Browse the repository at this point in the history
  • Loading branch information
Hugoch committed Sep 13, 2024
1 parent 959dc0d commit e24240a
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 113 deletions.
7 changes: 3 additions & 4 deletions src/benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use sysinfo::{CpuRefreshKind, MemoryRefreshKind, System};
use tokio::fs;
use tokio::sync::{broadcast, mpsc, Mutex};
use tokio::sync::mpsc::{Receiver, Sender};
use crate::requests::{TextGenerationBackend, TextRequestGenerator};
use crate::requests::{TextGenerationBackend, TextRequestGenerator, TokenizeOptions};
use crate::{executors, scheduler};
use crate::results::{BenchmarkReport, BenchmarkResults};
use crate::scheduler::{ExecutorType, SchedulerProgress};
Expand Down Expand Up @@ -67,9 +67,8 @@ pub struct BenchmarkConfig {
pub warmup_duration: Duration,
pub rate: Option<f64>,
pub num_rates: u64,
pub prompt_length: u64,
pub prompt_variance: u64,
pub decode_length: u64,
pub prompt_options: Option<TokenizeOptions>,
pub decode_options: Option<TokenizeOptions>,
}

impl BenchmarkConfig {
Expand Down
55 changes: 29 additions & 26 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub use crate::app::run_console;
use crate::benchmark::{BenchmarkReportWriter, Event, MessageEvent};
pub use crate::benchmark::{BenchmarkConfig, BenchmarkKind};
use crate::requests::{OpenAITextGenerationBackend};
pub use crate::requests::{TokenizeOptions};

mod requests;
mod executors;
Expand All @@ -22,45 +23,47 @@ mod app;
mod event;
mod flux;

pub async fn run(url: String,
tokenizer_name: String,
prompt_length: u64,
prompt_variance: u64,
decode_length: u64,
max_vus: u64,
duration: std::time::Duration,
rate: Option<f64>,
num_rates: u64,
benchmark_kind: String,
prewarm_duration: std::time::Duration,
interactive: bool,
pub struct RunConfiguration {
pub url: String,
pub tokenizer_name: String,
pub max_vus: u64,
pub duration: std::time::Duration,
pub rate: Option<f64>,
pub num_rates: u64,
pub benchmark_kind: String,
pub warmup_duration: std::time::Duration,
pub interactive: bool,
pub prompt_options: Option<TokenizeOptions>,
pub decode_options: Option<TokenizeOptions>,
}

pub async fn run(run_config: RunConfiguration,
stop_sender: Sender<()>,
) -> anyhow::Result<()> {
info!("Starting benchmark");
// set process system limits
sysinfo::set_open_files_limit(0);
// let backend = OpenAITextGenerationBackend::new("".to_string(), "http://10.90.11.68:8000".to_string());
let backend = OpenAITextGenerationBackend::new("".to_string(), url, tokenizer_name.clone());
let backend = OpenAITextGenerationBackend::new("".to_string(), run_config.url.clone(), run_config.tokenizer_name.clone());

let config = BenchmarkConfig {
max_vus,
duration,
benchmark_kind: match benchmark_kind.to_lowercase().as_str() {
max_vus:run_config.max_vus,
duration:run_config.duration,
benchmark_kind: match run_config.benchmark_kind.to_lowercase().as_str() {
"throughput" => BenchmarkKind::Throughput,
"sweep" => BenchmarkKind::Sweep,
"rate" => BenchmarkKind::Rate,
_ => BenchmarkKind::Sweep,
},
warmup_duration: prewarm_duration,
rate,
num_rates,
prompt_length,
prompt_variance,
decode_length,
warmup_duration: run_config.warmup_duration,
rate: run_config.rate,
num_rates: run_config.num_rates,
prompt_options: run_config.prompt_options.clone(),
decode_options: run_config.decode_options.clone(),
};
config.validate()?;
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
if interactive {
if run_config.interactive {
// send logs to file
let target = Box::new(File::create("log.txt").expect("Can't create file"));
env_logger::Builder::new()
Expand Down Expand Up @@ -90,7 +93,7 @@ pub async fn run(url: String,
error!("Received stop signal, stopping benchmark");
}
_ = async{
if interactive {
if run_config.interactive {
run_console(config_clone, rx, stop_sender_clone).await;
} else {
// consume the channel to avoid closed channel error
Expand All @@ -107,8 +110,8 @@ pub async fn run(url: String,
timestamp: chrono::Utc::now(),
level: Level::Info,
}));
let filepath = requests::ShareGPTTextRequestGenerator::download_dataset("hlarcher/share_gpt_small".to_string(), "share_gpt_filtered_small.json".to_string()).expect("Can't download dataset");
let requests = requests::ShareGPTTextRequestGenerator::new(filepath, tokenizer_name, prompt_length, 1, prompt_length * 2, prompt_variance);
let filepath = requests::ConversationTextRequestGenerator::download_dataset("hlarcher/share_gpt_small".to_string(), "share_gpt_filtered_small.json".to_string()).expect("Can't download dataset");
let requests = requests::ConversationTextRequestGenerator::new(filepath, run_config.tokenizer_name.clone(), run_config.prompt_options.clone(), run_config.decode_options.clone());

let mut benchmark = benchmark::Benchmark::new(config.clone(), Box::new(backend), Arc::from(Mutex::from(requests)), tx.clone(), stop_sender.clone());
let mut stop_receiver = stop_sender.subscribe();
Expand Down
92 changes: 70 additions & 22 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use clap::error::ErrorKind::InvalidValue;
use log::error;
use reqwest::Url;
use tokio::sync::broadcast;
use text_generation_inference_benchmark::{run};
use text_generation_inference_benchmark::{run, RunConfiguration, TokenizeOptions};

#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
Expand Down Expand Up @@ -40,15 +40,39 @@ struct Args {
/// Disable console UI
#[clap(short, long, env)]
no_console: bool,
/// Prompt token length
#[clap(default_value = "50", long, env)]
prompt_length: u64,
/// Variance of prompt token length following a normal distribution
#[clap(default_value = "10", long, env)]
prompt_variance: u64,
/// Decode token length (number of tokens to generate)
#[clap(default_value = "10", long, env)]
decode_length: u64,
/// Constraints for prompt length.
/// No value means use the input prompt as defined in input dataset.
/// We sample the number of tokens to generate from a normal distribution.
/// Specified as a comma-separated list of key=value pairs.
/// * num_tokens: target number of prompt tokens
/// * min_tokens: minimum number of prompt tokens
/// * max_tokens: maximum number of prompt tokens
/// * variance: variance in the number of prompt tokens
///
/// Example: num_tokens=50,max_tokens=60,min_tokens=40,variance=10
#[clap(
default_value = "num_tokens=50,max_tokens=60,min_tokens=40,variance=10",
long,
env,
value_parser(parse_tokenizer_options)
)]
prompt_options: Option<TokenizeOptions>,
/// Constraints for the generated text.
/// We sample the number of tokens to generate from a normal distribution.
/// Specified as a comma-separated list of key=value pairs.
/// * num_tokens: target number of generated tokens
/// * min_tokens: minimum number of generated tokens
/// * max_tokens: maximum number of generated tokens
/// * variance: variance in the number of generated tokens
///
/// Example: num_tokens=50,max_tokens=60,min_tokens=40,variance=10
#[clap(
default_value = "num_tokens=50,max_tokens=60,min_tokens=40,variance=10",
long,
env,
value_parser(parse_tokenizer_options)
)]
decode_options: Option<TokenizeOptions>,
}

fn parse_duration(s: &str) -> Result<Duration, Error> {
Expand All @@ -62,6 +86,28 @@ fn parse_url(s: &str) -> Result<String, Error> {
}
}

fn parse_tokenizer_options(s: &str) -> Result<TokenizeOptions, Error> {
let mut tokenizer_options = TokenizeOptions::new();
let items = s.split(",").collect::<Vec<&str>>();
for item in items.iter() {
let key_value = item.split("=").collect::<Vec<&str>>();
if key_value.len() != 2 {
return Err(Error::new(InvalidValue));
}
match key_value[0] {
"num_tokens" => tokenizer_options.num_tokens = key_value[1].parse::<u64>().unwrap(),
"min_tokens" => tokenizer_options.min_tokens = key_value[1].parse::<u64>().unwrap(),
"max_tokens" => tokenizer_options.max_tokens = key_value[1].parse::<u64>().unwrap(),
"variance" => tokenizer_options.variance = key_value[1].parse::<u64>().unwrap(),
_ => return Err(Error::new(InvalidValue)),
}
};
if tokenizer_options.num_tokens == 0 || tokenizer_options.min_tokens == 0 || tokenizer_options.max_tokens == 0 || tokenizer_options.min_tokens > tokenizer_options.max_tokens {
return Err(Error::new(InvalidValue));
}
Ok(tokenizer_options)
}

#[tokio::main]
async fn main() {
let args = Args::parse();
Expand All @@ -76,19 +122,21 @@ async fn main() {
});

let stop_sender_clone = stop_sender.clone();
let run_config = RunConfiguration {
url: args.url.clone(),
tokenizer_name: args.tokenizer_name.clone(),
max_vus: args.max_vus,
duration: args.duration,
rate: args.rate,
num_rates: args.num_rates,
benchmark_kind: args.benchmark_kind.clone(),
warmup_duration: args.warmup,
interactive: !args.no_console,
prompt_options: args.prompt_options.clone(),
decode_options: args.decode_options.clone(),
};
let main_thread = tokio::spawn(async move {
match run(args.url,
args.tokenizer_name,
args.prompt_length,
args.prompt_variance,
args.decode_length,
args.max_vus,
args.duration,
args.rate,
args.num_rates,
args.benchmark_kind,
args.warmup,
!args.no_console,
match run(run_config,
stop_sender_clone,
).await {
Ok(_) => {}
Expand Down
Loading

0 comments on commit e24240a

Please sign in to comment.