Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
valaises committed Jan 4, 2024
1 parent f953d1a commit d506475
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 23 deletions.
18 changes: 18 additions & 0 deletions src/restream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,24 @@ pub async fn scratchpad_interaction_stream(
let mut slowdown_scoped = SlowdownScoped::new(slowdown_arc.clone());
slowdown_scoped.be_nice_slow_down().await;
loop {
loop {
let value_maybe = scratch.response_spontaneous();
if let Ok(value) = value_maybe {
if value == json!(null) {
break;
}
let value_str = format!("data: {}\n\n", serde_json::to_string(&value).unwrap());
info!("yield: {:?}", value_str);
yield Result::<_, String>::Ok(value_str);
} else {
let err_str = value_maybe.unwrap_err();
error!("response_spontaneous error: {}", err_str);
let value_str = format!("data: {}\n\n", serde_json::to_string(&json!({"detail": err_str})).unwrap());
yield Result::<_, String>::Ok(value_str);
}
break;
}

let event_source_maybe = if endpoint_style == "hf" {
forward_to_hf_endpoint::forward_to_hf_style_endpoint_streaming(
&mut save_url,
Expand Down
2 changes: 2 additions & 0 deletions src/scratchpad_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ pub trait ScratchpadAbstract: Send {
stop_toks: bool,
stop_length: bool,
) -> Result<(serde_json::Value, bool), String>;

fn response_spontaneous(&mut self) -> Result<serde_json::Value, String>;
}


Expand Down
12 changes: 9 additions & 3 deletions src/scratchpads/chat_generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::scratchpad_abstract::HasTokenizerAndEot;
use crate::scratchpad_abstract::ScratchpadAbstract;
use crate::scratchpads::chat_utils_deltadelta::DeltaDeltaChatStreamer;
use crate::scratchpads::chat_utils_limit_history::limit_messages_history;
use crate::scratchpads::chat_utils_rag::embed_vecdb_results;
use crate::scratchpads::chat_utils_rag::{embed_vecdb_results, HasVecdb, HasVecdbResults};
use crate::vecdb::structs::VecdbSearch;

const DEBUG: bool = true;
Expand All @@ -27,6 +27,7 @@ pub struct GenericChatScratchpad<T> {
pub keyword_asst: String,
pub default_system_message: String,
pub vecdb_search: Arc<AMutex<Option<T>>>,
pub has_vecdb_results: HasVecdbResults,
}

impl<T: Send + Sync + VecdbSearch> GenericChatScratchpad<T> {
Expand All @@ -44,7 +45,8 @@ impl<T: Send + Sync + VecdbSearch> GenericChatScratchpad<T> {
keyword_user: "".to_string(),
keyword_asst: "".to_string(),
default_system_message: "".to_string(),
vecdb_search
vecdb_search,
has_vecdb_results: HasVecdbResults::new(),
}
}
}
Expand Down Expand Up @@ -85,7 +87,7 @@ impl<T: Send + Sync + VecdbSearch> ScratchpadAbstract for GenericChatScratchpad<
sampling_parameters_to_patch: &mut SamplingParameters,
) -> Result<String, String> {
match *self.vecdb_search.lock().await {
Some(ref db) => embed_vecdb_results(db, &mut self.post, 6).await,
Some(ref db) => embed_vecdb_results(db, &mut self.post, 6, &mut self.has_vecdb_results).await,
None => {}
}

Expand Down Expand Up @@ -149,5 +151,9 @@ impl<T: Send + Sync + VecdbSearch> ScratchpadAbstract for GenericChatScratchpad<
) -> Result<(serde_json::Value, bool), String> {
self.dd.response_streaming(delta, stop_toks)
}

fn response_spontaneous(&mut self) -> Result<serde_json::Value, String> {
return self.has_vecdb_results.response_streaming();
}
}

12 changes: 9 additions & 3 deletions src/scratchpads/chat_llama2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::scratchpad_abstract::HasTokenizerAndEot;
use crate::scratchpad_abstract::ScratchpadAbstract;
use crate::scratchpads::chat_utils_deltadelta::DeltaDeltaChatStreamer;
use crate::scratchpads::chat_utils_limit_history::limit_messages_history;
use crate::scratchpads::chat_utils_rag::embed_vecdb_results;
use crate::scratchpads::chat_utils_rag::{embed_vecdb_results, HasVecdb, HasVecdbResults};
use crate::vecdb::structs::VecdbSearch;

const DEBUG: bool = true;
Expand All @@ -26,6 +26,7 @@ pub struct ChatLlama2<T> {
pub keyword_slash_s: String,
pub default_system_message: String,
pub vecdb_search: Arc<AMutex<Option<T>>>,
pub has_vecdb_results: HasVecdbResults,
}


