From f1ef0ee66bda906329878043789fe757cfc8ccd3 Mon Sep 17 00:00:00 2001 From: Magdy Saleh <17618143+magdyksaleh@users.noreply.github.com> Date: Tue, 8 Oct 2024 15:09:00 -0400 Subject: [PATCH] pass correct stuff to predibase-reporter (#635) --- router/src/server.rs | 53 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 43 insertions(+), 10 deletions(-) diff --git a/router/src/server.rs b/router/src/server.rs index c2bbf9acb..7a04a6d6f 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -75,7 +75,9 @@ async fn compat_generate( default_return_full_text: Extension, infer: Extension, info: Extension, - request_logger_sender: Extension>>, + request_logger_sender: Extension< + Arc>, + >, req_headers: HeaderMap, req: Json, ) -> Result)> { @@ -140,7 +142,9 @@ async fn completions_v1( default_return_full_text: Extension, infer: Extension, info: Extension, - request_logger_sender: Extension>>, + request_logger_sender: Extension< + Arc>, + >, req_headers: HeaderMap, req: Json, ) -> Result)> { @@ -223,7 +227,9 @@ async fn chat_completions_v1( default_return_full_text: Extension, infer: Extension, info: Extension, - request_logger_sender: Extension>>, + request_logger_sender: Extension< + Arc>, + >, req_headers: HeaderMap, req: Json, ) -> Result)> { @@ -473,7 +479,9 @@ seed, async fn generate( infer: Extension, info: Extension, - request_logger_sender: Extension>>, + request_logger_sender: Extension< + Arc>, + >, req_headers: HeaderMap, mut req: Json, ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { @@ -489,6 +497,8 @@ async fn generate( add_prompt = Some(req.0.inputs.clone()); } + let inputs = req.0.inputs.clone(); + let details = req.0.parameters.details || req.0.parameters.decoder_input_details; let (adapter_source, adapter_parameters) = extract_adapter_params( req.0.parameters.adapter_id.clone(), @@ -662,6 +672,9 @@ async fn generate( let _ = request_logger_sender .send(( total_tokens as i64, + adapter_id_string, + inputs, + response.generated_text.text.clone(), api_token.unwrap_or("".to_string()), info.model_id.clone(), )) @@ -722,7 +735,9 @@ seed, async fn generate_stream( infer: Extension, info: Extension, - request_logger_sender: Extension>>, + request_logger_sender: Extension< + Arc>, + >, req_headers: HeaderMap, req: Json, ) -> ( @@ -745,7 +760,9 @@ async fn generate_stream( async fn generate_stream_with_callback( infer: Extension, info: Extension, - request_logger_sender: Extension>>, + request_logger_sender: Extension< + Arc>, + >, req_headers: HeaderMap, mut req: Json, callback: impl Fn(StreamResponse) -> Event, @@ -813,6 +830,7 @@ async fn generate_stream_with_callback( if req.0.parameters.return_full_text.unwrap_or(false) { add_prompt = Some(req.0.inputs.clone()); } + let inputs = req.0.inputs.clone(); let details = req.0.parameters.details; let best_of = req.0.parameters.best_of.unwrap_or(1); @@ -914,7 +932,15 @@ async fn generate_stream_with_callback( let total_tokens = generated_text.generated_tokens + prefill_tokens_length; if info.request_logger_url.is_some() { - let _ = request_logger_sender.send((total_tokens as i64, api_token.unwrap_or("".to_string()), info.model_id.clone())).await; + let _ = request_logger_sender.send(( + total_tokens as i64, + adapter_id_string, + inputs, + output_text.clone(), + api_token.unwrap_or("".to_string()), + info.model_id.clone(), + )) + .await; } let stream_token = StreamResponse { @@ -991,7 +1017,7 @@ async fn metrics(prom_handle: Extension) -> String { async fn request_logger( request_logger_url: Option, - mut rx: mpsc::Receiver<(i64, String, String)>, + mut rx: mpsc::Receiver<(i64, String, String, String, String, String)>, ) { if request_logger_url.is_none() { tracing::info!("REQUEST_LOGGER_URL not set, request logging is disabled"); @@ -1005,11 +1031,18 @@ async fn request_logger( let client = ClientBuilder::new(reqwest::Client::new()) .with(RetryTransientMiddleware::new_with_policy(retry_policy)) .build(); - while let Some((tokens, api_token, model_id)) = rx.recv().await { + while let Some((tokens, adapter_id, input, output, api_token, model_id)) = rx.recv().await { // Make a request out to localhost:8899 with the tokens, api_token, and model_id let res = client .post(&url_string) - .json(&json!({"tokens": tokens, "api_token": api_token, "model_id": model_id})) + .json(&json!({ + "tokens": tokens, + "adapter_id": adapter_id, + "input": input, + "output": output, + "api_token": api_token, + "model_id": model_id + })) .send() .await;