Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Context File Message Support for Chats #59

Merged
merged 1 commit into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/restream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,24 @@ pub async fn scratchpad_interaction_stream(
let mut slowdown_scoped = SlowdownScoped::new(slowdown_arc.clone());
slowdown_scoped.be_nice_slow_down().await;
loop {
loop {
let value_maybe = scratch.response_spontaneous();
if let Ok(value) = value_maybe {
if value == json!(null) {
break;
}
let value_str = format!("data: {}\n\n", serde_json::to_string(&value).unwrap());
info!("yield: {:?}", value_str);
yield Result::<_, String>::Ok(value_str);
} else {
let err_str = value_maybe.unwrap_err();
error!("response_spontaneous error: {}", err_str);
let value_str = format!("data: {}\n\n", serde_json::to_string(&json!({"detail": err_str})).unwrap());
yield Result::<_, String>::Ok(value_str);
}
break;
}

let event_source_maybe = if endpoint_style == "hf" {
forward_to_hf_endpoint::forward_to_hf_style_endpoint_streaming(
&mut save_url,
Expand Down
2 changes: 2 additions & 0 deletions src/scratchpad_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ pub trait ScratchpadAbstract: Send {
stop_toks: bool,
stop_length: bool,
) -> Result<(serde_json::Value, bool), String>;

fn response_spontaneous(&mut self) -> Result<serde_json::Value, String>;
}


Expand Down
12 changes: 9 additions & 3 deletions src/scratchpads/chat_generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::scratchpad_abstract::HasTokenizerAndEot;
use crate::scratchpad_abstract::ScratchpadAbstract;
use crate::scratchpads::chat_utils_deltadelta::DeltaDeltaChatStreamer;
use crate::scratchpads::chat_utils_limit_history::limit_messages_history;
use crate::scratchpads::chat_utils_rag::embed_vecdb_results;
use crate::scratchpads::chat_utils_rag::{embed_vecdb_results, HasVecdb, HasVecdbResults};
use crate::vecdb::structs::VecdbSearch;

const DEBUG: bool = true;
Expand All @@ -27,6 +27,7 @@ pub struct GenericChatScratchpad<T> {
pub keyword_asst: String,
pub default_system_message: String,
pub vecdb_search: Arc<AMutex<Option<T>>>,
pub has_vecdb_results: HasVecdbResults,
}

impl<T: Send + Sync + VecdbSearch> GenericChatScratchpad<T> {
Expand All @@ -44,7 +45,8 @@ impl<T: Send + Sync + VecdbSearch> GenericChatScratchpad<T> {
keyword_user: "".to_string(),
keyword_asst: "".to_string(),
default_system_message: "".to_string(),
vecdb_search
vecdb_search,
has_vecdb_results: HasVecdbResults::new(),
}
}
}
Expand Down Expand Up @@ -85,7 +87,7 @@ impl<T: Send + Sync + VecdbSearch> ScratchpadAbstract for GenericChatScratchpad<
sampling_parameters_to_patch: &mut SamplingParameters,
) -> Result<String, String> {
match *self.vecdb_search.lock().await {
Some(ref db) => embed_vecdb_results(db, &mut self.post, 6).await,
Some(ref db) => embed_vecdb_results(db, &mut self.post, 6, &mut self.has_vecdb_results).await,
None => {}
}

Expand Down Expand Up @@ -149,5 +151,9 @@ impl<T: Send + Sync + VecdbSearch> ScratchpadAbstract for GenericChatScratchpad<
) -> Result<(serde_json::Value, bool), String> {
self.dd.response_streaming(delta, stop_toks)
}

fn response_spontaneous(&mut self) -> Result<serde_json::Value, String> {
return self.has_vecdb_results.response_streaming();
}
}

12 changes: 9 additions & 3 deletions src/scratchpads/chat_llama2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::scratchpad_abstract::HasTokenizerAndEot;
use crate::scratchpad_abstract::ScratchpadAbstract;
use crate::scratchpads::chat_utils_deltadelta::DeltaDeltaChatStreamer;
use crate::scratchpads::chat_utils_limit_history::limit_messages_history;
use crate::scratchpads::chat_utils_rag::embed_vecdb_results;
use crate::scratchpads::chat_utils_rag::{embed_vecdb_results, HasVecdb, HasVecdbResults};
use crate::vecdb::structs::VecdbSearch;

const DEBUG: bool = true;
Expand All @@ -26,6 +26,7 @@ pub struct ChatLlama2<T> {
pub keyword_slash_s: String,
pub default_system_message: String,
pub vecdb_search: Arc<AMutex<Option<T>>>,
pub has_vecdb_results: HasVecdbResults,
}