Expand All @@ -42,7 +43,8 @@ impl<T: Send + Sync + VecdbSearch> ChatLlama2<T> {
keyword_s: "<s>".to_string(),
keyword_slash_s: "</s>".to_string(),
default_system_message: "".to_string(),
vecdb_search
vecdb_search,
has_vecdb_results: HasVecdbResults::new(),
}
}
}
Expand Down Expand Up @@ -71,7 +73,7 @@ impl<T: Send + Sync + VecdbSearch> ScratchpadAbstract for ChatLlama2<T> {
sampling_parameters_to_patch: &mut SamplingParameters,
) -> Result<String, String> {
match *self.vecdb_search.lock().await {
Some(ref db) => embed_vecdb_results(db, &mut self.post, 6).await,
Some(ref db) => embed_vecdb_results(db, &mut self.post, 6, &mut self.has_vecdb_results).await,
None => {}
}
let limited_msgs: Vec<ChatMessage> = limit_messages_history(&self.t, &self.post, context_size, &self.default_system_message)?;
Expand Down Expand Up @@ -137,5 +139,9 @@ impl<T: Send + Sync + VecdbSearch> ScratchpadAbstract for ChatLlama2<T> {
) -> Result<(serde_json::Value, bool), String> {
self.dd.response_streaming(delta, stop_toks)
}

fn response_spontaneous(&mut self) -> Result<serde_json::Value, String> {
return self.has_vecdb_results.response_streaming();
}
}

9 changes: 7 additions & 2 deletions src/scratchpads/chat_passthrough.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use tracing::info;
use crate::call_validation::{ChatMessage, ChatPost, ContextFile, SamplingParameters};
use crate::scratchpad_abstract::ScratchpadAbstract;
use crate::scratchpads::chat_utils_limit_history::limit_messages_history_in_bytes;
use crate::scratchpads::chat_utils_rag::embed_vecdb_results;
use crate::scratchpads::chat_utils_rag::{embed_vecdb_results, HasVecdb, HasVecdbResults};
use crate::vecdb::structs::VecdbSearch;

const DEBUG: bool = true;
Expand All @@ -19,6 +19,7 @@ pub struct ChatPassthrough<T> {
pub default_system_message: String,
pub limit_bytes: usize,
pub vecdb_search: Arc<AMutex<Option<T>>>,
pub has_vecdb_results: HasVecdbResults,
}

const DEFAULT_LIMIT_BYTES: usize = 4096*6;
Expand All @@ -33,6 +34,7 @@ impl<T: Send + Sync + VecdbSearch> ChatPassthrough<T> {
default_system_message: "".to_string(),
limit_bytes: DEFAULT_LIMIT_BYTES, // one token translates to 3 bytes (not unicode chars)
vecdb_search,
has_vecdb_results: HasVecdbResults::new(),
}
}
}
Expand All @@ -54,7 +56,7 @@ impl<T: Send + Sync + VecdbSearch> ScratchpadAbstract for ChatPassthrough<T> {
_sampling_parameters_to_patch: &mut SamplingParameters,
) -> Result<String, String> {
match *self.vecdb_search.lock().await {
Some(ref db) => embed_vecdb_results(db, &mut self.post, 6).await,
Some(ref db) => embed_vecdb_results(db, &mut self.post, 6, &mut self.has_vecdb_results).await,
None => {}
}
let limited_msgs: Vec<ChatMessage> = limit_messages_history_in_bytes(&self.post, self.limit_bytes, &self.default_system_message)?;
Expand Down Expand Up @@ -123,4 +125,7 @@ impl<T: Send + Sync + VecdbSearch> ScratchpadAbstract for ChatPassthrough<T> {
});
Ok((ans, finished))
}
fn response_spontaneous(&mut self) -> Result<serde_json::Value, String> {
return self.has_vecdb_results.response_streaming();
}
}
118 changes: 104 additions & 14 deletions src/scratchpads/chat_utils_rag.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
use async_trait::async_trait;
use serde_json::json;
use tracing::info;

use crate::call_validation::{ChatMessage, ChatPost, ContextFile};
use crate::vecdb::structs::{SearchResult, VecdbSearch};


