Skip to content

Commit

Permalink
feat: Results filename now has data and model name. Fix quit in case …
Browse files Browse the repository at this point in the history
…of no UI mode
  • Loading branch information
Hugoch committed Sep 20, 2024
1 parent 5ddfeab commit ce70757
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 7 deletions.
9 changes: 6 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ pub async fn run(run_config: RunConfiguration,
timestamp: chrono::Utc::now(),
level: Level::Info,
}));
let filepath = requests::ConversationTextRequestGenerator::download_dataset(run_config.dataset, run_config.dataset_file,run_config.hf_token.clone()).expect("Can't download dataset");
let requests = requests::ConversationTextRequestGenerator::load(filepath, run_config.tokenizer_name, run_config.prompt_options, run_config.decode_options, run_config.hf_token)?;
let filepath = requests::ConversationTextRequestGenerator::download_dataset(run_config.dataset, run_config.dataset_file, run_config.hf_token.clone()).expect("Can't download dataset");
let requests = requests::ConversationTextRequestGenerator::load(filepath, run_config.tokenizer_name.clone(), run_config.prompt_options, run_config.decode_options, run_config.hf_token)?;

let mut benchmark = benchmark::Benchmark::new(config.clone(), Box::new(backend), Arc::from(Mutex::from(requests)), tx.clone(), stop_sender.clone());
let mut stop_receiver = stop_sender.subscribe();
Expand All @@ -125,7 +125,7 @@ pub async fn run(run_config: RunConfiguration,
Ok(results) => {
info!("Throughput is {requests_throughput} req/s",requests_throughput = results.get_results()[0].successful_request_rate().unwrap());
let report = benchmark.get_report();
let path = "results/".to_string();
let path = format!("results/{}_{}.json", chrono::Utc::now().format("%Y-%m-%d-%H-%M-%S"),run_config.tokenizer_name.replace("/","_"));
let writer=BenchmarkReportWriter::new(config.clone(), report)?;
writer.json(&path).await?;
},
Expand All @@ -140,6 +140,9 @@ pub async fn run(run_config: RunConfiguration,
}
}
let _ = tx.send(Event::BenchmarkReportEnd);
if !run_config.interactive { // quit app if not interactive
let _ = stop_sender.send(());
}
ui_thread.await.unwrap();
Ok(())
}
14 changes: 11 additions & 3 deletions src/requests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,16 @@ impl TextGenerationBackend for OpenAITextGenerationBackend {
es.close();
break;
}
// deserialize message data FIXME: handle JSON errors
let oai_response: OpenAITextGenerationResponse = serde_json::from_str(&message.data).unwrap();
// deserialize message data
let oai_response: OpenAITextGenerationResponse = match serde_json::from_str(&message.data) {
Ok(response) => response,
Err(e) => {
error!("Error deserializing OpenAI API response: {e}", e = e);
aggregated_response.fail();
es.close();
break;
}
};
let choices = oai_response.choices;
match choices[0].clone().finish_reason {
None => {
Expand Down Expand Up @@ -363,7 +371,7 @@ fn tokenize_prompt(prompt: String, tokenizer: Arc<Tokenizer>, num_tokens: Option
}
Some(num_tokens) => {
if prompt_tokens.len() < num_tokens as usize {
return Err(anyhow::anyhow!("Prompt is too short to tokenize"));
return Err(anyhow::anyhow!(format!("Prompt is too short to tokenize: {}<{}", prompt_tokens.len(), num_tokens)));
}
// let's do a binary search to find the right number of tokens
let mut low = 1;
Expand Down
2 changes: 1 addition & 1 deletion src/writers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ impl BenchmarkReportWriter {
if !std::path::Path::new(&path).exists() {
fs::create_dir_all(&path).await?;
}
fs::write(format!("{}/results.json", path), report).await?;
fs::write(path, report).await?;
Ok(())
}
}

0 comments on commit ce70757

Please sign in to comment.