Expand All @@ -42,7 +43,8 @@ impl<T: Send + Sync + VecdbSearch> ChatLlama2<T> {
keyword_s: "<s>".to_string(),
keyword_slash_s: "</s>".to_string(),
default_system_message: "".to_string(),
vecdb_search
vecdb_search,
has_vecdb_results: HasVecdbResults::new(),
}
}
}
Expand Down Expand Up @@ -71,7 +73,7 @@ impl<T: Send + Sync + VecdbSearch> ScratchpadAbstract for ChatLlama2<T> {
sampling_parameters_to_patch: &mut SamplingParameters,
) -> Result<String, String> {
match *self.vecdb_search.lock().await {
Some(ref db) => embed_vecdb_results(db, &mut self.post, 6).await,
Some(ref db) => embed_vecdb_results(db, &mut self.post, 6, &mut self.has_vecdb_results).await,
None => {}
}
let limited_msgs: Vec<ChatMessage> = limit_messages_history(&self.t, &self.post, context_size, &self.default_system_message)?;
Expand Down Expand Up @@ -137,5 +139,9 @@ impl<T: Send + Sync + VecdbSearch> ScratchpadAbstract for ChatLlama2<T> {
) -> Result<(serde_json::Value, bool), String> {
self.dd.response_streaming(delta, stop_toks)
}

fn response_spontaneous(&mut self) -> Result<serde_json::Value, String> {
return self.has_vecdb_results.response_streaming();
}
}

9 changes: 7 additions & 2 deletions src/scratchpads/chat_passthrough.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use tracing::info;
use crate::call_validation::{ChatMessage, ChatPost, ContextFile, SamplingParameters};
use crate::scratchpad_abstract::ScratchpadAbstract;
use crate::scratchpads::chat_utils_limit_history::limit_messages_history_in_bytes;
use crate::scratchpads::chat_utils_rag::embed_vecdb_results;
use crate::scratchpads::chat_utils_rag::{embed_vecdb_results, HasVecdb, HasVecdbResults};
use crate::vecdb::structs::VecdbSearch;

const DEBUG: bool = true;
Expand All @@ -19,6 +19,7 @@ pub struct ChatPassthrough<T> {
pub default_system_message: String,
pub limit_bytes: usize,
pub vecdb_search: Arc<AMutex<Option<T>>>,
pub has_vecdb_results: HasVecdbResults,
}

const DEFAULT_LIMIT_BYTES: usize = 4096*6;
Expand All @@ -33,6 +34,7 @@ impl<T: Send + Sync + VecdbSearch> ChatPassthrough<T> {
default_system_message: "".to_string(),
limit_bytes: DEFAULT_LIMIT_BYTES, // one token translates to 3 bytes (not unicode chars)
vecdb_search,
has_vecdb_results: HasVecdbResults::new(),
}
}
}
Expand All @@ -54,7 +56,7 @@ impl<T: Send + Sync + VecdbSearch> ScratchpadAbstract for ChatPassthrough<T> {
_sampling_parameters_to_patch: &mut SamplingParameters,
) -> Result<String, String> {
match *self.vecdb_search.lock().await {
Some(ref db) => embed_vecdb_results(db, &mut self.post, 6).await,
Some(ref db) => embed_vecdb_results(db, &mut self.post, 6, &mut self.has_vecdb_results).await,
None => {}
}
let limited_msgs: Vec<ChatMessage> = limit_messages_history_in_bytes(&self.post, self.limit_bytes, &self.default_system_message)?;
Expand Down Expand Up @@ -123,4 +125,7 @@ impl<T: Send + Sync + VecdbSearch> ScratchpadAbstract for ChatPassthrough<T> {
});
Ok((ans, finished))
}
fn response_spontaneous(&mut self) -> Result<serde_json::Value, String> {
return self.has_vecdb_results.response_streaming();
}
}
118 changes: 104 additions & 14 deletions src/scratchpads/chat_utils_rag.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
use async_trait::async_trait;
use serde_json::json;
use tracing::info;

use crate::call_validation::{ChatMessage, ChatPost, ContextFile};
use crate::vecdb::structs::{SearchResult, VecdbSearch};


