Skip to content
This repository has been archived by the owner on Oct 19, 2023. It is now read-only.

VecDB #10

Draft
wants to merge 6 commits into
base: multifile1
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/global_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use std::sync::RwLock as StdRwLock;
use tokio::sync::Mutex as AMutex;
use tokio::sync::RwLock as ARwLock;
use tokenizers::Tokenizer;
use structopt::StructOpt;
Expand Down Expand Up @@ -34,7 +35,7 @@ pub struct CommandLine {
}


#[derive(Debug)]
// #[derive(Debug)]
pub struct GlobalContext {
pub http_client: reqwest::Client,
pub ask_shutdown_sender: Arc<Mutex<std::sync::mpsc::Sender<String>>>,
Expand All @@ -45,7 +46,7 @@ pub struct GlobalContext {
pub cmdline: CommandLine,
pub completions_cache: Arc<StdRwLock<CompletionCache>>,
pub telemetry: Arc<StdRwLock<telemetry_storage::Storage>>,
pub vecdb_search: Arc<Mutex<dyn VecdbSearch>>,
pub vecdb_search: Arc<AMutex<Box<dyn VecdbSearch + Send>>>,
}


Expand Down Expand Up @@ -124,6 +125,7 @@ pub async fn create_global_context(
cmdline: cmdline.clone(),
completions_cache: Arc::new(StdRwLock::new(CompletionCache::new())),
telemetry: Arc::new(StdRwLock::new(telemetry_storage::Storage::new())),
vecdb_search: Arc::new(AMutex::new(Box::new(crate::vecdb_search::VecdbSearchTest::new()))),
};
(Arc::new(ARwLock::new(cx)), ask_shutdown_receiver, cmdline)
}
7 changes: 4 additions & 3 deletions src/http_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use crate::custom_error::ScratchError;
use crate::telemetry_basic;
use crate::telemetry_snippets;
use crate::completion_cache;
// use crate::vecdb_search::VecdbSearch;


async fn _get_caps_and_tokenizer(
Expand Down Expand Up @@ -155,7 +156,7 @@ async fn handle_v1_code_completion(
let prompt = scratchpad.prompt(
2048,
&mut code_completion_post.parameters,
).map_err(|e|
).await.map_err(|e|
ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, format!("Prompt: {}", e))
)?;
// info!("prompt {:?}\n{}", t1.elapsed(), prompt);
Expand Down Expand Up @@ -193,7 +194,7 @@ async fn handle_v1_chat(
ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR,format!("Tokenizer: {}", e))
)?;

let vecdb_search = ;
let vecdb_search = global_context.read().await.vecdb_search.clone();
let mut scratchpad = scratchpads::create_chat_scratchpad(
chat_post.clone(),
&scratchpad_name,
Expand All @@ -207,7 +208,7 @@ async fn handle_v1_chat(
let prompt = scratchpad.prompt(
2048,
&mut chat_post.parameters,
).map_err(|e|
).await.map_err(|e|
ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, format!("Prompt: {}", e))
)?;
// info!("chat prompt {:?}\n{}", t1.elapsed(), prompt);
Expand Down
2 changes: 1 addition & 1 deletion src/lsp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ impl Document {
}
}

