Skip to content

Commit

Permalink
pass correct stuff to predibase-reporter (#635)
Browse files Browse the repository at this point in the history
  • Loading branch information
magdyksaleh authored Oct 8, 2024
1 parent 0c1cec2 commit f1ef0ee
Showing 1 changed file with 43 additions and 10 deletions.
53 changes: 43 additions & 10 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ async fn compat_generate(
default_return_full_text: Extension<bool>,
infer: Extension<Infer>,
info: Extension<Info>,
request_logger_sender: Extension<Arc<mpsc::Sender<(i64, String, String)>>>,
request_logger_sender: Extension<
Arc<mpsc::Sender<(i64, String, String, String, String, String)>>,
>,
req_headers: HeaderMap,
req: Json<CompatGenerateRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
Expand Down Expand Up @@ -140,7 +142,9 @@ async fn completions_v1(
default_return_full_text: Extension<bool>,
infer: Extension<Infer>,
info: Extension<Info>,
request_logger_sender: Extension<Arc<mpsc::Sender<(i64, String, String)>>>,
request_logger_sender: Extension<
Arc<mpsc::Sender<(i64, String, String, String, String, String)>>,
>,
req_headers: HeaderMap,
req: Json<CompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
Expand Down Expand Up @@ -223,7 +227,9 @@ async fn chat_completions_v1(
default_return_full_text: Extension<bool>,
infer: Extension<Infer>,
info: Extension<Info>,
request_logger_sender: Extension<Arc<mpsc::Sender<(i64, String, String)>>>,
request_logger_sender: Extension<
Arc<mpsc::Sender<(i64, String, String, String, String, String)>>,
>,
req_headers: HeaderMap,
req: Json<ChatCompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
Expand Down Expand Up @@ -473,7 +479,9 @@ seed,
async fn generate(
infer: Extension<Infer>,
info: Extension<Info>,
request_logger_sender: Extension<Arc<mpsc::Sender<(i64, String, String)>>>,
request_logger_sender: Extension<
Arc<mpsc::Sender<(i64, String, String, String, String, String)>>,
>,
req_headers: HeaderMap,
mut req: Json<GenerateRequest>,
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
Expand All @@ -489,6 +497,8 @@ async fn generate(
add_prompt = Some(req.0.inputs.clone());
}

let inputs = req.0.inputs.clone();

let details = req.0.parameters.details || req.0.parameters.decoder_input_details;
let (adapter_source, adapter_parameters) = extract_adapter_params(
req.0.parameters.adapter_id.clone(),
Expand Down Expand Up @@ -662,6 +672,9 @@ async fn generate(
let _ = request_logger_sender
.send((
total_tokens as i64,
adapter_id_string,
inputs,
response.generated_text.text.clone(),
api_token.unwrap_or("".to_string()),
info.model_id.clone(),
))
Expand Down Expand Up @@ -722,7 +735,9 @@ seed,
async fn generate_stream(
infer: Extension<Infer>,
info: Extension<Info>,
request_logger_sender: Extension<Arc<mpsc::Sender<(i64, String, String)>>>,
request_logger_sender: Extension<
Arc<mpsc::Sender<(i64, String, String, String, String, String)>>,
>,
req_headers: HeaderMap,
req: Json<GenerateRequest>,
) -> (
Expand All @@ -745,7 +760,9 @@ async fn generate_stream(
async fn generate_stream_with_callback(
infer: Extension<Infer>,
info: Extension<Info>,
request_logger_sender: Extension<Arc<mpsc::Sender<(i64, String, String)>>>,
request_logger_sender: Extension<
Arc<mpsc::Sender<(i64, String, String, String, String, String)>>,
>,
req_headers: HeaderMap,
mut req: Json<GenerateRequest>,
callback: impl Fn(StreamResponse) -> Event,
Expand Down Expand Up @@ -813,6 +830,7 @@ async fn generate_stream_with_callback(
if req.0.parameters.return_full_text.unwrap_or(false) {
add_prompt = Some(req.0.inputs.clone());
}
let inputs = req.0.inputs.clone();
let details = req.0.parameters.details;

let best_of = req.0.parameters.best_of.unwrap_or(1);
Expand Down Expand Up @@ -914,7 +932,15 @@ async fn generate_stream_with_callback(

let total_tokens = generated_text.generated_tokens + prefill_tokens_length;
if info.request_logger_url.is_some() {
let _ = request_logger_sender.send((total_tokens as i64, api_token.unwrap_or("".to_string()), info.model_id.clone())).await;
let _ = request_logger_sender.send((
total_tokens as i64,
adapter_id_string,
inputs,
output_text.clone(),
api_token.unwrap_or("".to_string()),
info.model_id.clone(),
))
.await;
}

let stream_token = StreamResponse {
Expand Down Expand Up @@ -991,7 +1017,7 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {

async fn request_logger(
request_logger_url: Option<String>,
mut rx: mpsc::Receiver<(i64, String, String)>,
mut rx: mpsc::Receiver<(i64, String, String, String, String, String)>,
) {
if request_logger_url.is_none() {
tracing::info!("REQUEST_LOGGER_URL not set, request logging is disabled");
Expand All @@ -1005,11 +1031,18 @@ async fn request_logger(
let client = ClientBuilder::new(reqwest::Client::new())
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
.build();
while let Some((tokens, api_token, model_id)) = rx.recv().await {
while let Some((tokens, adapter_id, input, output, api_token, model_id)) = rx.recv().await {
// Make a request out to localhost:8899 with the tokens, api_token, and model_id
let res = client
.post(&url_string)
.json(&json!({"tokens": tokens, "api_token": api_token, "model_id": model_id}))
.json(&json!({
"tokens": tokens,
"adapter_id": adapter_id,
"input": input,
"output": output,
"api_token": api_token,
"model_id": model_id
}))
.send()
.await;

Expand Down

0 comments on commit f1ef0ee

Please sign in to comment.