pub async fn embed_vecdb_results<T>(
vecdb: &T,
post: &mut ChatPost,
limit_examples_cnt: usize,
has_vecdb: &mut dyn HasVecdb,
) where T: VecdbSearch {
let latest_msg_cont = &post.messages.last().unwrap().content;
let vdb_resp = vecdb.search(latest_msg_cont.clone(), limit_examples_cnt).await;
let vdb_cont = vecdb_resp_to_prompt(&vdb_resp);
if vdb_cont.is_ok() {
post.messages = [
&post.messages[..post.messages.len() - 1],
&[ChatMessage {
role: "context_file".to_string(),
content: vdb_cont.unwrap(),
}],
&post.messages[post.messages.len() - 1..],
].concat();
}

has_vecdb.add2messages(
vdb_resp,
&mut post.messages,
).await;
}

fn vecdb_resp_to_prompt(
fn vecdb_resp_to_json(
resp: &Result<SearchResult, String>
) -> serde_json::Result<String> {
) -> serde_json::Result<serde_json::Value> {
let context_files: Vec<ContextFile> = match resp {
Ok(search_res) => {
search_res.results.iter().map(
Expand All @@ -35,5 +35,95 @@ fn vecdb_resp_to_prompt(
}
Err(_) => vec![]
};
serde_json::to_string(&context_files)
serde_json::to_value(&context_files)
}

fn vecdb_resp_to_prompt(
resp_mb: &Result<SearchResult, String>
) -> String {
let mut cont = "".to_string();

if resp_mb.is_err() {
info!("VECDB ERR");
return cont
}
let resp = resp_mb.as_ref().unwrap();
cont.push_str("CONTEXT:\n");

for res in resp.results.iter() {
cont.push_str("FILENAME:\n");
cont.push_str(res.file_path.clone().to_str().unwrap_or_else( || ""));
cont.push_str("\nTEXT:");
cont.push_str(res.window_text.clone().as_str());
cont.push_str("\n");
}
cont.push_str("\nRefer to the context to answer my next question.\n");
info!("VECDB prompt:\n{}", cont);
cont
}


pub struct HasVecdbResults {
pub was_sent: bool,
pub in_json: serde_json::Value,
}

impl HasVecdbResults {
pub fn new() -> Self {
HasVecdbResults {
was_sent: false,
in_json: json!(null)
}
}
}

#[async_trait]
pub trait HasVecdb: Send {
async fn add2messages(
&mut self,
vdb_result_mb: Result<SearchResult, String>,
messages: &mut Vec<ChatMessage>,
);
fn response_streaming(&mut self) -> Result<serde_json::Value, String>;
}

#[async_trait]
impl HasVecdb for HasVecdbResults {
async fn add2messages(
&mut self,
result_mb: Result<SearchResult, String>,
messages: &mut Vec<ChatMessage>,
) {
// if messages.len() > 1 {
// return;
// }

*messages = [
&messages[..messages.len() -1],
&[ChatMessage {
role: "user".to_string(),
content: vecdb_resp_to_prompt(&result_mb),
}],
&messages[messages.len() -1..],
].concat();

self.in_json = vecdb_resp_to_json(&result_mb).unwrap_or_else(|_| json!(null));
}

fn response_streaming(&mut self) -> Result<serde_json::Value, String> {
if self.was_sent == true || self.in_json.is_null() {
return Ok(json!(null));
}
self.was_sent = true;
return Ok(json!({
"choices": [{
"delta": {
"content": self.in_json.clone(),
"role": "context_file"
},
"finish_reason": serde_json::Value::Null,
"index": 0
}],
}));
}
}
3 changes: 3 additions & 0 deletions src/scratchpads/completion_single_file_fim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,9 @@ impl<T: Send + Sync + VecdbSearch> ScratchpadAbstract for SingleFileFIM<T> {
});
Ok((ans, finished))
}
fn response_spontaneous(&mut self) -> Result<serde_json::Value, String> {
return Err("".to_string());
}
}

fn get_context_near_cursor(text: &Rope, line_pos: usize, max_lines_count: usize) -> String {
Expand Down
4 changes: 3 additions & 1 deletion src/vecdb/vecdb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pub struct VecDb {

model_name: String,
endpoint_template: String,
endpoint_embeddings_style: String,
}


Expand Down Expand Up @@ -176,6 +177,7 @@ impl VecDb {

model_name,
endpoint_template,
endpoint_embeddings_style,
})
}

Expand Down Expand Up @@ -206,7 +208,7 @@ impl VecDb {
impl VecdbSearch for VecDb {
async fn search(&self, query: String, top_n: usize) -> Result<SearchResult, String> {
let embedding_mb = get_embedding(
&self.cmdline.address_url,
&self.endpoint_embeddings_style,
&self.model_name,
&self.endpoint_template,
query.clone(),
Expand Down

0 comments on commit d506475

Please sign in to comment.