#[derive(Debug)]
// #[derive(Debug)] GlobalContext does not implement Debug
pub struct Backend {
pub gcx: Arc<ARwLock<global_context::GlobalContext>>,
pub client: tower_lsp::Client,
Expand Down
11 changes: 0 additions & 11 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,6 @@ mod telemetry_snippets;
mod telemetry_storage;
mod vecdb_search;
mod lsp;
use crate::vecdb_search::VecdbSearch;


async fn test_vecdb()
{
let mut v = vecdb_search::VecdbSearchTest::new();
let res = v.search("ParallelTasksV3").await;
info!("{:?}", res);
}


#[tokio::main]
Expand All @@ -54,8 +45,6 @@ async fn main() {
.init();
info!("started");
info!("cache dir: {}", cache_dir.display());
test_vecdb().await;
return;

let gcx2 = gcx.clone();
let gcx3 = gcx.clone();
Expand Down
4 changes: 3 additions & 1 deletion src/scratchpad_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@ use std::sync::Arc;
use std::sync::RwLock;
use tokenizers::Tokenizer;
use crate::call_validation::SamplingParameters;
use async_trait::async_trait;


#[async_trait]
pub trait ScratchpadAbstract: Send {
fn apply_model_adaptation_patch(
&mut self,
patch: &serde_json::Value,
) -> Result<(), String>;

fn prompt(
async fn prompt(
&mut self,
context_size: usize,
sampling_parameters_to_patch: &mut SamplingParameters,
Expand Down
12 changes: 10 additions & 2 deletions src/scratchpads/chat_generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@ use crate::call_validation::ChatPost;
use crate::call_validation::ChatMessage;
use crate::call_validation::SamplingParameters;
use crate::scratchpads::chat_utils_limit_history::limit_messages_history;
use crate::vecdb_search::{VecdbSearch, embed_vecdb_results};

use std::sync::Arc;
use std::sync::RwLock;
use async_trait::async_trait;
use tokio::sync::Mutex as AMutex;

use tokenizers::Tokenizer;
use tracing::info;

const DEBUG: bool = true;


#[derive(Debug)]
pub struct GenericChatScratchpad {
pub t: HasTokenizerAndEot,
pub dd: DeltaDeltaChatStreamer,
Expand All @@ -24,12 +27,14 @@ pub struct GenericChatScratchpad {
pub keyword_user: String,
pub keyword_asst: String,
pub default_system_message: String,
pub vecdb_search: Arc<AMutex<Box<dyn VecdbSearch + Send>>>,
}

impl GenericChatScratchpad {
pub fn new(
tokenizer: Arc<RwLock<Tokenizer>>,
post: ChatPost,
vecdb_search: Arc<AMutex<Box<dyn VecdbSearch + Send>>>,
) -> Self {
GenericChatScratchpad {
t: HasTokenizerAndEot::new(tokenizer),
Expand All @@ -40,10 +45,12 @@ impl GenericChatScratchpad {
keyword_user: "".to_string(),
keyword_asst: "".to_string(),
default_system_message: "".to_string(),
vecdb_search
}
}
}

#[async_trait]
impl ScratchpadAbstract for GenericChatScratchpad {
fn apply_model_adaptation_patch(
&mut self,
Expand All @@ -68,11 +75,12 @@ impl ScratchpadAbstract for GenericChatScratchpad {
Ok(())
}

fn prompt(
async fn prompt(
&mut self,
context_size: usize,
sampling_parameters_to_patch: &mut SamplingParameters,
) -> Result<String, String> {
embed_vecdb_results(self.vecdb_search.clone(), &mut self.post, 3).await;
let limited_msgs: Vec<ChatMessage> = limit_messages_history(&self.t, &self.post, context_size, &self.default_system_message)?;
sampling_parameters_to_patch.stop = Some(self.dd.stop_list.clone());
// adapted from https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/model.py#L24
Expand Down
26 changes: 16 additions & 10 deletions src/scratchpads/chat_llama2.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
use tracing::info;
use std::sync::Arc;
use std::sync::RwLock as StdRwLock;
use tokio::sync::Mutex as AMutex;
use tokenizers::Tokenizer;
use async_trait::async_trait;

use crate::scratchpad_abstract::ScratchpadAbstract;
use crate::scratchpad_abstract::HasTokenizerAndEot;
use crate::scratchpads::chat_utils_deltadelta::DeltaDeltaChatStreamer;
use crate::call_validation::ChatPost;
use crate::call_validation::ChatMessage;
use crate::call_validation::SamplingParameters;
use crate::scratchpads::chat_utils_limit_history::limit_messages_history;
use crate::vecdb_search;
use std::sync::Arc;
use std::sync::RwLock as StdRwLock;
use std::sync::Mutex;
use crate::vecdb_search::{VecdbSearch, embed_vecdb_results};

use tokenizers::Tokenizer;
use tracing::info;

const DEBUG: bool = true;

Expand All @@ -24,14 +26,15 @@ pub struct ChatLlama2 {
pub keyword_s: String, // "SYSTEM:" keyword means it's not one token
pub keyword_slash_s: String,
pub default_system_message: String,
pub vecdb_search: Arc<Mutex<Box<dyn vecdb_search::VecdbSearch>>>,
pub vecdb_search: Arc<AMutex<Box<dyn VecdbSearch + Send>>>,
}


impl ChatLlama2 {
pub fn new(
tokenizer: Arc<StdRwLock<Tokenizer>>,
post: ChatPost,
vecdb_search: Arc<Mutex<Box<dyn vecdb_search::VecdbSearch>>>,
vecdb_search: Arc<AMutex<Box<dyn VecdbSearch + Send>>>,
) -> Self {
ChatLlama2 {
t: HasTokenizerAndEot::new(tokenizer),
Expand All @@ -40,11 +43,12 @@ impl ChatLlama2 {
keyword_s: "<s>".to_string(),
keyword_slash_s: "</s>".to_string(),
default_system_message: "".to_string(),
vecdb_search: vecdb_search
vecdb_search
}
}
}

#[async_trait]
impl ScratchpadAbstract for ChatLlama2 {
fn apply_model_adaptation_patch(
&mut self,
Expand All @@ -62,11 +66,12 @@ impl ScratchpadAbstract for ChatLlama2 {
Ok(())
}

fn prompt(
async fn prompt(
&mut self,
context_size: usize,
sampling_parameters_to_patch: &mut SamplingParameters,
) -> Result<String, String> {
embed_vecdb_results(self.vecdb_search.clone(), &mut self.post, 3).await;
let limited_msgs: Vec<ChatMessage> = limit_messages_history(&self.t, &self.post, context_size, &self.default_system_message)?;
sampling_parameters_to_patch.stop = Some(self.dd.stop_list.clone());
// loosely adapted from https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/model.py#L24
Expand Down Expand Up @@ -101,6 +106,7 @@ impl ScratchpadAbstract for ChatLlama2 {
// This only supports assistant, not suggestions for user
self.dd.role = "assistant".to_string();
if DEBUG {
// info!("llama2 chat vdb_suggestion {:?}", vdb_suggestion);
info!("llama2 chat prompt\n{}", prompt);
info!("llama2 chat re-encode whole prompt again gives {} tokes", self.t.count_tokens(prompt.as_str())?);
}
Expand Down
5 changes: 4 additions & 1 deletion src/scratchpads/completion_single_file_fim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use std::sync::RwLock as StdRwLock;
use tokenizers::Tokenizer;
use ropey::Rope;
use tracing::info;
use async_trait::async_trait;

use crate::completion_cache;
use crate::telemetry_storage;
use crate::telemetry_snippets;
Expand Down Expand Up @@ -42,6 +44,7 @@ impl SingleFileFIM {
}


#[async_trait]
impl ScratchpadAbstract for SingleFileFIM {
fn apply_model_adaptation_patch(
&mut self,
Expand All @@ -59,7 +62,7 @@ impl ScratchpadAbstract for SingleFileFIM {
Ok(())
}

fn prompt(
async fn prompt(
&mut self,
context_size: usize,
sampling_parameters_to_patch: &mut SamplingParameters,
Expand Down
7 changes: 4 additions & 3 deletions src/scratchpads/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::sync::Arc;
use std::sync::RwLock as StdRwLock;
use std::sync::Mutex;
use tokio::sync::Mutex as AMutex;

use tokenizers::Tokenizer;

pub mod completion_single_file_fim;
Expand Down Expand Up @@ -46,11 +47,11 @@ pub fn create_chat_scratchpad(
scratchpad_name: &str,
scratchpad_patch: &serde_json::Value,
tokenizer_arc: Arc<StdRwLock<Tokenizer>>,
vecdb_search: Arc<Mutex<Box<dyn vecdb_search::VecdbSearch>>>,
vecdb_search: Arc<AMutex<Box<dyn vecdb_search::VecdbSearch + Send>>>,
) -> Result<Box<dyn ScratchpadAbstract>, String> {
let mut result: Box<dyn ScratchpadAbstract>;
if scratchpad_name == "CHAT-GENERIC" {
result = Box::new(chat_generic::GenericChatScratchpad::new(tokenizer_arc, post));
result = Box::new(chat_generic::GenericChatScratchpad::new(tokenizer_arc, post, vecdb_search));
} else if scratchpad_name == "CHAT-LLAMA2" {
result = Box::new(chat_llama2::ChatLlama2::new(tokenizer_arc, post, vecdb_search));
} else {
Expand Down
Loading