From ce707579e5a7c97ca58379f18a9d0e9e7ad9f961 Mon Sep 17 00:00:00 2001 From: Hugo Larcher Date: Fri, 20 Sep 2024 11:48:10 +0200 Subject: [PATCH] feat: Results filename now has data and model name. Fix quit in case of no UI mode --- src/lib.rs | 9 ++++++--- src/requests.rs | 14 +++++++++++--- src/writers.rs | 2 +- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 298a05d..506927e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -114,8 +114,8 @@ pub async fn run(run_config: RunConfiguration, timestamp: chrono::Utc::now(), level: Level::Info, })); - let filepath = requests::ConversationTextRequestGenerator::download_dataset(run_config.dataset, run_config.dataset_file,run_config.hf_token.clone()).expect("Can't download dataset"); - let requests = requests::ConversationTextRequestGenerator::load(filepath, run_config.tokenizer_name, run_config.prompt_options, run_config.decode_options, run_config.hf_token)?; + let filepath = requests::ConversationTextRequestGenerator::download_dataset(run_config.dataset, run_config.dataset_file, run_config.hf_token.clone()).expect("Can't download dataset"); + let requests = requests::ConversationTextRequestGenerator::load(filepath, run_config.tokenizer_name.clone(), run_config.prompt_options, run_config.decode_options, run_config.hf_token)?; let mut benchmark = benchmark::Benchmark::new(config.clone(), Box::new(backend), Arc::from(Mutex::from(requests)), tx.clone(), stop_sender.clone()); let mut stop_receiver = stop_sender.subscribe(); @@ -125,7 +125,7 @@ pub async fn run(run_config: RunConfiguration, Ok(results) => { info!("Throughput is {requests_throughput} req/s",requests_throughput = results.get_results()[0].successful_request_rate().unwrap()); let report = benchmark.get_report(); - let path = "results/".to_string(); + let path = format!("results/{}_{}.json", chrono::Utc::now().format("%Y-%m-%d-%H-%M-%S"),run_config.tokenizer_name.replace("/","_")); let writer=BenchmarkReportWriter::new(config.clone(), report)?; writer.json(&path).await?; }, @@ -140,6 +140,9 @@ pub async fn run(run_config: RunConfiguration, } } let _ = tx.send(Event::BenchmarkReportEnd); + if !run_config.interactive { // quit app if not interactive + let _ = stop_sender.send(()); + } ui_thread.await.unwrap(); Ok(()) } diff --git a/src/requests.rs b/src/requests.rs index 20c61d2..9f2e794 100644 --- a/src/requests.rs +++ b/src/requests.rs @@ -144,8 +144,16 @@ impl TextGenerationBackend for OpenAITextGenerationBackend { es.close(); break; } - // deserialize message data FIXME: handle JSON errors - let oai_response: OpenAITextGenerationResponse = serde_json::from_str(&message.data).unwrap(); + // deserialize message data + let oai_response: OpenAITextGenerationResponse = match serde_json::from_str(&message.data) { + Ok(response) => response, + Err(e) => { + error!("Error deserializing OpenAI API response: {e}", e = e); + aggregated_response.fail(); + es.close(); + break; + } + }; let choices = oai_response.choices; match choices[0].clone().finish_reason { None => { @@ -363,7 +371,7 @@ fn tokenize_prompt(prompt: String, tokenizer: Arc, num_tokens: Option } Some(num_tokens) => { if prompt_tokens.len() < num_tokens as usize { - return Err(anyhow::anyhow!("Prompt is too short to tokenize")); + return Err(anyhow::anyhow!(format!("Prompt is too short to tokenize: {}<{}", prompt_tokens.len(), num_tokens))); } // let's do a binary search to find the right number of tokens let mut low = 1; diff --git a/src/writers.rs b/src/writers.rs index 6b2f260..7da7303 100644 --- a/src/writers.rs +++ b/src/writers.rs @@ -116,7 +116,7 @@ impl BenchmarkReportWriter { if !std::path::Path::new(&path).exists() { fs::create_dir_all(&path).await?; } - fs::write(format!("{}/results.json", path), report).await?; + fs::write(path, report).await?; Ok(()) } } \ No newline at end of file