From d50647588f47ad9040e107417489aab4dd13ddae Mon Sep 17 00:00:00 2001 From: Valeryi Date: Thu, 4 Jan 2024 16:31:57 +0000 Subject: [PATCH] init --- src/restream.rs | 18 +++ src/scratchpad_abstract.rs | 2 + src/scratchpads/chat_generic.rs | 12 +- src/scratchpads/chat_llama2.rs | 12 +- src/scratchpads/chat_passthrough.rs | 9 +- src/scratchpads/chat_utils_rag.rs | 118 +++++++++++++++--- src/scratchpads/completion_single_file_fim.rs | 3 + src/vecdb/vecdb.rs | 4 +- 8 files changed, 155 insertions(+), 23 deletions(-) diff --git a/src/restream.rs b/src/restream.rs index dae5dffff..c7053a9e5 100644 --- a/src/restream.rs +++ b/src/restream.rs @@ -188,6 +188,24 @@ pub async fn scratchpad_interaction_stream( let mut slowdown_scoped = SlowdownScoped::new(slowdown_arc.clone()); slowdown_scoped.be_nice_slow_down().await; loop { + loop { + let value_maybe = scratch.response_spontaneous(); + if let Ok(value) = value_maybe { + if value == json!(null) { + break; + } + let value_str = format!("data: {}\n\n", serde_json::to_string(&value).unwrap()); + info!("yield: {:?}", value_str); + yield Result::<_, String>::Ok(value_str); + } else { + let err_str = value_maybe.unwrap_err(); + error!("response_spontaneous error: {}", err_str); + let value_str = format!("data: {}\n\n", serde_json::to_string(&json!({"detail": err_str})).unwrap()); + yield Result::<_, String>::Ok(value_str); + } + break; + } + let event_source_maybe = if endpoint_style == "hf" { forward_to_hf_endpoint::forward_to_hf_style_endpoint_streaming( &mut save_url, diff --git a/src/scratchpad_abstract.rs b/src/scratchpad_abstract.rs index 109c31d6f..9d2ce9e08 100644 --- a/src/scratchpad_abstract.rs +++ b/src/scratchpad_abstract.rs @@ -31,6 +31,8 @@ pub trait ScratchpadAbstract: Send { stop_toks: bool, stop_length: bool, ) -> Result<(serde_json::Value, bool), String>; + + fn response_spontaneous(&mut self) -> Result; } diff --git a/src/scratchpads/chat_generic.rs b/src/scratchpads/chat_generic.rs index 4a94dcc7a..31c3f74a6 100644 --- a/src/scratchpads/chat_generic.rs +++ b/src/scratchpads/chat_generic.rs @@ -11,7 +11,7 @@ use crate::scratchpad_abstract::HasTokenizerAndEot; use crate::scratchpad_abstract::ScratchpadAbstract; use crate::scratchpads::chat_utils_deltadelta::DeltaDeltaChatStreamer; use crate::scratchpads::chat_utils_limit_history::limit_messages_history; -use crate::scratchpads::chat_utils_rag::embed_vecdb_results; +use crate::scratchpads::chat_utils_rag::{embed_vecdb_results, HasVecdb, HasVecdbResults}; use crate::vecdb::structs::VecdbSearch; const DEBUG: bool = true; @@ -27,6 +27,7 @@ pub struct GenericChatScratchpad { pub keyword_asst: String, pub default_system_message: String, pub vecdb_search: Arc>>, + pub has_vecdb_results: HasVecdbResults, } impl GenericChatScratchpad { @@ -44,7 +45,8 @@ impl GenericChatScratchpad { keyword_user: "".to_string(), keyword_asst: "".to_string(), default_system_message: "".to_string(), - vecdb_search + vecdb_search, + has_vecdb_results: HasVecdbResults::new(), } } } @@ -85,7 +87,7 @@ impl ScratchpadAbstract for GenericChatScratchpad< sampling_parameters_to_patch: &mut SamplingParameters, ) -> Result { match *self.vecdb_search.lock().await { - Some(ref db) => embed_vecdb_results(db, &mut self.post, 6).await, + Some(ref db) => embed_vecdb_results(db, &mut self.post, 6, &mut self.has_vecdb_results).await, None => {} } @@ -149,5 +151,9 @@ impl ScratchpadAbstract for GenericChatScratchpad< ) -> Result<(serde_json::Value, bool), String> { self.dd.response_streaming(delta, stop_toks) } + + fn response_spontaneous(&mut self) -> Result { + return self.has_vecdb_results.response_streaming(); + } } diff --git a/src/scratchpads/chat_llama2.rs b/src/scratchpads/chat_llama2.rs index 8899a8998..36a06c512 100644 --- a/src/scratchpads/chat_llama2.rs +++ b/src/scratchpads/chat_llama2.rs @@ -11,7 +11,7 @@ use crate::scratchpad_abstract::HasTokenizerAndEot; use crate::scratchpad_abstract::ScratchpadAbstract; use crate::scratchpads::chat_utils_deltadelta::DeltaDeltaChatStreamer; use crate::scratchpads::chat_utils_limit_history::limit_messages_history; -use crate::scratchpads::chat_utils_rag::embed_vecdb_results; +use crate::scratchpads::chat_utils_rag::{embed_vecdb_results, HasVecdb, HasVecdbResults}; use crate::vecdb::structs::VecdbSearch; const DEBUG: bool = true; @@ -26,6 +26,7 @@ pub struct ChatLlama2 { pub keyword_slash_s: String, pub default_system_message: String, pub vecdb_search: Arc>>, + pub has_vecdb_results: HasVecdbResults, } @@ -42,7 +43,8 @@ impl ChatLlama2 { keyword_s: "".to_string(), keyword_slash_s: "".to_string(), default_system_message: "".to_string(), - vecdb_search + vecdb_search, + has_vecdb_results: HasVecdbResults::new(), } } } @@ -71,7 +73,7 @@ impl ScratchpadAbstract for ChatLlama2 { sampling_parameters_to_patch: &mut SamplingParameters, ) -> Result { match *self.vecdb_search.lock().await { - Some(ref db) => embed_vecdb_results(db, &mut self.post, 6).await, + Some(ref db) => embed_vecdb_results(db, &mut self.post, 6, &mut self.has_vecdb_results).await, None => {} } let limited_msgs: Vec = limit_messages_history(&self.t, &self.post, context_size, &self.default_system_message)?; @@ -137,5 +139,9 @@ impl ScratchpadAbstract for ChatLlama2 { ) -> Result<(serde_json::Value, bool), String> { self.dd.response_streaming(delta, stop_toks) } + + fn response_spontaneous(&mut self) -> Result { + return self.has_vecdb_results.response_streaming(); + } } diff --git a/src/scratchpads/chat_passthrough.rs b/src/scratchpads/chat_passthrough.rs index eaf51de54..c6b7af461 100644 --- a/src/scratchpads/chat_passthrough.rs +++ b/src/scratchpads/chat_passthrough.rs @@ -7,7 +7,7 @@ use tracing::info; use crate::call_validation::{ChatMessage, ChatPost, ContextFile, SamplingParameters}; use crate::scratchpad_abstract::ScratchpadAbstract; use crate::scratchpads::chat_utils_limit_history::limit_messages_history_in_bytes; -use crate::scratchpads::chat_utils_rag::embed_vecdb_results; +use crate::scratchpads::chat_utils_rag::{embed_vecdb_results, HasVecdb, HasVecdbResults}; use crate::vecdb::structs::VecdbSearch; const DEBUG: bool = true; @@ -19,6 +19,7 @@ pub struct ChatPassthrough { pub default_system_message: String, pub limit_bytes: usize, pub vecdb_search: Arc>>, + pub has_vecdb_results: HasVecdbResults, } const DEFAULT_LIMIT_BYTES: usize = 4096*6; @@ -33,6 +34,7 @@ impl ChatPassthrough { default_system_message: "".to_string(), limit_bytes: DEFAULT_LIMIT_BYTES, // one token translates to 3 bytes (not unicode chars) vecdb_search, + has_vecdb_results: HasVecdbResults::new(), } } } @@ -54,7 +56,7 @@ impl ScratchpadAbstract for ChatPassthrough { _sampling_parameters_to_patch: &mut SamplingParameters, ) -> Result { match *self.vecdb_search.lock().await { - Some(ref db) => embed_vecdb_results(db, &mut self.post, 6).await, + Some(ref db) => embed_vecdb_results(db, &mut self.post, 6, &mut self.has_vecdb_results).await, None => {} } let limited_msgs: Vec = limit_messages_history_in_bytes(&self.post, self.limit_bytes, &self.default_system_message)?; @@ -123,4 +125,7 @@ impl ScratchpadAbstract for ChatPassthrough { }); Ok((ans, finished)) } + fn response_spontaneous(&mut self) -> Result { + return self.has_vecdb_results.response_streaming(); + } } diff --git a/src/scratchpads/chat_utils_rag.rs b/src/scratchpads/chat_utils_rag.rs index bc658f896..7f8b79867 100644 --- a/src/scratchpads/chat_utils_rag.rs +++ b/src/scratchpads/chat_utils_rag.rs @@ -1,29 +1,29 @@ +use async_trait::async_trait; +use serde_json::json; +use tracing::info; + use crate::call_validation::{ChatMessage, ChatPost, ContextFile}; use crate::vecdb::structs::{SearchResult, VecdbSearch}; + pub async fn embed_vecdb_results( vecdb: &T, post: &mut ChatPost, limit_examples_cnt: usize, + has_vecdb: &mut dyn HasVecdb, ) where T: VecdbSearch { let latest_msg_cont = &post.messages.last().unwrap().content; let vdb_resp = vecdb.search(latest_msg_cont.clone(), limit_examples_cnt).await; - let vdb_cont = vecdb_resp_to_prompt(&vdb_resp); - if vdb_cont.is_ok() { - post.messages = [ - &post.messages[..post.messages.len() - 1], - &[ChatMessage { - role: "context_file".to_string(), - content: vdb_cont.unwrap(), - }], - &post.messages[post.messages.len() - 1..], - ].concat(); - } + + has_vecdb.add2messages( + vdb_resp, + &mut post.messages, + ).await; } -fn vecdb_resp_to_prompt( +fn vecdb_resp_to_json( resp: &Result -) -> serde_json::Result { +) -> serde_json::Result { let context_files: Vec = match resp { Ok(search_res) => { search_res.results.iter().map( @@ -35,5 +35,95 @@ fn vecdb_resp_to_prompt( } Err(_) => vec![] }; - serde_json::to_string(&context_files) + serde_json::to_value(&context_files) +} + +fn vecdb_resp_to_prompt( + resp_mb: &Result +) -> String { + let mut cont = "".to_string(); + + if resp_mb.is_err() { + info!("VECDB ERR"); + return cont + } + let resp = resp_mb.as_ref().unwrap(); + cont.push_str("CONTEXT:\n"); + + for res in resp.results.iter() { + cont.push_str("FILENAME:\n"); + cont.push_str(res.file_path.clone().to_str().unwrap_or_else( || "")); + cont.push_str("\nTEXT:"); + cont.push_str(res.window_text.clone().as_str()); + cont.push_str("\n"); + } + cont.push_str("\nRefer to the context to answer my next question.\n"); + info!("VECDB prompt:\n{}", cont); + cont +} + + +pub struct HasVecdbResults { + pub was_sent: bool, + pub in_json: serde_json::Value, +} + +impl HasVecdbResults { + pub fn new() -> Self { + HasVecdbResults { + was_sent: false, + in_json: json!(null) + } + } +} + +#[async_trait] +pub trait HasVecdb: Send { + async fn add2messages( + &mut self, + vdb_result_mb: Result, + messages: &mut Vec, + ); + fn response_streaming(&mut self) -> Result; +} + +#[async_trait] +impl HasVecdb for HasVecdbResults { + async fn add2messages( + &mut self, + result_mb: Result, + messages: &mut Vec, + ) { + // if messages.len() > 1 { + // return; + // } + + *messages = [ + &messages[..messages.len() -1], + &[ChatMessage { + role: "user".to_string(), + content: vecdb_resp_to_prompt(&result_mb), + }], + &messages[messages.len() -1..], + ].concat(); + + self.in_json = vecdb_resp_to_json(&result_mb).unwrap_or_else(|_| json!(null)); + } + + fn response_streaming(&mut self) -> Result { + if self.was_sent == true || self.in_json.is_null() { + return Ok(json!(null)); + } + self.was_sent = true; + return Ok(json!({ + "choices": [{ + "delta": { + "content": self.in_json.clone(), + "role": "context_file" + }, + "finish_reason": serde_json::Value::Null, + "index": 0 + }], + })); + } } diff --git a/src/scratchpads/completion_single_file_fim.rs b/src/scratchpads/completion_single_file_fim.rs index da81fbbd4..7ba20c21c 100644 --- a/src/scratchpads/completion_single_file_fim.rs +++ b/src/scratchpads/completion_single_file_fim.rs @@ -285,6 +285,9 @@ impl ScratchpadAbstract for SingleFileFIM { }); Ok((ans, finished)) } + fn response_spontaneous(&mut self) -> Result { + return Err("".to_string()); + } } fn get_context_near_cursor(text: &Rope, line_pos: usize, max_lines_count: usize) -> String { diff --git a/src/vecdb/vecdb.rs b/src/vecdb/vecdb.rs index b1a2e4e24..7df70cab3 100644 --- a/src/vecdb/vecdb.rs +++ b/src/vecdb/vecdb.rs @@ -24,6 +24,7 @@ pub struct VecDb { model_name: String, endpoint_template: String, + endpoint_embeddings_style: String, } @@ -176,6 +177,7 @@ impl VecDb { model_name, endpoint_template, + endpoint_embeddings_style, }) } @@ -206,7 +208,7 @@ impl VecDb { impl VecdbSearch for VecDb { async fn search(&self, query: String, top_n: usize) -> Result { let embedding_mb = get_embedding( - &self.cmdline.address_url, + &self.endpoint_embeddings_style, &self.model_name, &self.endpoint_template, query.clone(),