pub async fn embed_vecdb_results<T>(
vecdb: &T,
post: &mut ChatPost,
limit_examples_cnt: usize,
has_vecdb: &mut dyn HasVecdb,
) where T: VecdbSearch {
let latest_msg_cont = &post.messages.last().unwrap().content;
let vdb_resp = vecdb.search(latest_msg_cont.clone(), limit_examples_cnt).await;
let vdb_cont = vecdb_resp_to_prompt(&vdb_resp);
if vdb_cont.is_ok() {
post.messages = [
&post.messages[..post.messages.len() - 1],
&[ChatMessage {
role: "context_file".to_string(),
content: vdb_cont.unwrap(),
}],
&post.messages[post.messages.len() - 1..],
].concat();
}

has_vecdb.add2messages(
vdb_resp,
&mut post.messages,
).await;
}

fn vecdb_resp_to_prompt(
fn vecdb_resp_to_json(
resp: &Result<SearchResult, String>
) -> serde_json::Result<String> {
) -> serde_json::Result<serde_json::Value> {
let context_files: Vec<ContextFile> = match resp {
Ok(search_res) => {
search_res.results.iter().map(
Expand All @@ -35,5 +35,95 @@ fn vecdb_resp_to_prompt(
}
Err(_) => vec![]
};
serde_json::to_string(&context_files)
serde_json::to_value(&context_files)
}

fn vecdb_resp_to_prompt(
resp_mb: &Result<SearchResult, String>
) -> String {
let mut cont = "".to_string();

if resp_mb.is_err() {
info!("VECDB ERR");
return cont
}
let resp = resp_mb.as_ref().unwrap();
cont.push_str("CONTEXT:\n");

for res in resp.results.iter() {
cont.push_str("FILENAME:\n");
cont.push_str(res.file_path.clone().to_str().unwrap_or_else( || ""));
cont.push_str("\nTEXT:");
cont.push_str(res.window_text.clone().as_str());
cont.push_str("\n");
}
cont.push_str("\nRefer to the context to answer my next question.\n");
info!("VECDB prompt:\n{}", cont);
cont
}


pub struct HasVecdbResults {
pub was_sent: bool,
pub in_json: serde_json::Value,
}

impl HasVecdbResults {
pub fn new() -> Self {
HasVecdbResults {
was_sent: false,
in_json: json!(null)
}
}
}

#[async_trait]
pub trait HasVecdb: Send {
async fn add2messages(
&mut self,
vdb_result_mb: Result<SearchResult, String>,
messages: &mut Vec<ChatMessage>,
);
fn response_streaming(&mut self) -> Result<serde_json::Value, String>;
}

#[async_trait]
impl HasVecdb for HasVecdbResults {
async fn add2messages(
&mut self,
result_mb: Result<SearchResult, String>,
messages: &mut Vec<ChatMessage>,
) {
// if messages.len() > 1 {
// return;
// }

*messages = [
&messages[..messages.len() -1],
&[ChatMessage {
role: "user".to_string(),
content: vecdb_resp_to_prompt(&result_mb),
}],
&messages[messages.len() -1..],
].concat();

self.in_json = vecdb_resp_to_json(&result_mb).unwrap_or_else(|_| json!(null));
}

fn response_streaming(&mut self) -> Result<serde_json::Value, String> {
if self.was_sent == true || self.in_json.is_null() {
return Ok(json!(null));
}
self.was_sent = true;
return Ok(json!({
"choices": [{
"delta": {
"content": self.in_json.clone(),
"role": "context_file"
},
"finish_reason": serde_json::Value::Null,
"index": 0
}],
}));
}
}
3 changes: 3 additions & 0 deletions src/scratchpads/completion_single_file_fim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,9 @@ impl<T: Send + Sync + VecdbSearch> ScratchpadAbstract for SingleFileFIM<T> {
});
Ok((ans, finished))
}
fn response_spontaneous(&mut self) -> Result<serde_json::Value, String> {
return Err("".to_string());
}
}

fn get_context_near_cursor(text: &Rope, line_pos: usize, max_lines_count: usize) -> String {
Expand Down
4 changes: 3 additions & 1 deletion src/vecdb/vecdb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pub struct VecDb {

model_name: String,
endpoint_template: String,
endpoint_embeddings_style: String,
}


Expand Down Expand Up @@ -176,6 +177,7 @@ impl VecDb {

model_name,
endpoint_template,
endpoint_embeddings_style,
})
}

Expand Down Expand Up @@ -206,7 +208,7 @@ impl VecDb {
impl VecdbSearch for VecDb {
async fn search(&self, query: String, top_n: usize) -> Result<SearchResult, String> {
let embedding_mb = get_embedding(
&self.cmdline.address_url,
&self.endpoint_embeddings_style,
&self.model_name,
&self.endpoint_template,
query.clone(),
Expand Down