From dc49578eec6dadc51170c8913410b6e01918cca1 Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Sat, 13 Jan 2024 22:59:05 +0100 Subject: [PATCH] debug more switching code completion models --- src/forward_to_openai_endpoint.rs | 5 ++++- src/global_context.rs | 3 ++- src/http/routers/v1/code_completion.rs | 19 ++++++++++++------- src/restream.rs | 10 ++++++++-- 4 files changed, 26 insertions(+), 11 deletions(-) diff --git a/src/forward_to_openai_endpoint.rs b/src/forward_to_openai_endpoint.rs index b62e1cad5..aae1a50de 100644 --- a/src/forward_to_openai_endpoint.rs +++ b/src/forward_to_openai_endpoint.rs @@ -6,6 +6,7 @@ use reqwest_eventsource::EventSource; use serde_json::json; use crate::call_validation; use crate::call_validation::SamplingParameters; +use tracing::info; pub async fn forward_to_openai_style_endpoint( @@ -50,12 +51,14 @@ pub async fn forward_to_openai_style_endpoint( let response_txt = resp.text().await.map_err(|e| format!("reading from socket {}: {}", url, e) )?; - // info!("forward_to_openai_style_endpoint: {} {}\n{}", url, status_code, response_txt); // 400 "client error" is likely a json that we rather accept here, pick up error details as we analyse json fields at the level // higher, the most often 400 is no such model. if status_code != 200 && status_code != 400 { return Err(format!("{} status={} text {}", url, status_code, response_txt)); } + if status_code != 200 { + info!("forward_to_openai_style_endpoint: {} {}\n{}", url, status_code, response_txt); + } Ok(serde_json::from_str(&response_txt).unwrap()) } diff --git a/src/global_context.rs b/src/global_context.rs index cdd776419..4db7d588f 100644 --- a/src/global_context.rs +++ b/src/global_context.rs @@ -69,7 +69,8 @@ pub struct GlobalContext { pub lsp_backend_document_state: LSPBackendDocumentState, } -pub type SharedGlobalContext = Arc>; +pub type SharedGlobalContext = Arc>; // TODO: remove this type alias, confusing + const CAPS_RELOAD_BACKOFF: u64 = 60; // seconds const CAPS_BACKGROUND_RELOAD: u64 = 3600; // seconds diff --git a/src/http/routers/v1/code_completion.rs b/src/http/routers/v1/code_completion.rs index 1210945a2..1e572290e 100644 --- a/src/http/routers/v1/code_completion.rs +++ b/src/http/routers/v1/code_completion.rs @@ -1,5 +1,6 @@ use std::sync::Arc; use std::sync::RwLock as StdRwLock; +use tokio::sync::RwLock as ARwLock; use axum::Extension; use axum::response::Result; @@ -11,7 +12,7 @@ use crate::caps; use crate::caps::CodeAssistantCaps; use crate::completion_cache; use crate::custom_error::ScratchError; -use crate::global_context::SharedGlobalContext; +use crate::global_context::GlobalContext; use crate::scratchpads; async fn _lookup_code_completion_scratchpad( @@ -36,16 +37,20 @@ async fn _lookup_code_completion_scratchpad( } pub async fn handle_v1_code_completion( - global_context: SharedGlobalContext, + global_context: Arc>, code_completion_post: &mut CodeCompletionPost, ) -> Result, ScratchError> { let caps = crate::global_context::try_load_caps_quickly_if_not_present(global_context.clone(), 0).await?; - let (model_name, scratchpad_name, scratchpad_patch, n_ctx) = _lookup_code_completion_scratchpad( + let maybe = _lookup_code_completion_scratchpad( caps.clone(), &code_completion_post, - ).await.map_err(|e| { - ScratchError::new(StatusCode::BAD_REQUEST, format!("{}", e)) - })?; + ).await; + if maybe.is_err() { + // On error, this will also invalidate caps each 10 seconds, allows to overcome empty caps situation + let _ = crate::global_context::try_load_caps_quickly_if_not_present(global_context.clone(), 10).await; + return Err(ScratchError::new(StatusCode::BAD_REQUEST, format!("{}", maybe.unwrap_err()))) + } + let (model_name, scratchpad_name, scratchpad_patch, n_ctx) = maybe.unwrap(); if code_completion_post.parameters.max_new_tokens == 0 { code_completion_post.parameters.max_new_tokens = 50; } @@ -102,7 +107,7 @@ pub async fn handle_v1_code_completion( } pub async fn handle_v1_code_completion_web( - Extension(global_context): Extension, + Extension(global_context): Extension>>, body_bytes: hyper::body::Bytes, ) -> Result, ScratchError> { let mut code_completion_post = serde_json::from_slice::(&body_bytes).map_err(|e| diff --git a/src/restream.rs b/src/restream.rs index cfa236acc..284f44e98 100644 --- a/src/restream.rs +++ b/src/restream.rs @@ -149,7 +149,7 @@ pub async fn scratchpad_interaction_not_stream( } else { return Err(ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, - format!("unrecognized response: {:?}", model_says)) + format!("unrecognized response (1): {:?}", model_says)) ); } @@ -365,8 +365,14 @@ fn _push_streaming_json_into_scratchpad( } value["model"] = json!(model_name.clone()); Ok(value) + } else if let Some(err) = json.get("error") { + Err(format!("{}", err)) + } else if let Some(msg) = json.get("human_readable_message") { + Err(format!("{}", msg)) + } else if let Some(msg) = json.get("detail") { + Err(format!("{}", msg)) } else { - Err(format!("unrecognized response: {:?}", json)) + Err(format!("unrecognized response (2): {:?}", json)) } }