From fded15e2beb20d567ae89fc500093632994d73de Mon Sep 17 00:00:00 2001 From: reymondzzzz Date: Thu, 26 Oct 2023 13:39:53 +0300 Subject: [PATCH 01/54] WIP initial vecdb framework comments compilable thread -> tokyo coroutine debug fixes rebase fixes rm extra formatting changes using new splitter cosmetic changes, finishing some stuff status handler forward vecsearch to completion scratchpads add protobuf-compiler to gh actions add protobuf-compiler to gh actions add protobuf-compiler to gh actions add protobuf-compiler to gh actions add protobuf-compiler to gh actions add dummy handlers add impl for JB lsp handlers ask CSV first for the project files, so making it much robust status unprocessed_files_count fix vectordb cache simple prompts for chat / completion add protobuf build step add protobuf build step add protobuf build step add protobuf build step add protobuf build step add protobuf build step add protobuf build step add protobuf build step add protobuf build step review fixes, getting rid of unwrap() as much as possible review fixes, rebase try build protobuf via cmake try build protobuf via cmake try build protobuf via cmake try build protobuf via cmake try build protobuf via cmake try build protobuf via cmake try build protobuf via cmake try build protobuf via cmake try build protobuf via cmake try build protobuf via cmake try build protobuf via cmake try build protobuf via cmake try build protobuf via cmake try build protobuf via cmake try build protobuf via cmake try build protobuf via cmake try build protobuf via cmake try build protobuf via cmake fix compile errors add stats for records and small rework add cleanup thread rebase fixes add cmdline arg to enable vecdb augment messages on stack --- .github/workflows/build.yml | 38 +- Cargo.toml | 16 + examples/http_chat_passthrough.sh | 13 + examples/http_vecdb_status.sh | 2 + src/background_tasks.rs | 16 +- src/call_validation.rs | 2 + src/global_context.rs | 31 +- src/http/routers/v1.rs | 7 +- src/http/routers/v1/chat.rs | 2 +- src/http/routers/v1/code_completion.rs | 2 + src/http/routers/v1/lsp_like_handlers.rs | 10 + src/http/routers/v1/vecdb.rs | 66 +++ src/lsp.rs | 95 ++- src/main.rs | 12 +- src/scratchpads/chat_generic.rs | 38 +- src/scratchpads/chat_llama2.rs | 34 +- src/scratchpads/chat_passthrough.rs | 34 +- src/scratchpads/chat_utils_limit_history.rs | 28 +- src/scratchpads/chat_utils_rag.rs | 43 ++ src/scratchpads/completion_single_file_fim.rs | 109 +++- src/scratchpads/mod.rs | 20 +- src/vecdb/file_filter.rs | 132 +++++ src/vecdb/file_splitter.rs | 85 +++ src/vecdb/file_watcher_service.rs | 107 ++++ src/vecdb/handler.rs | 554 ++++++++++++++++++ src/vecdb/mod.rs | 10 + src/vecdb/req_client.rs | 201 +++++++ src/vecdb/structs.rs | 59 ++ src/vecdb/vecdb.rs | 95 +++ src/vecdb/vecdb_remote.rs | 46 ++ src/vecdb/vectorizer_service.rs | 258 ++++++++ src/vecdb_search.rs | 133 ----- 32 files changed, 2057 insertions(+), 241 deletions(-) create mode 100644 examples/http_chat_passthrough.sh create mode 100644 examples/http_vecdb_status.sh create mode 100644 src/http/routers/v1/vecdb.rs create mode 100644 src/scratchpads/chat_utils_rag.rs create mode 100644 src/vecdb/file_filter.rs create mode 100644 src/vecdb/file_splitter.rs create mode 100644 src/vecdb/file_watcher_service.rs create mode 100644 src/vecdb/handler.rs create mode 100644 src/vecdb/mod.rs create mode 100644 src/vecdb/req_client.rs create mode 100644 src/vecdb/structs.rs create mode 100644 src/vecdb/vecdb.rs create mode 100644 src/vecdb/vecdb_remote.rs create mode 100644 src/vecdb/vectorizer_service.rs delete mode 100644 src/vecdb_search.rs diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 709dd23ff..9f5178bf9 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -34,14 +34,14 @@ jobs: - os: windows-latest target: aarch64-pc-windows-msvc code-target: win32-arm64 - - os: ubuntu-20.04 + - os: ubuntu-22.04 target: x86_64-unknown-linux-gnu code-target: linux-x64 container: ubuntu:18.04 - - os: ubuntu-20.04 + - os: ubuntu-22.04 target: aarch64-unknown-linux-gnu code-target: linux-arm64 - - os: ubuntu-20.04 + - os: ubuntu-22.04 target: arm-unknown-linux-gnueabihf code-target: linux-armhf - os: macos-11 @@ -68,7 +68,12 @@ jobs: if: matrix.container == 'ubuntu:18.04' shell: bash run: | - apt-get update && apt-get install -y build-essential curl libssl-dev pkg-config + apt-get update && apt install -y build-essential manpages-dev software-properties-common + add-apt-repository ppa:ubuntu-toolchain-r/test + apt-get install -y sudo build-essential curl libssl-dev pkg-config wget git zlib1g zlib1g-dev gcc-11 g++-11 + unlink /usr/bin/gcc && ln -s /usr/bin/gcc-11 /usr/bin/gcc && unlink /usr/bin/g++ && ln -s /usr/bin/g++-11 /usr/bin/g++ + wget https://github.com/Kitware/CMake/releases/download/v3.20.0/cmake-3.20.0.tar.gz + tar -zxvf cmake-3.20.0.tar.gz && cd cmake-3.20.0 && ./bootstrap && make && make install && cd.. && cmake --version curl --proto '=https' --tlsv1.2 --retry 10 --retry-connrefused -fsSL "https://sh.rustup.rs" | sh -s -- --profile minimal --default-toolchain none -y echo "${CARGO_HOME:-$HOME/.cargo}/bin" >> $GITHUB_PATH @@ -88,8 +93,22 @@ jobs: - name: Install ARM target toolchain if: matrix.target == 'arm-unknown-linux-gnueabihf' - run: sudo apt-get install gcc-multilib-arm-linux-gnueabihf - + run: sudo apt-get install gcc-arm-linux-gnueabihf + + - name: Install toolchain dependencies (MacOS) + if: startsWith(matrix.os, 'macos') + run: brew install protobuf + + - name: Install toolchain dependencies (Windows) + if: startsWith(matrix.os, 'windows') + run: choco install protoc + + - name: Build and install protobuf library + if: matrix.target == 'arm-unknown-linux-gnueabihf' || matrix.target == 'x86_64-unknown-linux-gnu' || matrix.target == 'aarch64-unknown-linux-gnu' || matrix.target == 'x86_64-unknown-linux-musl' + run: | + git clone --depth=1 --recursive -b v25.1 https://github.com/protocolbuffers/protobuf + cd protobuf && cmake . && cmake --build . --parallel 10 && sudo cmake --install . + - name: Build and test artifact if: matrix.target == 'x86_64-unknown-linux-gnu' run: | @@ -131,7 +150,12 @@ jobs: steps: - name: Install dependencies - run: apk add --no-cache git clang lld musl-dev nodejs npm openssl-dev pkgconfig g++ + run: apk add --no-cache sudo git clang lld musl-dev nodejs npm openssl-dev pkgconfig g++ make cmake zlib linux-headers + + - name: Build and install protobuf library + run: | + git clone --depth=1 --recursive -b v25.1 https://github.com/protocolbuffers/protobuf + cd protobuf && cmake . && cmake --build . --parallel 10 && sudo cmake --install . - name: Checkout repository uses: actions/checkout@v3 diff --git a/Cargo.toml b/Cargo.toml index cae702a6f..e3e00a641 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,3 +38,19 @@ async-trait = "0.1.73" similar = "2.3.0" axum = "0.6.20" uuid = { version = "1", features = ["v4"] } + +arrow = "47.0.0" +arrow-array = "47.0.0" +arrow-schema= "47.0.0" +async_once= "0.2.6" +async-process = "2.0.1" +lance = "=0.9.0" +log = "0.4.20" +md5 = "0.7" +mockito = "0.28.0" +notify = "6.1.1" +tempfile = "3.8.1" +time = "0.3.30" +vectordb = "=0.4.0" +walkdir = "2.3" +which = "5.0.0" diff --git a/examples/http_chat_passthrough.sh b/examples/http_chat_passthrough.sh new file mode 100644 index 000000000..041298cbf --- /dev/null +++ b/examples/http_chat_passthrough.sh @@ -0,0 +1,13 @@ +curl http://127.0.0.1:8001/v1/chat -k \ + -H 'Content-Type: application/json' \ + -d '{ + "messages": [ + {"role": "user", "content": "Who is Bill Clinton? What is his favorite programming language?"} + ], + "model": "gpt-3.5-turbo", + "parameters": { + "temperature": 0.1, + "max_new_tokens": 10 + } +}' + diff --git a/examples/http_vecdb_status.sh b/examples/http_vecdb_status.sh new file mode 100644 index 000000000..4a074de63 --- /dev/null +++ b/examples/http_vecdb_status.sh @@ -0,0 +1,2 @@ +curl http://127.0.0.1:8001/v1/vdb-status + diff --git a/src/background_tasks.rs b/src/background_tasks.rs index 42c349859..8f3edc3e1 100644 --- a/src/background_tasks.rs +++ b/src/background_tasks.rs @@ -1,13 +1,14 @@ +use std::iter::IntoIterator; use std::sync::Arc; use std::vec; use tokio::sync::RwLock as ARwLock; use tokio::task::JoinHandle; -use crate::global_context; -use crate::telemetry::basic_transmit; +use crate::{global_context, vecdb}; use crate::global_context::GlobalContext; use crate::snippets_transmit; +use crate::telemetry::basic_transmit; pub struct BackgroundTasksHolder { tasks: Vec>, @@ -21,7 +22,15 @@ impl BackgroundTasksHolder { } pub fn push_back(&mut self, task: JoinHandle<()>) { - self.tasks.push(task) + self.tasks.push(task); + } + + + pub fn extend(&mut self, tasks: T) + where + T: IntoIterator>, + { + self.tasks.extend(tasks); } pub async fn abort(self) { @@ -36,5 +45,6 @@ pub fn start_background_tasks(global_context: Arc>) -> Ba BackgroundTasksHolder::new(vec![ tokio::spawn(basic_transmit::telemetry_background_task(global_context.clone())), tokio::spawn(snippets_transmit::tele_snip_background_task(global_context.clone())), + tokio::spawn(vecdb::file_watcher_service::file_watcher_task(global_context.clone())), ]) } diff --git a/src/call_validation.rs b/src/call_validation.rs index 03347800d..fb469d6f8 100644 --- a/src/call_validation.rs +++ b/src/call_validation.rs @@ -178,6 +178,8 @@ mod tests { pub struct ContextFile { pub file_name: String, pub file_content: String, + pub line1: i32, + pub line2: i32, } #[derive(Debug, Serialize, Deserialize, Clone)] diff --git a/src/global_context.rs b/src/global_context.rs index 5654482e6..1c5164576 100644 --- a/src/global_context.rs +++ b/src/global_context.rs @@ -15,12 +15,12 @@ use async_trait::async_trait; use crate::caps::CodeAssistantCaps; use crate::completion_cache::CompletionCache; use crate::telemetry::telemetry_structs; -use crate::vecdb_search::VecdbSearch; use crate::custom_error::ScratchError; use hyper::StatusCode; use tokio::sync::mpsc::Permit; use tower_lsp::lsp_types::WorkspaceFolder; use crate::receive_workspace_changes::Document; +use crate::vecdb::vecdb::VecDb; #[derive(Debug, StructOpt, Clone)] @@ -47,6 +47,10 @@ pub struct CommandLine { pub insecure: bool, #[structopt(long, short="v", help="Verbose logging, lots of output")] pub verbose: bool, + #[structopt(long, help="Whether to use a vector database")] + pub vecdb: bool, + #[structopt(long, short = "f", default_value = "", help = "The path to jsonl file which contains filtered source files")] + pub files_set_path: String, } impl CommandLine { @@ -76,7 +80,7 @@ pub struct GlobalContext { pub tokenizer_map: HashMap< String, Arc>>, pub completions_cache: Arc>, pub telemetry: Arc>, - pub vecdb_search: Arc>>, + pub vec_db: Arc>>, pub ask_shutdown_sender: Arc>>, pub lsp_backend_document_state: LSPBackendDocumentState, } @@ -160,6 +164,27 @@ pub async fn create_global_context( http_client_builder = http_client_builder.danger_accept_invalid_certs(true) } let http_client = http_client_builder.build().unwrap(); + let mut vec_db: Option = None; + if cmdline.vecdb { + vec_db = match VecDb::init( + cache_dir.clone(), cmdline.clone(), + 384, 60, 512, 1024, + "BAAI/bge-small-en-v1.5".to_string(), + ).await { + Ok(res) => Some(res), + Err(err) => { + error!("Ooops database is broken! + Last error message: {} + You can report this issue here: + https://github.com/smallcloudai/refact-lsp/issues + Also you can run this to erase your db: + `rm -rf ~/.cache/refact/refact_vecdb_cache` + After that restart this LSP server or your IDE.", err); + None + } + }; + } + let cx = GlobalContext { cmdline: cmdline.clone(), @@ -171,7 +196,7 @@ pub async fn create_global_context( tokenizer_map: HashMap::new(), completions_cache: Arc::new(StdRwLock::new(CompletionCache::new())), telemetry: Arc::new(StdRwLock::new(telemetry_structs::Storage::new())), - vecdb_search: Arc::new(AMutex::new(Box::new(crate::vecdb_search::VecdbSearchTest::new()))), + vec_db: Arc::new(AMutex::new(vec_db)), ask_shutdown_sender: Arc::new(Mutex::new(ask_shutdown_sender)), lsp_backend_document_state: LSPBackendDocumentState { document_map: Arc::new(ARwLock::new(HashMap::new())), diff --git a/src/http/routers/v1.rs b/src/http/routers/v1.rs index 043194189..4d0f1a88d 100644 --- a/src/http/routers/v1.rs +++ b/src/http/routers/v1.rs @@ -20,6 +20,7 @@ use crate::http::routers::v1::telemetry_network::handle_v1_telemetry_network; use crate::http::routers::v1::lsp_like_handlers::handle_v1_lsp_initialize; use crate::http::routers::v1::lsp_like_handlers::handle_v1_lsp_did_change; use crate::http::utils::telemetry_wrapper; +use crate::http::routers::v1::vecdb::{handle_v1_vecdb_search, handle_v1_vecdb_status}; pub mod code_completion; pub mod chat; @@ -27,7 +28,8 @@ pub mod telemetry_network; pub mod snippet_accepted; pub mod caps; pub mod graceful_shutdown; -mod lsp_like_handlers; +pub mod lsp_like_handlers; +pub mod vecdb; pub fn make_v1_router() -> Router { Router::new() @@ -39,6 +41,9 @@ pub fn make_v1_router() -> Router { .route("/caps", telemetry_get!(handle_v1_caps)) .route("/graceful-shutdown", telemetry_get!(handle_v1_graceful_shutdown)) + .route("/vdb-search", telemetry_get!(handle_v1_vecdb_search)) + .route("/vdb-status", telemetry_get!(handle_v1_vecdb_status)) + .route("/lsp-initialize", telemetry_post!(handle_v1_lsp_initialize)) .route("/lsp-did-changed", telemetry_post!(handle_v1_lsp_did_change)) } \ No newline at end of file diff --git a/src/http/routers/v1/chat.rs b/src/http/routers/v1/chat.rs index 75943beb6..ad777dc7f 100644 --- a/src/http/routers/v1/chat.rs +++ b/src/http/routers/v1/chat.rs @@ -55,7 +55,7 @@ pub async fn handle_v1_chat( let cx_locked = global_context.write().await; (cx_locked.http_client.clone(), cx_locked.cmdline.api_key.clone()) }; - let vecdb_search = global_context.read().await.vecdb_search.clone(); + let vecdb_search = global_context.read().await.vec_db.clone(); let mut scratchpad = scratchpads::create_chat_scratchpad( global_context.clone(), caps, diff --git a/src/http/routers/v1/code_completion.rs b/src/http/routers/v1/code_completion.rs index 7b00ed0ad..06615da05 100644 --- a/src/http/routers/v1/code_completion.rs +++ b/src/http/routers/v1/code_completion.rs @@ -81,6 +81,7 @@ pub async fn handle_v1_code_completion( } } + let vecdb_search = global_context.read().await.vec_db.clone(); let mut scratchpad = scratchpads::create_code_completion_scratchpad( global_context.clone(), caps, @@ -90,6 +91,7 @@ pub async fn handle_v1_code_completion( &scratchpad_patch, cache_arc.clone(), tele_storage.clone(), + vecdb_search ).await.map_err(|e| ScratchError::new(StatusCode::BAD_REQUEST, e) )?; diff --git a/src/http/routers/v1/lsp_like_handlers.rs b/src/http/routers/v1/lsp_like_handlers.rs index a444026d4..e2390f9fb 100644 --- a/src/http/routers/v1/lsp_like_handlers.rs +++ b/src/http/routers/v1/lsp_like_handlers.rs @@ -1,3 +1,4 @@ +use std::path::PathBuf; use axum::Extension; use axum::response::Result; use hyper::{Body, Response, StatusCode}; @@ -7,6 +8,7 @@ use url::Url; use crate::custom_error::ScratchError; use crate::global_context::SharedGlobalContext; use crate::receive_workspace_changes; +use crate::vecdb::file_filter::retrieve_files_by_proj_folders; #[derive(Serialize, Deserialize, Clone)] @@ -29,6 +31,14 @@ pub async fn handle_v1_lsp_initialize( ScratchError::new(StatusCode::BAD_REQUEST, format!("JSON problem: {}", e)) })?; + let files = retrieve_files_by_proj_folders( + post.project_roots.iter().map(|x| PathBuf::from(x.path())).collect() + ).await; + match *global_context.read().await.vec_db.lock().await { + Some(ref mut db) => db.add_or_update_files(files, true).await, + None => {} + }; + Ok(Response::builder() .status(StatusCode::OK) .body(Body::from(json!({"success": 1}).to_string())) diff --git a/src/http/routers/v1/vecdb.rs b/src/http/routers/v1/vecdb.rs new file mode 100644 index 000000000..6f38c9cc0 --- /dev/null +++ b/src/http/routers/v1/vecdb.rs @@ -0,0 +1,66 @@ +use axum::Extension; +use axum::response::Result; +use hyper::{Body, Response, StatusCode}; +use serde::{Deserialize, Serialize}; +use serde_json::json; + +use crate::custom_error::ScratchError; +use crate::global_context::SharedGlobalContext; +use crate::vecdb::structs::VecdbSearch; + +#[derive(Serialize, Deserialize, Clone)] +struct VecDBPost { + query: String, + top_n: usize, +} + +pub async fn handle_v1_vecdb_search( + Extension(global_context): Extension, + body_bytes: hyper::body::Bytes, +) -> Result, ScratchError> { + let post = serde_json::from_slice::(&body_bytes).map_err(|e| { + ScratchError::new(StatusCode::BAD_REQUEST, format!("JSON problem: {}", e)) + })?; + + let cx_locked = global_context.read().await; + let search_res = match *cx_locked.vec_db.lock().await { + Some(ref db) => db.search(post.query.to_string(), post.top_n).await, + None => { + return Err(ScratchError::new( + StatusCode::INTERNAL_SERVER_ERROR, "Vector db is not available".to_string() + )); + } + }; + + match search_res { + Ok(search_res) => { + Ok(Response::builder() + .status(StatusCode::OK) + .body(Body::from(json!(search_res).to_string())) + .unwrap()) + } + Err(e) => { + Err(ScratchError::new(StatusCode::BAD_REQUEST, e)) + } + } +} + +pub async fn handle_v1_vecdb_status( + Extension(global_context): Extension, + _: hyper::body::Bytes, +) -> Result, ScratchError> { + let cx_locked = global_context.read().await; + let status = match *cx_locked.vec_db.lock().await { + Some(ref db) => db.get_status().await, + None => { + return Err(ScratchError::new( + StatusCode::INTERNAL_SERVER_ERROR, "Vector db is not available".to_string() + )); + } + }; + + Ok(Response::builder() + .status(StatusCode::OK) + .body(Body::from(json!(status).to_string())) + .unwrap()) +} diff --git a/src/lsp.rs b/src/lsp.rs index 64aad047d..86bbd0924 100644 --- a/src/lsp.rs +++ b/src/lsp.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; use std::fmt::Display; +use std::path::PathBuf; use std::sync::Arc; use serde::{Deserialize, Serialize}; @@ -15,10 +16,10 @@ use crate::call_validation::{CodeCompletionInputs, CodeCompletionPost, CursorPos use crate::global_context; use crate::global_context::CommandLine; use crate::http::routers::v1::code_completion::handle_v1_code_completion; -use crate::telemetry; use crate::receive_workspace_changes; +use crate::telemetry; use crate::telemetry::snippets_collection; - +use crate::vecdb::file_filter::{is_valid_file, retrieve_files_by_proj_folders}; const VERSION: &str = env!("CARGO_PKG_VERSION"); @@ -35,7 +36,6 @@ impl Display for APIError { } - // #[derive(Debug)] GlobalContext does not implement Debug pub struct Backend { pub gcx: Arc>, @@ -182,10 +182,10 @@ impl Backend { if is_valid { unchanged_percentage = telemetry::utils::unchanged_percentage( ¶ms.orig_grey_text, - &grey_corrected + &grey_corrected, ); } - Ok(TestHeadTailAddedTextRes{is_valid, grey_corrected, unchanged_percentage}) + Ok(TestHeadTailAddedTextRes { is_valid, grey_corrected, unchanged_percentage }) } } @@ -196,10 +196,22 @@ impl LanguageServer for Backend { info!("LSP client_info {:?}", params.client_info); { let gcx_locked = self.gcx.write().await; - *gcx_locked.lsp_backend_document_state.workspace_folders.write().await = params.workspace_folders; + *gcx_locked.lsp_backend_document_state.workspace_folders.write().await = params.workspace_folders.clone(); info!("LSP workspace_folders {:?}", gcx_locked.lsp_backend_document_state.workspace_folders); } + if let Some(folders) = params.workspace_folders { + let files = retrieve_files_by_proj_folders( + folders.iter().map(|x| PathBuf::from(x.uri.path())).collect() + ).await; + match *self.gcx.read().await.vec_db.lock().await { + Some(ref db) => db.add_or_update_files(files, true).await, + None => { + info!("LSP no vec_db"); + } + }; + } + let completion_options: CompletionOptions; completion_options = CompletionOptions { resolve_provider: Some(false), @@ -240,6 +252,14 @@ impl LanguageServer for Backend { } async fn did_change(&self, params: DidChangeTextDocumentParams) { + let file_path = PathBuf::from(params.text_document.uri.path()); + if is_valid_file(&file_path) { + match *self.gcx.read().await.vec_db.lock().await { + Some(ref mut db) => db.add_or_update_file(file_path, false).await, + None => {} + }; + } + receive_workspace_changes::on_did_change( self.gcx.clone(), ¶ms.text_document.uri.to_string(), @@ -247,13 +267,34 @@ impl LanguageServer for Backend { ).await } - async fn did_close(&self, params: DidCloseTextDocumentParams) { - let _uri = params.text_document.uri.to_string(); - // TODO: remove text from memory + async fn did_save(&self, params: DidSaveTextDocumentParams) { + self.client + .log_message(MessageType::INFO, "{refact-lsp} file saved") + .await; + let uri = params.text_document.uri.to_string(); + let file_path = PathBuf::from(params.text_document.uri.path()); + if is_valid_file(&file_path) { + match *self.gcx.read().await.vec_db.lock().await { + Some(ref mut db) => db.add_or_update_file(file_path, false).await, + None => {} + }; + } + info!("{uri} saved"); } - async fn did_save(&self, params: DidSaveTextDocumentParams) { - let _uri = params.text_document.uri.to_string(); + async fn did_close(&self, params: DidCloseTextDocumentParams) { + self.client + .log_message(MessageType::INFO, "{refact-lsp} file closed") + .await; + let uri = params.text_document.uri.to_string(); + let file_path = PathBuf::from(params.text_document.uri.path()); + if is_valid_file(&file_path) { + match *self.gcx.read().await.vec_db.lock().await { + Some(ref mut db) => db.add_or_update_file(file_path, false).await, + None => {} + }; + } + info!("{uri} closed"); } async fn shutdown(&self) -> Result<()> { @@ -266,6 +307,34 @@ impl LanguageServer for Backend { Ok(Some(CompletionResponse::Array(vec![ ]))) } + + async fn did_delete_files(&self, params: DeleteFilesParams) { + let files = params.files + .into_iter() + .map(|x| PathBuf::from(x.uri.replace("file://", ""))) + .filter(|x| is_valid_file(&x)); + + match *self.gcx.read().await.vec_db.lock().await { + Some(ref mut db) => { + for file in files { + db.remove_file(&file).await; + } + } + None => {} + }; + } + + async fn did_create_files(&self, params: CreateFilesParams) { + let files = params.files + .into_iter() + .map(|x| PathBuf::from(x.uri.replace("file://", ""))) + .filter(|x| is_valid_file(&x)) + .collect(); + match *self.gcx.read().await.vec_db.lock().await { + Some(ref mut db) => db.add_or_update_files(files, false).await, + None => {} + }; + } } async fn build_lsp_service( @@ -289,7 +358,7 @@ pub async fn spawn_lsp_task( if cmdline.lsp_stdin_stdout == 0 && cmdline.lsp_port > 0 { let gcx_t = gcx.clone(); let addr: std::net::SocketAddr = ([127, 0, 0, 1], cmdline.lsp_port).into(); - return Some(tokio::spawn( async move { + return Some(tokio::spawn(async move { let listener: TcpListener = TcpListener::bind(&addr).await.unwrap(); info!("LSP listening on {}", listener.local_addr().unwrap()); loop { @@ -312,7 +381,7 @@ pub async fn spawn_lsp_task( if cmdline.lsp_stdin_stdout != 0 && cmdline.lsp_port == 0 { let gcx_t = gcx.clone(); - return Some(tokio::spawn( async move { + return Some(tokio::spawn(async move { let stdin = tokio::io::stdin(); let stdout = tokio::io::stdout(); let (lsp_service, socket) = build_lsp_service(gcx_t.clone()).await; diff --git a/src/main.rs b/src/main.rs index 63423ec65..ebd1dcd73 100644 --- a/src/main.rs +++ b/src/main.rs @@ -20,11 +20,12 @@ mod restream; mod custom_error; mod completion_cache; mod telemetry; -mod vecdb_search; mod lsp; mod http; mod background_tasks; mod receive_workspace_changes; +mod vecdb; + #[tokio::main] async fn main() { @@ -60,8 +61,13 @@ async fn main() { } let mut background_tasks = start_background_tasks(gcx.clone()); + background_tasks.extend(match *gcx.read().await.vec_db.lock().await { + Some(ref db) => db.start_background_tasks().await, + None => vec![] + }); + let should_start_http = cmdline.http_port != 0; - let should_start_lsp = (cmdline.lsp_port == 0 && cmdline.lsp_stdin_stdout == 1) || + let should_start_lsp = (cmdline.lsp_port == 0 && cmdline.lsp_stdin_stdout == 1) || (cmdline.lsp_port != 0 && cmdline.lsp_stdin_stdout == 0); let mut main_handle: Option> = None; @@ -78,7 +84,7 @@ async fn main() { if main_handle.is_some() { let _ = main_handle.unwrap().await; } - + background_tasks.abort().await; info!("saving telemetry without sending, so should be quick"); basic_transmit::basic_telemetry_compress(gcx.clone()).await; diff --git a/src/scratchpads/chat_generic.rs b/src/scratchpads/chat_generic.rs index 544fb221f..46c3cf969 100644 --- a/src/scratchpads/chat_generic.rs +++ b/src/scratchpads/chat_generic.rs @@ -1,22 +1,23 @@ -use crate::scratchpad_abstract::ScratchpadAbstract; -use crate::scratchpad_abstract::HasTokenizerAndEot; -use crate::scratchpads::chat_utils_deltadelta::DeltaDeltaChatStreamer; -use crate::call_validation::{ChatPost, ChatMessage, SamplingParameters, ContextFile}; -use crate::scratchpads::chat_utils_limit_history::limit_messages_history; -use crate::vecdb_search::{VecdbSearch, embed_vecdb_results}; - use std::sync::Arc; use std::sync::RwLock; -use async_trait::async_trait; -use tokio::sync::Mutex as AMutex; +use async_trait::async_trait; use tokenizers::Tokenizer; +use tokio::sync::Mutex as AMutex; use tracing::info; +use crate::call_validation::{ChatMessage, ChatPost, ContextFile, SamplingParameters}; +use crate::scratchpad_abstract::HasTokenizerAndEot; +use crate::scratchpad_abstract::ScratchpadAbstract; +use crate::scratchpads::chat_utils_deltadelta::DeltaDeltaChatStreamer; +use crate::scratchpads::chat_utils_limit_history::limit_messages_history; +use crate::scratchpads::chat_utils_rag::embed_vecdb_results; +use crate::vecdb::structs::VecdbSearch; + const DEBUG: bool = true; -pub struct GenericChatScratchpad { +pub struct GenericChatScratchpad { pub t: HasTokenizerAndEot, pub dd: DeltaDeltaChatStreamer, pub post: ChatPost, @@ -25,15 +26,15 @@ pub struct GenericChatScratchpad { pub keyword_user: String, pub keyword_asst: String, pub default_system_message: String, - pub vecdb_search: Arc>>, + pub vecdb_search: Arc>>, } -impl GenericChatScratchpad { +impl GenericChatScratchpad { pub fn new( tokenizer: Arc>, post: ChatPost, - vecdb_search: Arc>>, - ) -> Self { + vecdb_search: Arc>>, + ) -> Self where T: VecdbSearch + 'static + Sync { GenericChatScratchpad { t: HasTokenizerAndEot::new(tokenizer), dd: DeltaDeltaChatStreamer::new(), @@ -49,7 +50,7 @@ impl GenericChatScratchpad { } #[async_trait] -impl ScratchpadAbstract for GenericChatScratchpad { +impl ScratchpadAbstract for GenericChatScratchpad { fn apply_model_adaptation_patch( &mut self, patch: &serde_json::Value, @@ -83,8 +84,11 @@ impl ScratchpadAbstract for GenericChatScratchpad { context_size: usize, sampling_parameters_to_patch: &mut SamplingParameters, ) -> Result { - // embed_vecdb_results(self.vecdb_search.clone(), &mut self.post, 3).await; - let limited_msgs: Vec = limit_messages_history(&self.t, &self.post, context_size, &self.default_system_message)?; + let augmented_msgs = match *self.vecdb_search.lock().await { + Some(ref db) => embed_vecdb_results(db, &self.post.messages, 6).await, + None => { self.post.messages.clone() } + }; + let limited_msgs: Vec = limit_messages_history(&self.t, &augmented_msgs, self.post.parameters.max_new_tokens, context_size, &self.default_system_message)?; sampling_parameters_to_patch.stop = Some(self.dd.stop_list.clone()); // adapted from https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/model.py#L24 let mut prompt = "".to_string(); diff --git a/src/scratchpads/chat_llama2.rs b/src/scratchpads/chat_llama2.rs index b57357c42..1732f7c9a 100644 --- a/src/scratchpads/chat_llama2.rs +++ b/src/scratchpads/chat_llama2.rs @@ -1,39 +1,40 @@ -use tracing::info; use std::sync::Arc; use std::sync::RwLock as StdRwLock; -use tokio::sync::Mutex as AMutex; -use tokenizers::Tokenizer; + use async_trait::async_trait; +use tokenizers::Tokenizer; +use tokio::sync::Mutex as AMutex; +use tracing::info; -use crate::scratchpad_abstract::ScratchpadAbstract; +use crate::call_validation::{ChatMessage, ChatPost, ContextFile, SamplingParameters}; use crate::scratchpad_abstract::HasTokenizerAndEot; +use crate::scratchpad_abstract::ScratchpadAbstract; use crate::scratchpads::chat_utils_deltadelta::DeltaDeltaChatStreamer; -use crate::call_validation::{ChatPost, ChatMessage, SamplingParameters, ContextFile}; use crate::scratchpads::chat_utils_limit_history::limit_messages_history; -use crate::vecdb_search::{VecdbSearch, embed_vecdb_results}; - +use crate::scratchpads::chat_utils_rag::embed_vecdb_results; +use crate::vecdb::structs::VecdbSearch; const DEBUG: bool = true; // #[derive(Debug)] -pub struct ChatLlama2 { +pub struct ChatLlama2 { pub t: HasTokenizerAndEot, pub dd: DeltaDeltaChatStreamer, pub post: ChatPost, pub keyword_s: String, // "SYSTEM:" keyword means it's not one token pub keyword_slash_s: String, pub default_system_message: String, - pub vecdb_search: Arc>>, + pub vecdb_search: Arc>>, } -impl ChatLlama2 { +impl ChatLlama2 { pub fn new( tokenizer: Arc>, post: ChatPost, - vecdb_search: Arc>>, - ) -> Self { + vecdb_search: Arc>>, + ) -> Self where T: VecdbSearch + Send { ChatLlama2 { t: HasTokenizerAndEot::new(tokenizer), dd: DeltaDeltaChatStreamer::new(), @@ -47,7 +48,7 @@ impl ChatLlama2 { } #[async_trait] -impl ScratchpadAbstract for ChatLlama2 { +impl ScratchpadAbstract for ChatLlama2 { fn apply_model_adaptation_patch( &mut self, patch: &serde_json::Value, @@ -69,8 +70,11 @@ impl ScratchpadAbstract for ChatLlama2 { context_size: usize, sampling_parameters_to_patch: &mut SamplingParameters, ) -> Result { - // embed_vecdb_results(self.vecdb_search.clone(), &mut self.post, 3).await; - let limited_msgs: Vec = limit_messages_history(&self.t, &self.post, context_size, &self.default_system_message)?; + let augmented_msgs = match *self.vecdb_search.lock().await { + Some(ref db) => embed_vecdb_results(db, &self.post.messages, 6).await, + None => { self.post.messages.clone() } + }; + let limited_msgs: Vec = limit_messages_history(&self.t, &augmented_msgs, self.post.parameters.max_new_tokens, context_size, &self.default_system_message)?; sampling_parameters_to_patch.stop = Some(self.dd.stop_list.clone()); // loosely adapted from https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/model.py#L24 let mut prompt = "".to_string(); diff --git a/src/scratchpads/chat_passthrough.rs b/src/scratchpads/chat_passthrough.rs index a93ad30c0..7b8dc972f 100644 --- a/src/scratchpads/chat_passthrough.rs +++ b/src/scratchpads/chat_passthrough.rs @@ -1,33 +1,33 @@ -use tracing::info; use std::sync::Arc; -use tokio::sync::Mutex as AMutex; + use async_trait::async_trait; +use tokio::sync::Mutex as AMutex; +use tracing::info; +use crate::call_validation::{ChatMessage, ChatPost, ContextFile, SamplingParameters}; use crate::scratchpad_abstract::ScratchpadAbstract; -use crate::call_validation::{ChatPost, ChatMessage, SamplingParameters, ContextFile}; use crate::scratchpads::chat_utils_limit_history::limit_messages_history_in_bytes; -// use crate::vecdb_search::{VecdbSearch, embed_vecdb_results}; -use crate::vecdb_search::VecdbSearch; - +use crate::scratchpads::chat_utils_rag::embed_vecdb_results; +use crate::vecdb::structs::VecdbSearch; const DEBUG: bool = true; // #[derive(Debug)] -pub struct ChatPassthrough { +pub struct ChatPassthrough { pub post: ChatPost, pub default_system_message: String, pub limit_bytes: usize, - pub vecdb_search: Arc>>, + pub vecdb_search: Arc>>, } -const DEFAULT_LIMIT_BYTES: usize = 4096*3; +const DEFAULT_LIMIT_BYTES: usize = 4096*6; -impl ChatPassthrough { +impl ChatPassthrough { pub fn new( post: ChatPost, - vecdb_search: Arc>>, - ) -> Self { + vecdb_search: Arc>>, + ) -> Self where T: VecdbSearch + 'static + Sync { ChatPassthrough { post, default_system_message: "".to_string(), @@ -38,7 +38,7 @@ impl ChatPassthrough { } #[async_trait] -impl ScratchpadAbstract for ChatPassthrough { +impl ScratchpadAbstract for ChatPassthrough { fn apply_model_adaptation_patch( &mut self, patch: &serde_json::Value, @@ -53,7 +53,11 @@ impl ScratchpadAbstract for ChatPassthrough { _context_size: usize, _sampling_parameters_to_patch: &mut SamplingParameters, ) -> Result { - let limited_msgs: Vec = limit_messages_history_in_bytes(&self.post, self.limit_bytes, &self.default_system_message)?; + let augmented_msgs = match *self.vecdb_search.lock().await { + Some(ref db) => embed_vecdb_results(db, &self.post.messages, 6).await, + None => { self.post.messages.clone() } + }; + let limited_msgs: Vec = limit_messages_history_in_bytes(&augmented_msgs, self.limit_bytes, &self.default_system_message)?; info!("chat passthrough {} messages -> {} messages after applying limits and possibly adding the default system message", &limited_msgs.len(), &limited_msgs.len()); let mut filtered_msgs: Vec = Vec::::new(); for msg in &limited_msgs { @@ -64,7 +68,7 @@ impl ScratchpadAbstract for ChatPassthrough { for context_file in &vector_of_context_files { filtered_msgs.push(ChatMessage { role: "user".to_string(), - content: format!("{}\n```\n{}```", context_file.file_name, context_file.file_content), + content: format!("{}:{}-{}\n```\n{}```", context_file.file_name, context_file.line1, context_file.line2, context_file.file_content), }); } } diff --git a/src/scratchpads/chat_utils_limit_history.rs b/src/scratchpads/chat_utils_limit_history.rs index 9c13b38bb..efd6a7314 100644 --- a/src/scratchpads/chat_utils_limit_history.rs +++ b/src/scratchpads/chat_utils_limit_history.rs @@ -1,21 +1,21 @@ use crate::scratchpad_abstract::HasTokenizerAndEot; -use crate::call_validation::ChatPost; use crate::call_validation::ChatMessage; pub fn limit_messages_history( t: &HasTokenizerAndEot, - post: &ChatPost, + messages: &Vec, + max_new_tokens: usize, context_size: usize, default_system_mesage: &String, ) -> Result, String> { - let tokens_limit: i32 = context_size as i32 - post.parameters.max_new_tokens as i32; + let tokens_limit: i32 = context_size as i32 - max_new_tokens as i32; let mut tokens_used: i32 = 0; - let mut message_token_count: Vec = vec![0; post.messages.len()]; - let mut message_take: Vec = vec![false; post.messages.len()]; + let mut message_token_count: Vec = vec![0; messages.len()]; + let mut message_take: Vec = vec![false; messages.len()]; let mut have_system = false; - for (i, msg) in post.messages.iter().enumerate() { + for (i, msg) in messages.iter().enumerate() { let tcnt = (3 + t.count_tokens(msg.content.as_str())?) as i32; // 3 for role "\n\nASSISTANT:" kind of thing message_token_count[i] = tcnt; if i==0 && msg.role == "system" { @@ -29,7 +29,7 @@ pub fn limit_messages_history( let tcnt = t.count_tokens(default_system_mesage.as_str())? as i32; tokens_used += tcnt; } - for i in (0..post.messages.len()).rev() { + for i in (0..messages.len()).rev() { let tcnt = 3 + message_token_count[i]; if !message_take[i] { if tokens_used + tcnt < tokens_limit { @@ -40,7 +40,7 @@ pub fn limit_messages_history( } } } - let mut messages_out: Vec = post.messages.iter().enumerate().filter(|(i, _)| message_take[*i]).map(|(_, x)| x.clone()).collect(); + let mut messages_out: Vec = messages.iter().enumerate().filter(|(i, _)| message_take[*i]).map(|(_, x)| x.clone()).collect(); if need_default_system_msg { messages_out.insert(0, ChatMessage { role: "system".to_string(), @@ -52,15 +52,15 @@ pub fn limit_messages_history( pub fn limit_messages_history_in_bytes( - post: &ChatPost, + messages: &Vec, bytes_limit: usize, default_system_mesage: &String, ) -> Result, String> { let mut bytes_used: usize = 0; let mut have_system = false; - let mut message_take: Vec = vec![false; post.messages.len()]; - for (i, msg) in post.messages.iter().enumerate() { + let mut message_take: Vec = vec![false; messages.len()]; + for (i, msg) in messages.iter().enumerate() { bytes_used += msg.content.as_bytes().len(); if i==0 && msg.role == "system" { message_take[i] = true; @@ -71,8 +71,8 @@ pub fn limit_messages_history_in_bytes( if need_default_system_msg { bytes_used += default_system_mesage.as_bytes().len(); } - for i in (0..post.messages.len()).rev() { - let bytes = post.messages[i].content.len(); + for i in (0..messages.len()).rev() { + let bytes = messages[i].content.len(); if !message_take[i] { if bytes_used + bytes < bytes_limit { message_take[i] = true; @@ -82,7 +82,7 @@ pub fn limit_messages_history_in_bytes( } } } - let mut messages_out: Vec = post.messages.iter().enumerate().filter(|(i, _)| message_take[*i]).map(|(_, x)| x.clone()).collect(); + let mut messages_out: Vec = messages.iter().enumerate().filter(|(i, _)| message_take[*i]).map(|(_, x)| x.clone()).collect(); if need_default_system_msg { messages_out.insert(0, ChatMessage { role: "system".to_string(), diff --git a/src/scratchpads/chat_utils_rag.rs b/src/scratchpads/chat_utils_rag.rs new file mode 100644 index 000000000..3e95f2a92 --- /dev/null +++ b/src/scratchpads/chat_utils_rag.rs @@ -0,0 +1,43 @@ +use crate::call_validation::{ChatMessage, ChatPost, ContextFile}; +use crate::vecdb::structs::{SearchResult, VecdbSearch}; + +pub async fn embed_vecdb_results( + vecdb: &T, + msgs: &Vec, + limit_examples_cnt: usize, +) -> Vec where T: VecdbSearch { + let latest_msg_cont = &msgs.last().unwrap().content; + let vdb_resp = vecdb.search(latest_msg_cont.clone(), limit_examples_cnt).await; + let vdb_cont = vecdb_resp_to_prompt(&vdb_resp); + if vdb_cont.is_ok() { + return [ + &msgs[..msgs.len() - 1], + &[ChatMessage { + role: "context_file".to_string(), + content: vdb_cont.unwrap(), + }], + &msgs[msgs.len() - 1..], + ].concat(); + } else { + return msgs.clone(); + } +} + +fn vecdb_resp_to_prompt( + resp: &Result +) -> serde_json::Result { + let context_files: Vec = match resp { + Ok(search_res) => { + search_res.results.iter().map( + |x| ContextFile { + file_name: x.file_path.to_str().unwrap().to_string(), + file_content: x.window_text.clone(), + line1: x.start_line as i32, + line2: x.end_line as i32, + } + ).collect() + } + Err(_) => vec![] + }; + serde_json::to_string(&context_files) +} diff --git a/src/scratchpads/completion_single_file_fim.rs b/src/scratchpads/completion_single_file_fim.rs index 814798224..da81fbbd4 100644 --- a/src/scratchpads/completion_single_file_fim.rs +++ b/src/scratchpads/completion_single_file_fim.rs @@ -4,6 +4,8 @@ use crate::call_validation::CodeCompletionPost; use crate::call_validation::SamplingParameters; use std::sync::Arc; use std::sync::RwLock as StdRwLock; +use std::vec; +use tokio::sync::Mutex as AMutex; // use ropey::RopeSlice; use tokenizers::Tokenizer; use ropey::Rope; @@ -13,12 +15,13 @@ use async_trait::async_trait; use crate::completion_cache; use crate::telemetry::telemetry_structs; use crate::telemetry::snippets_collection; +use crate::vecdb::structs::VecdbSearch; const DEBUG: bool = false; #[derive(Debug)] -pub struct SingleFileFIM { +pub struct SingleFileFIM { pub t: HasTokenizerAndEot, pub post: CodeCompletionPost, pub order: String, @@ -27,19 +30,24 @@ pub struct SingleFileFIM { pub fim_middle: String, pub data4cache: completion_cache::CompletionSaveToCache, pub data4snippet: snippets_collection::SaveSnippet, + pub vecdb_search: Arc>>, } -impl SingleFileFIM { +impl SingleFileFIM { pub fn new( tokenizer: Arc>, post: CodeCompletionPost, order: String, cache_arc: Arc>, tele_storage: Arc>, - ) -> Self { + vecdb_search: Arc>>, + ) -> Self where T: VecdbSearch + Send { let data4cache = completion_cache::CompletionSaveToCache::new(cache_arc, &post); let data4snippet = snippets_collection::SaveSnippet::new(tele_storage, &post); - SingleFileFIM { t: HasTokenizerAndEot::new(tokenizer), post, order, fim_prefix: String::new(), fim_suffix: String::new(), fim_middle: String::new(), data4cache, data4snippet } + SingleFileFIM { t: HasTokenizerAndEot::new(tokenizer), post, order, fim_prefix: String::new(), + fim_suffix: String::new(), fim_middle: String::new(), data4cache, data4snippet, + vecdb_search + } } fn cleanup_prompt(&mut self, text: &String) -> String { @@ -53,7 +61,7 @@ impl SingleFileFIM { #[async_trait] -impl ScratchpadAbstract for SingleFileFIM { +impl ScratchpadAbstract for SingleFileFIM { fn apply_model_adaptation_patch( &mut self, patch: &serde_json::Value, @@ -74,6 +82,7 @@ impl ScratchpadAbstract for SingleFileFIM { Ok(()) } + async fn prompt( &mut self, context_size: usize, @@ -98,6 +107,23 @@ impl ScratchpadAbstract for SingleFileFIM { let pos = &self.post.inputs.cursor; let mut before_iter = text.lines_at(pos.line as usize).reversed(); let mut after_iter = text.lines_at(pos.line as usize + 1); + let (extra_context, mut tokens_used) = match *self.vecdb_search.lock().await { + Some(ref db) => { + match self.post.no_cache || self.post.inputs.multiline { + true => { + let text_near_cursor = get_context_near_cursor(&text, pos.line as usize, 20); + search_vecdb( + db, + self.t.clone(), + text_near_cursor, + (limit as f32 * 0.5) as usize + ).await + } + false => (String::new(), 0) + } + } + None => (String::new(), 0) + }; let mut before_line = before_iter.next(); @@ -117,7 +143,7 @@ impl ScratchpadAbstract for SingleFileFIM { let mut before = vec![]; let mut after = String::new(); - let mut tokens_used = self.t.count_tokens( + tokens_used += self.t.count_tokens( (cursor_line1.clone() + &cursor_line2).as_str() )?; while before_line.is_some() || after_line.is_some() { @@ -146,9 +172,10 @@ impl ScratchpadAbstract for SingleFileFIM { let prompt: String; if self.order == "PSM" { prompt = format!( - "{}{}{}{}{}{}{}{}", + "{}{}{}{}{}{}{}{}{}", self.t.eos, self.fim_prefix, + extra_context, before.into_iter().rev().collect::>().join(""), cursor_line1, self.fim_suffix, @@ -158,9 +185,10 @@ impl ScratchpadAbstract for SingleFileFIM { ); } else if self.order == "SPM" { prompt = format!( - "{}{}{}{}{}{}{}{}", + "{}{}{}{}{}{}{}{}{}", self.t.eos, self.fim_suffix, + extra_context, cursor_line2, after, self.fim_prefix, @@ -259,6 +287,71 @@ impl ScratchpadAbstract for SingleFileFIM { } } +fn get_context_near_cursor(text: &Rope, line_pos: usize, max_lines_count: usize) -> String { + let mut before_iter = text.lines_at(line_pos).reversed(); + let mut after_iter = text.lines_at(line_pos + 1); + + let mut before = vec![]; + let mut after = vec![]; + let mut before_line = before_iter.next(); + let mut after_line = after_iter.next(); + + while (before.len() + after.len() < max_lines_count) && (before_line.is_some() || after_line.is_some()) { + if let Some(before_line) = before_iter.next() { + before.push(before_line.as_str().unwrap_or("")); + } + if let Some(after_line) = after_iter.next() { + after.push(after_line.as_str().unwrap_or("")); + } + before_line = before_iter.next(); + after_line = after_iter.next(); + } + + let before_str = before.into_iter().rev().collect::>().join(""); + let after_str = after.join(""); + + return format!("{}{}", before_str, after_str); +} + +async fn search_vecdb( + vecdb_search: &T, + tokenizer: HasTokenizerAndEot, + text_near_cursor: String, + max_context_size: usize +) -> (String, i32) where T: VecdbSearch + Send { + let search_result = vecdb_search.search(text_near_cursor, 20).await; + + let init_cfc_text = "Here are some relevant code fragments from other files of the repo:\n\n"; + let mut tokens_used = tokenizer.count_tokens(init_cfc_text).expect( + "Tokenization has failed" + ); + + match search_result { + Ok(res) => { + if res.results.is_empty() { + return ("".to_string(), tokens_used); + } + + let mut final_text_vec: Vec = vec![init_cfc_text.to_string()]; + for res in res.results { + let text: String = format!( + "The below code fragment is found in {}\n{}\n\n", + res.file_path.to_str().unwrap_or(""), res.window_text + ); + tokens_used += tokenizer.count_tokens(&text).expect( + "Tokenization has failed" + ); + final_text_vec.push(text); + if tokens_used > max_context_size as i32 { + break + } + } + (final_text_vec.join(""), tokens_used) + } + Err(_) => ("".to_string(), tokens_used) + } +} + fn cut_result(text: &str, eot_token: &str, multiline: bool) -> (String, bool) { let mut cut_at = vec![]; diff --git a/src/scratchpads/mod.rs b/src/scratchpads/mod.rs index a0bab2d30..a5498623c 100644 --- a/src/scratchpads/mod.rs +++ b/src/scratchpads/mod.rs @@ -10,6 +10,7 @@ pub mod chat_llama2; pub mod chat_passthrough; pub mod chat_utils_deltadelta; pub mod chat_utils_limit_history; +mod chat_utils_rag; use crate::call_validation::CodeCompletionPost; use crate::call_validation::ChatPost; @@ -18,14 +19,14 @@ use crate::caps::CodeAssistantCaps; use crate::scratchpad_abstract::ScratchpadAbstract; use crate::completion_cache; use crate::telemetry::telemetry_structs; -use crate::vecdb_search; use crate::cached_tokenizers; +use crate::vecdb::structs::VecdbSearch; fn verify_has_send(_x: &T) {} -pub async fn create_code_completion_scratchpad( +pub async fn create_code_completion_scratchpad( global_context: Arc>, caps: Arc>, model_name_for_tokenizer: String, @@ -34,13 +35,15 @@ pub async fn create_code_completion_scratchpad( scratchpad_patch: &serde_json::Value, cache_arc: Arc>, tele_storage: Arc>, -) -> Result, String> { + vecdb_search: Arc>>, +) -> Result, String> + where T: VecdbSearch + 'static + Sync { let mut result: Box; let tokenizer_arc: Arc> = cached_tokenizers::cached_tokenizer(caps, global_context, model_name_for_tokenizer).await?; if scratchpad_name == "FIM-PSM" { - result = Box::new(completion_single_file_fim::SingleFileFIM::new(tokenizer_arc, post, "PSM".to_string(), cache_arc, tele_storage)); + result = Box::new(completion_single_file_fim::SingleFileFIM::new(tokenizer_arc, post, "PSM".to_string(), cache_arc, tele_storage, vecdb_search)); } else if scratchpad_name == "FIM-SPM" { - result = Box::new(completion_single_file_fim::SingleFileFIM::new(tokenizer_arc, post, "SPM".to_string(), cache_arc, tele_storage)); + result = Box::new(completion_single_file_fim::SingleFileFIM::new(tokenizer_arc, post, "SPM".to_string(), cache_arc, tele_storage, vecdb_search)); } else { return Err(format!("This rust binary doesn't have code completion scratchpad \"{}\" compiled in", scratchpad_name)); } @@ -49,15 +52,16 @@ pub async fn create_code_completion_scratchpad( Ok(result) } -pub async fn create_chat_scratchpad( +pub async fn create_chat_scratchpad( global_context: Arc>, caps: Arc>, model_name_for_tokenizer: String, post: ChatPost, scratchpad_name: &str, scratchpad_patch: &serde_json::Value, - vecdb_search: Arc>>, -) -> Result, String> { + vecdb_search: Arc>>, +) -> Result, String> + where T: VecdbSearch + 'static + Sync { let mut result: Box; if scratchpad_name == "CHAT-GENERIC" { let tokenizer_arc: Arc> = cached_tokenizers::cached_tokenizer(caps, global_context, model_name_for_tokenizer).await?; diff --git a/src/vecdb/file_filter.rs b/src/vecdb/file_filter.rs new file mode 100644 index 000000000..f5f68abfd --- /dev/null +++ b/src/vecdb/file_filter.rs @@ -0,0 +1,132 @@ +use std::fs; +use std::io::Read; +use std::path::PathBuf; + +use async_process::Command; +use walkdir::WalkDir; +use which::which; + +const LARGE_FILE_SIZE_THRESHOLD: u64 = 1_000_000; +// 1 MB +const SMALL_FILE_SIZE_THRESHOLD: u64 = 10; // 10 Bytes + +const SOURCE_FILE_EXTENSIONS: &[&str] = &[ + "c", "cpp", "cc", "h", "hpp", "cs", "java", "py", "rb", "go", "rs", "swift", + "php", "js", "jsx", "ts", "tsx", "lua", "pl", "r", "sh", "bat", "cmd", "ps1", + "m", "kt", "kts", "groovy", "dart", "fs", "fsx", "fsi", "html", "htm", "css", + "scss", "sass", "less", "json", "xml", "yml", "yaml", "md", "sql", "db", "sqlite", + "mdf", "cfg", "conf", "ini", "toml", "dockerfile", "ipynb", "rmd", "swift", "java", + "xml", "kt", "xaml", "unity", "gd", "uproject", "uasset", "asm", "s", "tex", + "makefile", "mk", "cmake", "gradle", +]; + +pub fn is_valid_file(path: &PathBuf) -> bool { + // Check if the path points to a file + if !path.is_file() { + return false; + } + + // Check if the file is in a hidden directory + if path.ancestors().any(|ancestor| { + ancestor.file_name() + .map(|name| name.to_string_lossy().starts_with('.')) + .unwrap_or(false) + }) { + return false; + } + + // Check if the file is a source file + if let Some(extension) = path.extension() { + if !SOURCE_FILE_EXTENSIONS.contains(&extension.to_str().unwrap_or_default()) { + return false; + } + } else { + // No extension, not a source file + return false; + } + + // Check file size + if let Ok(metadata) = fs::metadata(path) { + let file_size = metadata.len(); + if file_size < SMALL_FILE_SIZE_THRESHOLD || file_size > LARGE_FILE_SIZE_THRESHOLD { + return false; + } + } else { + // Unable to access file metadata + return false; + } + + // Check for read permissions + if fs::read(&path).is_err() { + return false; + } + + // Check if the file is not UTF-8 + let mut file = match fs::File::open(&path) { + Ok(file) => file, + Err(_) => return false, + }; + let mut buffer = Vec::new(); + if file.read_to_end(&mut buffer).is_err() { + return false; + } + if String::from_utf8(buffer).is_err() { + return false; + } + + // All checks passed + true +} + +pub async fn get_control_version_files(path: &PathBuf) -> Option> { + if path.join(".git").exists() && which("git").is_ok() { + // Git repository + run_command("git", &["ls-files"], path).await + } else if path.join(".hg").exists() && which("hg").is_ok() { + // Mercurial repository + run_command("hg", &["status", "-c"], path).await + } else if path.join(".svn").exists() && which("svn").is_ok() { + // SVN repository + run_command("svn", &["list", "-R"], path).await + } else { + None + } +} + +async fn run_command(cmd: &str, args: &[&str], path: &PathBuf) -> Option> { + let output = Command::new(cmd) + .args(args) + .current_dir(path) + .output() + .await + .ok()?; + + if !output.status.success() { + return None; + } + + String::from_utf8(output.stdout) + .ok() + .map(|s| s.lines().map(|line| path.join(line)).collect()) +} + + +pub async fn retrieve_files_by_proj_folders(proj_folders: Vec) -> Vec { + let mut all_files: Vec = Vec::new(); + for proj_folder in proj_folders { + let maybe_files = get_control_version_files(&proj_folder).await; + if let Some(files) = maybe_files { + all_files.extend(files); + } else { + let files: Vec = WalkDir::new(proj_folder) + .into_iter() + .filter_map(|e| e.ok()) + .filter(|e| !e.path().is_dir()) + .filter(|e| is_valid_file(&e.path().to_path_buf())) + .map(|e| e.path().to_path_buf()) + .collect::>(); + all_files.extend(files); + } + } + all_files +} diff --git a/src/vecdb/file_splitter.rs b/src/vecdb/file_splitter.rs new file mode 100644 index 000000000..04f31f02c --- /dev/null +++ b/src/vecdb/file_splitter.rs @@ -0,0 +1,85 @@ +use std::path::PathBuf; +use tokio::fs::read_to_string; +use crate::vecdb::structs::SplitResult; +use md5; + +fn str_hash(s: &String) -> String { + let digest = md5::compute(s); + format!("{:x}", digest) +} + +pub struct FileSplitter { + soft_window: usize, + hard_window: usize, +} + +impl FileSplitter { + pub fn new(window_size: usize, soft_limit: usize) -> Self { + Self { + soft_window: window_size, + hard_window: window_size + soft_limit, + } + } + + pub async fn split(&self, file_path: &PathBuf) -> Result, String> { + let text = match read_to_string(file_path).await { + Ok(s) => s, + Err(e) => return Err(e.to_string()) + }; + + let mut chunks = Vec::new(); + let mut batch = Vec::new(); + let mut batch_size = 0; + let mut soft_batch = Vec::new(); + let mut current_line_number: u64 = 0; + for line in text.lines() { + batch_size += line.len(); + if batch_size > self.soft_window { + soft_batch.push(line.to_string()); + } else { + batch.push(line.to_string()); + } + + if batch_size >= self.hard_window { + let best_break_line_n = soft_batch.iter() + .rposition(|l| l.trim().is_empty()) + .unwrap_or(soft_batch.len()); + + let (remaining, to_next_batch) = soft_batch.split_at(best_break_line_n); + batch.extend_from_slice(remaining); + + let start_line = current_line_number; + let end_line = start_line + batch.len() as u64; + current_line_number += batch.len() as u64; + + chunks.push(SplitResult { + file_path: file_path.clone(), + window_text: batch.join("\n"), + window_text_hash: str_hash(&batch.join("\n")), + start_line, + end_line, + }); + + batch = to_next_batch.to_vec(); + soft_batch.clear(); + batch_size = batch.iter().map(|s| s.len()).sum(); + } + } + + if !batch.is_empty() || !soft_batch.is_empty() { + batch.extend(soft_batch); + let start_line = current_line_number; + let end_line = start_line + batch.len() as u64; + + chunks.push(SplitResult { + file_path: file_path.clone(), + window_text: batch.join("\n"), + window_text_hash: str_hash(&batch.join("\n")), + start_line, + end_line, + }); + } + + Ok(chunks) + } +} diff --git a/src/vecdb/file_watcher_service.rs b/src/vecdb/file_watcher_service.rs new file mode 100644 index 000000000..d7df459b2 --- /dev/null +++ b/src/vecdb/file_watcher_service.rs @@ -0,0 +1,107 @@ +use std::path::{Path, PathBuf}; +use std::sync::Arc; + +use futures::{ + channel::mpsc::{channel, Receiver}, + SinkExt, StreamExt, +}; +use log::info; +use notify::{Config, Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher}; +use serde_json::Value; +use tokio::fs::File; +use tokio::io::AsyncBufReadExt; +use tokio::io::BufReader; +use tokio::sync::RwLock as ARwLock; + +use crate::global_context::GlobalContext; + +fn make_async_watcher() -> notify::Result<(RecommendedWatcher, Receiver>)> { + let (mut tx, rx) = channel(1); + + let watcher = RecommendedWatcher::new( + move |res| { + futures::executor::block_on(async { + tx.send(res).await.unwrap(); + }) + }, + Config::default(), + )?; + + Ok((watcher, rx)) +} + + +async fn parse_jsonl(path: &PathBuf) -> Result, String> { + let file = File::open(path).await.map_err(|_| format!("File not found: {:?}", path))?; + let reader = BufReader::new(file); + let base_path = path.parent().or(Some(Path::new("/"))).unwrap().to_path_buf(); + + let mut lines = reader.lines(); + let mut paths = Vec::new(); + + while let Some(line) = lines.next_line().await.transpose() { + let line = line.map_err(|_| "Error reading line".to_string())?; + if let Ok(value) = serde_json::from_str::(&line) { + if value.is_object() { + if let Some(filename) = value.get("path").and_then(|v| v.as_str()) { + paths.push(base_path.join(filename)); + } + } + } + } + + Ok(paths) +} + +pub async fn file_watcher_task( + global_context: Arc>, +) -> () { + let (mut watcher, mut rx) = make_async_watcher().expect("Failed to make file watcher"); + let maybe_path = global_context.read().await.cmdline.files_set_path.clone(); + if maybe_path.is_empty() { + info!("file watcher: no files to watch"); + return; + } + let path = PathBuf::from(maybe_path); + let load_data = || async { + let filenames_data = match parse_jsonl(&path).await { + Ok(data) => data, + Err(_) => { + info!("invalid jsonl file: {:?}", path); + vec![] + } + }; + match *global_context.read().await.vec_db.lock().await { + Some(ref mut db) => db.add_or_update_files(filenames_data, true).await, + None => {} + }; + }; + + if watcher.watch(path.as_ref(), RecursiveMode::Recursive).is_err() { + info!("file watcher: {:?} is already watching", path); + return; + } + load_data().await; + while let Some(res) = rx.next().await { + match res { + Ok(event) => { + match event.kind { + EventKind::Any => {} + EventKind::Access(_) => {} + EventKind::Create(_) => { + info!("file {:?} was created", path) + } + EventKind::Modify(_) => { + load_data().await; + } + EventKind::Remove(_) => { + info!("file {:?} was removed", path) + // TODO: should we remove everything inside the database? + } + EventKind::Other => {} + } + } + Err(e) => info!("file watch error: {:?}", e), + } + } +} diff --git a/src/vecdb/handler.rs b/src/vecdb/handler.rs new file mode 100644 index 000000000..38c2c49a9 --- /dev/null +++ b/src/vecdb/handler.rs @@ -0,0 +1,554 @@ +use std::any::Any; +use std::collections::HashSet; +use std::fmt::{Debug, Formatter}; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Duration; +use std::time::SystemTime; + +use arrow::array::ArrayData; +use arrow::buffer::Buffer; +use arrow::compute::concat_batches; +use arrow_array::{FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, StringArray, UInt64Array}; +use arrow_array::cast::{as_fixed_size_list_array, as_primitive_array, as_string_array}; +use arrow_array::types::{Float32Type, UInt64Type}; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use futures_util::TryStreamExt; +use lance::dataset::{WriteMode, WriteParams}; +use log::info; +use tempfile::{tempdir, TempDir}; +use tokio::sync::Mutex; +use tracing::error; +use vectordb::database::Database; +use vectordb::table::Table; + +use crate::vecdb::structs::{Record, SplitResult}; + +pub type VecDBHandlerRef = Arc>; + +impl Debug for VecDBHandler { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "VecDBHandler: {:?}", self.cache_database.type_id()) + } +} + +pub struct VecDBHandler { + cache_database: Database, + data_database_temp_dir: TempDir, + cache_table: Table, + data_table: Table, + schema: SchemaRef, + data_table_hashes: HashSet, + embedding_size: i32, +} + +fn cosine_similarity(vec1: &Vec, vec2: &Vec) -> f32 { + let dot_product: f32 = vec1.iter().zip(vec2).map(|(x, y)| x * y).sum(); + let magnitude_vec1: f32 = vec1.iter().map(|x| x.powi(2)).sum::().sqrt(); + let magnitude_vec2: f32 = vec2.iter().map(|x| x.powi(2)).sum::().sqrt(); + dot_product / (magnitude_vec1 * magnitude_vec2) +} + +fn cosine_distance(vec1: &Vec, vec2: &Vec) -> f32 { + 1.0 - cosine_similarity(vec1, vec2) +} + +const TWO_WEEKS: i32 = 2 * 7 * 24 * 3600; +const ONE_MONTH: i32 = 30 * 24 * 3600; +const MIN_LIKES: i32 = 3; + +impl VecDBHandler { + pub async fn init(cache_dir: PathBuf, embedding_size: i32) -> Result { + let cache_dir_str = match cache_dir.join("refact_vecdb_cache").to_str() { + Some(dir) => dir.to_string(), + None => { + return Err(format!("{:?}", "Cache directory is not a valid path")); + } + }; + let data_database_temp_dir = match tempdir() { + Ok(dir) => dir, + Err(_) => return Err(format!("{:?}", "Error creating temp dir")), + }; + let data_database_temp_dir_str = match data_database_temp_dir.path().to_str() { + Some(path) => path, + None => return Err(format!("{:?}", "Temp directory is not a valid path")), + }; + + let cache_database = match Database::connect(cache_dir_str.as_str()).await { + Ok(db) => db, + Err(err) => return Err(format!("{:?}", err)) + }; + let temp_database = match Database::connect(data_database_temp_dir_str).await { + Ok(db) => db, + Err(err) => return Err(format!("{:?}", err)) + }; + + let vec_trait = Arc::new(Field::new("item", DataType::Float32, true)); + let schema = Arc::new(Schema::new(vec![ + Field::new("vector", DataType::FixedSizeList(vec_trait, embedding_size), true), + Field::new("window_text", DataType::Utf8, true), + Field::new("window_text_hash", DataType::Utf8, true), + Field::new("file_path", DataType::Utf8, true), + Field::new("start_line", DataType::UInt64, true), + Field::new("end_line", DataType::UInt64, true), + Field::new("time_added", DataType::UInt64, true), + Field::new("time_last_used", DataType::UInt64, true), + Field::new("model_name", DataType::Utf8, true), + Field::new("used_counter", DataType::UInt64, true), + ])); + let cache_table = match cache_database.open_table("data").await { + Ok(table) => { table } + Err(_) => { + let batches_iter = RecordBatchIterator::new(vec![].into_iter().map(Ok), schema.clone()); + match cache_database.create_table("data", batches_iter, Option::from(WriteParams::default())).await { + Ok(table) => table, + Err(err) => return Err(format!("{:?}", err)) + } + } + }; + let batches_iter = RecordBatchIterator::new(vec![].into_iter().map(Ok), schema.clone()); + let data_table = match temp_database.create_table("data", batches_iter, Option::from(WriteParams::default())).await { + Ok(table) => table, + Err(err) => return Err(format!("{:?}", err)) + }; + + Ok(VecDBHandler { + cache_database, + data_database_temp_dir, + schema, + cache_table, + data_table, + data_table_hashes: HashSet::new(), + embedding_size, + }) + } + + async fn checkout(&mut self) { + match self.data_table.checkout_latest().await { + Ok(table) => { self.data_table = table } + Err(err) => error!("Error while checking out data table: {:?}", err) + } + match self.cache_table.checkout_latest().await { + Ok(table) => { self.cache_table = table } + Err(err) => error!("Error while checking out data table: {:?}", err) + } + } + + pub async fn size(&self) -> Result { + match self.data_table.count_rows().await { + Ok(size) => Ok(size), + Err(err) => Err(format!("{:?}", err)) + } + } + + pub async fn cache_size(&self) -> Result { + match self.cache_table.count_rows().await { + Ok(size) => Ok(size), + Err(err) => Err(format!("{:?}", err)) + } + } + + async fn get_records(&mut self, table: Table, _hashes: Vec) -> (Vec, Vec) { + let mut hashes: HashSet = HashSet::from_iter(_hashes); + let q = hashes.iter().map(|x| format!("'{}'", x)).collect::>().join(", "); + let records = table + .filter(format!("window_text_hash in ({})", q)) + .execute() + .await.unwrap() + .try_collect::>() + .await.unwrap(); + let record_batch = concat_batches(&self.schema, &records).unwrap(); + let records = VecDBHandler::parse_table_iter(record_batch, true, None).unwrap(); + for r in &records { + hashes.remove(&r.window_text_hash); + } + (records, hashes.into_iter().collect()) + } + + pub async fn get_records_from_data(&mut self, hashes: Vec) -> (Vec, Vec) { + self.get_records(self.data_table.clone(), hashes).await + } + pub async fn get_records_from_cache(&mut self, hashes: Vec) -> (Vec, Vec) { + self.get_records(self.cache_table.clone(), hashes).await + } + + async fn get_record(&mut self, table: Table, hash: String) -> vectordb::error::Result { + let records = table + .filter(format!("window_text_hash == '{}'", hash)) + .execute() + .await? + .try_collect::>() + .await?; + let record_batch = concat_batches(&self.schema, &records)?; + let records = VecDBHandler::parse_table_iter(record_batch, true, None)?; + match records.get(0) { + Some(x) => Ok(x.clone()), + None => Err(vectordb::error::Error::Lance { message: format!("No record found for hash: {}", hash) }) + } + } + + pub async fn get_record_from_data(&mut self, hash: String) -> vectordb::error::Result { + self.get_record(self.data_table.clone(), hash).await + } + pub async fn get_record_from_cache(&mut self, hash: String) -> vectordb::error::Result { + self.get_record(self.cache_table.clone(), hash).await + } + + pub async fn try_add_from_cache(&mut self, data: Vec) -> Vec { + let hashes = data.iter().map(|x| x.window_text_hash.clone()).collect(); + let (found_records, left_hashes) = self.get_records_from_cache(hashes).await; + let left_results: Vec = + data.into_iter().filter(|x| left_hashes.contains(&x.window_text_hash)).collect(); + + match self.add_or_update(found_records, false).await { + Ok(_) => {} + Err(err) => info!("Error while adding values from cache: {:?}", err), + }; + left_results + } + + pub async fn add_or_update(&mut self, records: Vec, add_to_cache: bool) -> Result<(), String> { + fn make_emb_data(records: &Vec, embedding_size: i32) -> Result { + let vec_trait = Arc::new(Field::new("item", DataType::Float32, true)); + let mut emb_builder: Vec = vec![]; + + for record in records { + emb_builder.append(&mut record.vector.clone().expect("No embedding is provided")); + } + + let emb_data_res = ArrayData::builder(DataType::Float32) + .add_buffer(Buffer::from_vec(emb_builder)) + .len(records.len() * embedding_size as usize) + .build(); + let emb_data = match emb_data_res { + Ok(res) => res, + Err(err) => { return Err(format!("{:?}", err)); } + }; + + match ArrayData::builder(DataType::FixedSizeList(vec_trait.clone(), embedding_size)) + .len(records.len()) + .add_child_data(emb_data.clone()) + .build() { + Ok(res) => Ok(res), + Err(err) => return Err(format!("{:?}", err)) + } + } + + if records.is_empty() { + return Ok(()); + } + + let vectors: ArrayData = match make_emb_data(&records, self.embedding_size) { + Ok(res) => res, + Err(err) => return Err(format!("{:?}", err)) + }; + let window_texts: Vec = records.iter().map(|x| x.window_text.clone()).collect(); + let window_text_hashes: Vec = records.iter().map(|x| x.window_text_hash.clone()).collect(); + let file_paths: Vec = records.iter().map(|x| x.file_path.to_str().unwrap_or("No filename").to_string()).collect(); + let start_lines: Vec = records.iter().map(|x| x.start_line).collect(); + let end_lines: Vec = records.iter().map(|x| x.end_line).collect(); + let time_adds: Vec = records.iter().map(|x| x.time_added.duration_since(std::time::UNIX_EPOCH).unwrap_or( + Duration::from_secs(0) + ).as_secs()).collect(); + let time_last_used: Vec = records.iter().map(|x| x.time_last_used.duration_since(std::time::UNIX_EPOCH).unwrap_or( + Duration::from_secs(0) + ).as_secs()).collect(); + let model_names: Vec = records.iter().map(|x| x.model_name.clone()).collect(); + let used_counters: Vec = records.iter().map(|x| x.used_counter).collect(); + let data_batches_iter = RecordBatchIterator::new( + vec![RecordBatch::try_new( + self.schema.clone(), + vec![ + Arc::new(FixedSizeListArray::from(vectors.clone())), + Arc::new(StringArray::from(window_texts.clone())), + Arc::new(StringArray::from(window_text_hashes.clone())), + Arc::new(StringArray::from(file_paths.clone())), + Arc::new(UInt64Array::from(start_lines.clone())), + Arc::new(UInt64Array::from(end_lines.clone())), + Arc::new(UInt64Array::from(time_adds.clone())), + Arc::new(UInt64Array::from(time_last_used.clone())), + Arc::new(StringArray::from(model_names.clone())), + Arc::new(UInt64Array::from(used_counters.clone())), + ], + )], + self.schema.clone(), + ); + let cache_batches_iter = RecordBatchIterator::new( + vec![RecordBatch::try_new( + self.schema.clone(), + vec![ + Arc::new(FixedSizeListArray::from(vectors)), + Arc::new(StringArray::from(window_texts)), + Arc::new(StringArray::from(window_text_hashes.clone())), + Arc::new(StringArray::from(file_paths)), + Arc::new(UInt64Array::from(start_lines)), + Arc::new(UInt64Array::from(end_lines)), + Arc::new(UInt64Array::from(time_adds)), + Arc::new(UInt64Array::from(time_last_used)), + Arc::new(StringArray::from(model_names)), + Arc::new(UInt64Array::from(used_counters)), + ], + )], + self.schema.clone(), + ); + + if add_to_cache { + let cache_res = self.cache_table.add( + cache_batches_iter, Option::from(WriteParams { + mode: WriteMode::Append, + ..Default::default() + }), + ); + match cache_res.await { + Ok(_) => {} + Err(err) => return Err(format!("{:?}", err)) + }; + } + + let data_res = self.data_table.add( + data_batches_iter, Option::from(WriteParams { + mode: WriteMode::Append, + ..Default::default() + }), + ); + self.data_table_hashes.extend(window_text_hashes); + match data_res.await { + Ok(_) => Ok(()), + Err(err) => return Err(format!("{:?}", err)) + } + } + + pub async fn remove(&mut self, file_path: &PathBuf) { + let file_path_str = match file_path.to_str() { + None => { + info!("File path is not a string"); + return; + } + Some(res) => res + }; + + match self.cache_table.delete( + format!("(file_path = \"{}\")", file_path_str).as_str() // TODO: Prevent a possible sql injection here + ).await { + Ok(_) => {} + Err(err) => { + info!("Error while deleting from cache: {:?}", err); + } + } + match self.data_table.delete( + format!("(file_path = \"{}\")", file_path_str).as_str() // TODO: Prevent a possible sql injection here + ).await { + Ok(_) => {} + Err(err) => { + info!("Error while deleting from cache: {:?}", err); + } + } + } + + pub fn contains(&self, hash: &str) -> bool { + self.data_table_hashes.contains(hash) + } + + fn parse_table_iter( + record_batch: RecordBatch, + include_embedding: bool, + embedding_to_compare: Option<&Vec>, + ) -> vectordb::error::Result> { + (0..record_batch.num_rows()).map(|idx| { + let gathered_vec = as_primitive_array::( + &as_fixed_size_list_array(record_batch.column_by_name("vector").unwrap()) + .iter() + .map(|x| x.unwrap()) + .collect::>()[idx] + ) + .iter() + .map(|x| x.unwrap()).collect(); + let distance = match embedding_to_compare { + None => { -1.0 } + Some(embedding) => { cosine_distance(&embedding, &gathered_vec) } + }; + let embedding = match include_embedding { + true => Some(gathered_vec), + false => None + }; + + Ok(Record { + vector: embedding, + window_text: as_string_array(record_batch.column_by_name("window_text") + .expect("Missing column 'window_text'")) + .value(idx) + .to_string(), + window_text_hash: as_string_array(record_batch.column_by_name("window_text_hash") + .expect("Missing column 'window_text_hash'")) + .value(idx) + .to_string(), + file_path: PathBuf::from(as_string_array(record_batch.column_by_name("file_path") + .expect("Missing column 'file_path'")) + .value(idx) + .to_string()), + start_line: as_primitive_array::(record_batch.column_by_name("start_line") + .expect("Missing column 'start_line'")) + .value(idx), + end_line: as_primitive_array::(record_batch.column_by_name("end_line") + .expect("Missing column 'end_line'")) + .value(idx), + time_added: std::time::UNIX_EPOCH + Duration::from_secs( + as_primitive_array::( + record_batch.column_by_name("time_added") + .expect("Missing column 'time_added'")) + .value(idx) + ), + time_last_used: std::time::UNIX_EPOCH + Duration::from_secs( + as_primitive_array::( + record_batch.column_by_name("time_last_used") + .expect("Missing column 'time_last_used'")) + .value(idx) + ), + model_name: as_string_array(record_batch.column_by_name("model_name") + .expect("Missing column 'model_name'")) + .value(idx) + .to_string(), + used_counter: as_primitive_array::(record_batch.column_by_name("used_counter") + .expect("Missing column 'used_counter'")) + .value(idx), + distance, + }) + }).collect() + } + + pub async fn search(&mut self, embedding: Vec, top_n: usize) -> vectordb::error::Result> { + let query = self.data_table.clone() + .search(Some(Float32Array::from(embedding.clone()))) + .limit(top_n) + .use_index(true) + .execute() + .await? + .try_collect::>() + .await?; + let record_batch = concat_batches(&self.schema, &query)?; + VecDBHandler::parse_table_iter(record_batch, false, Some(&embedding)) + } + + pub async fn update_record_statistic(&mut self, records: Vec) { + let now = SystemTime::now(); + for record in records { + for mut table in vec![self.data_table.clone(), self.cache_table.clone()] { + let _ = table.update(Some(format!("window_text_hash == '{}'", record.window_text_hash.clone()).as_str()), + vec![ + ("used_counter", &(&record.used_counter + 1).to_string()), + ("time_last_used", &*now.elapsed().unwrap().as_secs().to_string()), + ]).await.unwrap(); + } + self.checkout().await; + } + } + pub async fn cleanup_old_records(&mut self) -> Result<(), String> { + let now = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap(); + let q = format!("{} - time_last_used > {TWO_WEEKS} AND used_counter < {MIN_LIKES}", now.as_secs()); + self.cache_table.delete(&*q).await.expect("could not delete old records"); + self.data_table.delete(&*q).await.expect("could not delete old records"); + self.checkout().await; + + let q = format!("{} - time_last_used > {ONE_MONTH}", now.as_secs()); + self.cache_table.delete(&*q).await.expect("could not delete old records"); + self.data_table.delete(&*q).await.expect("could not delete old records"); + self.checkout().await; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::time::SystemTime; + + use tempfile::tempdir; + use tokio; + + use super::*; + + #[tokio::test] + async fn test_init() { + let temp_dir = tempdir().unwrap(); + let embedding_size = 2; + let mut handler = VecDBHandler::init( + temp_dir.path().to_path_buf(), + embedding_size, + ).await; + assert_eq!(handler.size().await, 0); + } + + #[tokio::test] + async fn test_add_or_update() { + let temp_dir = tempdir().unwrap(); + let embedding_size = 2; + let mut handler = VecDBHandler::init( + temp_dir.path().to_path_buf(), + embedding_size, + ).await; + let expected_size = 1; + + // Prepare a sample record + let records = vec![ + Record { + vector: Some(vec![1.0, 2.0]), // Example values + window_text: "sample text".to_string(), + window_text_hash: "hash1".to_string(), + file_path: PathBuf::from("/path/to/file"), + start_line: 1, + end_line: 2, + time_added: SystemTime::now(), + time_last_used: SystemTime::now(), + model_name: "model1".to_string(), + used_counter: 0, + distance: 1.0, + }, + ]; + + // Call add_or_update + handler.add_or_update(records, true).await.unwrap(); + + // Validate the records + assert_eq!(handler.size().await, expected_size); + } + + #[tokio::test] + async fn test_search() { + let temp_dir = tempdir().unwrap(); + let embedding_size = 4; + let mut handler = VecDBHandler::init( + temp_dir.path().to_path_buf(), + embedding_size, + ).await; + let top_n = 1; + + let time_added = SystemTime::now(); + let records = vec![ + Record { + vector: Some(vec![1.0, 2.0, 3.0, 4.0]), + window_text: "test text".to_string(), + window_text_hash: "hash2".to_string(), + file_path: PathBuf::from("/path/to/another/file"), + start_line: 3, + end_line: 4, + time_added: time_added, + time_last_used: time_added, + model_name: "model2".to_string(), + used_counter: 0, + distance: 1.0, + }, + ]; + handler.add_or_update(records, true).await.unwrap(); + + let query_embedding = vec![1.0, 2.0, 3.0, 4.0]; + let results = handler.search(query_embedding, top_n).await.unwrap(); + + assert!(!results.is_empty()); + assert_eq!(results[0].window_text, "test text"); + assert_eq!(results[0].window_text_hash, "hash2"); + assert_eq!(results[0].file_path, PathBuf::from("/path/to/another/file")); + assert_eq!(results[0].start_line, 3); + assert_eq!(results[0].end_line, 4); + assert_eq!(results[0].model_name, "model2"); + assert_eq!(results[0].distance, 1.0); + } +} diff --git a/src/vecdb/mod.rs b/src/vecdb/mod.rs new file mode 100644 index 000000000..bc409f285 --- /dev/null +++ b/src/vecdb/mod.rs @@ -0,0 +1,10 @@ +mod file_splitter; +mod handler; +mod req_client; +mod vectorizer_service; + +pub mod file_watcher_service; +pub mod file_filter; +pub mod structs; +pub mod vecdb; +pub mod vecdb_remote; diff --git a/src/vecdb/req_client.rs b/src/vecdb/req_client.rs new file mode 100644 index 000000000..6624df831 --- /dev/null +++ b/src/vecdb/req_client.rs @@ -0,0 +1,201 @@ +use std::time::Duration; +use reqwest; +use serde::Serialize; +use tokio::task::JoinHandle; +use tokio::time::sleep; + +#[derive(Serialize)] +struct Payload { + pub inputs: String, +} + + +fn get_base_url() -> String { + #[cfg(test)] + { + // When running tests, return the mockito server URL + mockito::server_url() + } + + #[cfg(not(test))] + { + // In a regular run, return the actual URL + "https://api-inference.huggingface.co".to_string() + } +} + +pub fn get_embedding( + text: String, + model_name: &String, + api_key: String, +) -> JoinHandle, String>> { + let url = format!("{}/models/{}", get_base_url(), model_name); + let client = reqwest::Client::new(); + let payload = Payload { inputs: text }; + + tokio::spawn(async move { + let mut attempts = 0; + let max_attempts = 3; + let delay = Duration::from_secs(5); + + while attempts < max_attempts { + let maybe_response = client.post(&url) + .bearer_auth("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".clone()) + .json(&payload) + .send() + .await; + + match maybe_response { + Ok(response) => { + if response.status().is_success() { + match response.json::>().await { + Ok(embedding) => return Ok(embedding), + Err(err) => return Err(format!("Failed to parse the response: {:?}", err)), + } + } else if response.status().is_server_error() { + // Retry in case of 5xx server errors + attempts += 1; + sleep(delay).await; + continue; + } else { + return Err(format!("Failed to get a response: {:?}", response.status())); + } + }, + Err(err) => return Err(format!("Failed to send a request: {:?}", err)) + } + } + + Err("Exceeded maximum attempts to reach the server".to_string()) + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_valid_request() { + let _m = mockito::mock("POST", "/models/valid_model") + .with_status(200) + .with_body(r#"{"embedding": [1.0, 2.0, 3.0]}"#) + .create(); + + let text = "sample text".to_string(); + let model_name = "valid_model".to_string(); + let api_key = "valid_api_key".to_string(); + + let result = get_embedding(text, &model_name, api_key).await.unwrap(); + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), vec![1.0, 2.0, 3.0]); + } + + #[tokio::test] + async fn test_invalid_api_key() { + let _m = mockito::mock("POST", "/models/valid_model") + .with_status(401) + .create(); + + let text = "sample text".to_string(); + let model_name = "valid_model".to_string(); + let api_key = "invalid_api_key".to_string(); + + let result = get_embedding(text, &model_name, api_key).await.unwrap(); + + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_concurrent_requests() { + let mock = mockito::mock("POST", "/models/valid_model") + .with_status(200) + .with_body(r#"{"embedding": [1.0, 2.0, 3.0]}"#) + .expect(10) // Expect 10 calls + .create(); + + let handles: Vec<_> = (0..10).map(|_| { + let text = "sample text".to_string(); + let model_name = "valid_model".to_string(); + let api_key = "valid_api_key".to_string(); + + get_embedding(text, &model_name, api_key) + }).collect(); + + for handle in handles { + let result = handle.await.unwrap(); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), vec![1.0, 2.0, 3.0]); + } + + mock.assert(); + } + + #[tokio::test] + async fn test_empty_text_input() { + let _m = mockito::mock("POST", "/models/valid_model") + .with_status(200) + .with_body(r#"{"embedding": []}"#) + .create(); + + let text = "".to_string(); + let model_name = "valid_model".to_string(); + let api_key = "valid_api_key".to_string(); + + let result = get_embedding(text, &model_name, api_key).await.unwrap(); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Vec::::new()); + } + + #[tokio::test] + async fn test_invalid_model_name() { + let _m = mockito::mock("POST", "/models/invalid_model") + .with_status(404) + .create(); + + let text = "sample text".to_string(); + let model_name = "invalid_model".to_string(); + let api_key = "valid_api_key".to_string(); + + let result = get_embedding(text, &model_name, api_key).await.unwrap(); + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_network_failure() { + let _m = mockito::mock("POST", "/models/valid_model") + .with_status(500) // Internal Server Error to simulate server-side failure + .create(); + + let text = "sample text".to_string(); + let model_name = "valid_model".to_string(); + let api_key = "valid_api_key".to_string(); + + let result = get_embedding(text, &model_name, api_key).await.unwrap(); + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_different_embeddings() { + let mock1 = mockito::mock("POST", "/models/model1") + .with_status(200) + .with_body(r#"{"embedding": [1.0, 2.0]}"#) + .create(); + + let mock2 = mockito::mock("POST", "/models/model2") + .with_status(200) + .with_body(r#"{"embedding": [3.0, 4.0]}"#) + .create(); + + let text = "sample text".to_string(); + let model_names = vec!["model1", "model2"]; + let api_key = "valid_api_key".to_string(); + + for model_name in model_names { + let result = get_embedding(text.clone(), &model_name.to_string(), api_key.clone()).await.unwrap(); + assert!(result.is_ok()); + } + + mock1.assert(); + mock2.assert(); + } +} \ No newline at end of file diff --git a/src/vecdb/structs.rs b/src/vecdb/structs.rs new file mode 100644 index 000000000..feaeed2f8 --- /dev/null +++ b/src/vecdb/structs.rs @@ -0,0 +1,59 @@ +use std::fmt::Debug; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::SystemTime; + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use tokio::sync::Mutex; + +#[async_trait] +pub trait VecdbSearch: Send { + async fn search( + &self, + query: String, + top_n: usize, + ) -> Result; +} + + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct VecDbStatus { + pub unprocessed_files_count: usize, + pub requests_made_since_start: usize, + pub db_size: usize, + pub db_cache_size: usize, +} + +pub type VecDbStatusRef = Arc>; + + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Record { + pub vector: Option>, + pub window_text: String, + pub window_text_hash: String, + pub file_path: PathBuf, + pub start_line: u64, + pub end_line: u64, + pub time_added: SystemTime, + pub time_last_used: SystemTime, + pub model_name: String, + pub used_counter: u64, + pub distance: f32, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)] +pub struct SplitResult { + pub file_path: PathBuf, + pub window_text: String, + pub window_text_hash: String, + pub start_line: u64, + pub end_line: u64, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct SearchResult { + pub query_text: String, + pub results: Vec, +} diff --git a/src/vecdb/vecdb.rs b/src/vecdb/vecdb.rs new file mode 100644 index 000000000..3100172cc --- /dev/null +++ b/src/vecdb/vecdb.rs @@ -0,0 +1,95 @@ +use std::path::PathBuf; +use std::sync::Arc; + +use async_trait::async_trait; +use tokio::sync::Mutex; +use tokio::task::JoinHandle; + +use crate::vecdb::handler::{VecDBHandler, VecDBHandlerRef}; +use crate::vecdb::req_client::get_embedding; +use crate::vecdb::vectorizer_service::FileVectorizerService; +use crate::vecdb::structs::{SearchResult, VecdbSearch, VecDbStatus}; + +#[derive(Debug)] +pub struct VecDb { + vecdb_handler: VecDBHandlerRef, + retriever_service: Arc>, + embedding_model_name: String, + cmdline: crate::global_context::CommandLine, +} + + +impl VecDb { + pub async fn init( + cache_dir: PathBuf, + cmdline: crate::global_context::CommandLine, + embedding_size: i32, + cooldown_secs: u64, + splitter_window_size: usize, + splitter_soft_limit: usize, + embedding_model_name: String, + ) -> Result { + let handler = match VecDBHandler::init(cache_dir, embedding_size).await { + Ok(res) => res, + Err(err) => { return Err(err) } + }; + let vecdb_handler = Arc::new(Mutex::new(handler)); + let retriever_service = Arc::new(Mutex::new(FileVectorizerService::new( + vecdb_handler.clone(), cooldown_secs, splitter_window_size, splitter_soft_limit, + embedding_model_name.clone(), cmdline.api_key.clone(), + ).await)); + + Ok(VecDb { + vecdb_handler, + retriever_service, + embedding_model_name, + cmdline, + }) + } + + pub async fn start_background_tasks(&self) -> Vec> { + return self.retriever_service.lock().await.start_background_tasks().await; + } + + pub async fn add_or_update_file(&mut self, file_path: PathBuf, force: bool) { + self.retriever_service.lock().await.process_file(file_path, force).await; + } + + pub async fn add_or_update_files(&self, file_paths: Vec, force: bool) { + self.retriever_service.lock().await.process_files(file_paths, force).await; + } + + pub async fn remove_file(&self, file_path: &PathBuf) { + self.vecdb_handler.lock().await.remove(file_path).await; + } + + pub async fn get_status(&self) -> Result { + self.retriever_service.lock().await.status().await + } +} + + +#[async_trait] +impl VecdbSearch for VecDb { + async fn search(&self, query: String, top_n: usize) -> Result { + let embedding = get_embedding( + query.clone(), &self.embedding_model_name, self.cmdline.api_key.clone(), + ).await.unwrap(); + match embedding { + Ok(vector) => { + let mut binding = self.vecdb_handler.lock().await; + let results = binding.search(vector, top_n).await.unwrap(); + binding.update_record_statistic(results.clone()).await; + Ok( + SearchResult { + query_text: query, + results: results, + } + ) + } + Err(_) => { + return Err("Failed to get embedding".to_string()); + } + } + } +} diff --git a/src/vecdb/vecdb_remote.rs b/src/vecdb/vecdb_remote.rs new file mode 100644 index 000000000..c63e44b35 --- /dev/null +++ b/src/vecdb/vecdb_remote.rs @@ -0,0 +1,46 @@ +use async_trait::async_trait; +use reqwest::header::CONTENT_TYPE; +use reqwest::header::HeaderMap; +use reqwest::header::HeaderValue; +use serde_json::json; + +use crate::vecdb::structs::{SearchResult, VecdbSearch}; + +#[derive(Debug)] +pub struct VecDbRemote {} + +#[async_trait] +impl VecdbSearch for VecDbRemote { + async fn search( + &self, + query: String, + top_n: usize, + ) -> Result { + let url = "http://127.0.0.1:8008/v1/vdb-search".to_string(); + let mut headers = HeaderMap::new(); + // headers.insert(AUTHORIZATION, HeaderValue::from_str(&format!("Bearer {}", self.token)).unwrap()); + headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap()); + let body = json!({ + "text": query, + "top_n": top_n + }); + let res = reqwest::Client::new() + .post(&url) + .headers(headers) + .body(body.to_string()) + .send() + .await.map_err(|e| format!("Vecdb search HTTP error (1): {}", e))?; + + let body = res.text().await.map_err(|e| format!("Vecdb search HTTP error (2): {}", e))?; + // info!("Vecdb search result: {:?}", &body); + let result: Vec = serde_json::from_str(&body).map_err(|e| { + format!("vecdb JSON problem: {}", e) + })?; + if result.len() == 0 { + return Err("Vecdb search result is empty".to_string()); + } + let result0 = result[0].clone(); + // info!("Vecdb search result: {:?}", &result0); + Ok(result0) + } +} diff --git a/src/vecdb/vectorizer_service.rs b/src/vecdb/vectorizer_service.rs new file mode 100644 index 000000000..db2e631c9 --- /dev/null +++ b/src/vecdb/vectorizer_service.rs @@ -0,0 +1,258 @@ +use std::collections::{HashMap, VecDeque}; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::SystemTime; + +use tokio::sync::Mutex; +use tokio::task::JoinHandle; +use tracing::info; + +use crate::vecdb::file_splitter::FileSplitter; +use crate::vecdb::handler::VecDBHandlerRef; +use crate::vecdb::req_client::get_embedding; +use crate::vecdb::structs::{Record, SplitResult, VecDbStatus, VecDbStatusRef}; + +#[derive(Debug)] +pub struct FileVectorizerService { + update_request_queue: Arc>>, + output_queue: Arc>>, + vecdb_handler: VecDBHandlerRef, + status: VecDbStatusRef, + cooldown_secs: u64, + splitter_window_size: usize, + splitter_soft_limit: usize, + embedding_model_name: String, + api_key: String, +} + +async fn cooldown_queue_thread( + update_request_queue: Arc>>, + out_queue: Arc>>, + status: VecDbStatusRef, + cooldown_secs: u64, +) { + let mut last_updated: HashMap = HashMap::new(); + loop { + let (path_maybe, unprocessed_files_count) = { + let mut queue_locked = update_request_queue.lock().await; + if !queue_locked.is_empty() { + (Some(queue_locked.pop_front().unwrap()), queue_locked.len()) + } else { + (None, 0) + } + }; + + if let Some(path) = path_maybe { + last_updated.insert(path, SystemTime::now()); + } + + let mut paths_to_process: Vec = Vec::new(); + for (path, time) in &last_updated { + if time.elapsed().unwrap().as_secs() > cooldown_secs { + paths_to_process.push(path.clone()); + } + } + for path in paths_to_process { + last_updated.remove(&path); + out_queue.lock().await.push_back(path); + } + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + } +} + + +async fn vectorize_thread( + queue: Arc>>, + vecdb_handler_ref: VecDBHandlerRef, + status: VecDbStatusRef, + splitter_window_size: usize, + splitter_soft_limit: usize, + embedding_model_name: String, + api_key: String, +) { + let file_splitter = FileSplitter::new(splitter_window_size, splitter_soft_limit); + + loop { + let (path_maybe, unprocessed_files_count) = { + let mut queue_locked = queue.lock().await; + if !queue_locked.is_empty() { + (Some(queue_locked.pop_front().unwrap()), queue_locked.len()) + } else { + (None, 0) + } + }; + status.lock().await.unprocessed_files_count = unprocessed_files_count; + let path = { + match path_maybe { + Some(path) => path, + None => { + tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; + continue; + } + } + }; + info!("Processing file: {}", path.display()); + + let split_data = match file_splitter.split(&path).await { + Ok(data) => data, + Err(_) => { continue } + }; + + let mut vecdb_handler = vecdb_handler_ref.lock().await; + let mut split_data_filtered: Vec = split_data + .iter() + .filter(|x| !vecdb_handler.contains(&x.window_text_hash)) + .cloned() // Clone to avoid borrowing issues + .collect(); + split_data_filtered = vecdb_handler.try_add_from_cache(split_data_filtered).await; + drop(vecdb_handler); + info!("Retrieving embeddings for {} chunks", split_data_filtered.len()); + + let join_handles: Vec<_> = split_data_filtered.iter().map( + |x| get_embedding(x.window_text.clone(), &embedding_model_name, api_key.clone()) + ).collect(); + status.lock().await.requests_made_since_start += join_handles.len(); + + let mut split_join_data: VecDeque<(SplitResult, JoinHandle, String>>)> + = split_data_filtered.into_iter() + .zip(join_handles.into_iter()) + .collect::>(); + + let mut records: Vec = Vec::new(); + while let Some((data_res, handle)) = split_join_data.pop_front() { + match handle.await { + Ok(Ok(result)) => { + let now = SystemTime::now(); + records.push( + Record { + vector: Some(result), + window_text: data_res.window_text, + window_text_hash: data_res.window_text_hash, + file_path: data_res.file_path, + start_line: data_res.start_line, + end_line: data_res.end_line, + time_added: now, + time_last_used: now, + model_name: embedding_model_name.clone(), + used_counter: 0, + distance: -1.0, + } + ); + } + Ok(Err(e)) => { + info!("Error retrieving embeddings for {}: {}", data_res.window_text, e); + } + Err(_) => { continue; } + } + } + match vecdb_handler_ref.lock().await.add_or_update(records, true).await { + Err(e) => { + info!("Error adding/updating records in VecDB: {}", e); + } + _ => {} + } + } +} + +async fn cleanup_thread(vecdb_handler: VecDBHandlerRef) { + loop { + { + let mut vecdb = vecdb_handler.lock().await; + let _ = vecdb.cleanup_old_records().await; + } + tokio::time::sleep(tokio::time::Duration::from_secs(2 * 3600)).await; + } +} + +impl FileVectorizerService { + pub async fn new( + vecdb_handler: VecDBHandlerRef, + cooldown_secs: u64, + splitter_window_size: usize, + splitter_soft_limit: usize, + embedding_model_name: String, + api_key: String, + ) -> Self { + let update_request_queue = Arc::new(Mutex::new(VecDeque::new())); + let output_queue = Arc::new(Mutex::new(VecDeque::new())); + let status = Arc::new(Mutex::new( + VecDbStatus { + unprocessed_files_count: 0, + requests_made_since_start: 0, + db_size: 0, + db_cache_size: 0, + } + )); + FileVectorizerService { + update_request_queue: update_request_queue.clone(), + output_queue: output_queue.clone(), + vecdb_handler: vecdb_handler.clone(), + status: status.clone(), + cooldown_secs, + splitter_window_size, + splitter_soft_limit, + embedding_model_name, + api_key, + } + } + + pub async fn start_background_tasks(&self) -> Vec> { + let cooldown_queue_join_handle = tokio::spawn( + cooldown_queue_thread( + self.update_request_queue.clone(), + self.output_queue.clone(), + self.status.clone(), + self.cooldown_secs, + ) + ); + + let retrieve_thread_handle = tokio::spawn( + vectorize_thread( + self.output_queue.clone(), + self.vecdb_handler.clone(), + self.status.clone(), + self.splitter_window_size, + self.splitter_soft_limit, + self.embedding_model_name.clone(), + self.api_key.clone(), + ) + ); + + let cleanup_thread_handle = tokio::spawn( + cleanup_thread( + self.vecdb_handler.clone() + ) + ); + + return vec![cooldown_queue_join_handle, retrieve_thread_handle, cleanup_thread_handle]; + } + + pub async fn process_file(&self, path: PathBuf, force: bool) { + if !force { + self.update_request_queue.lock().await.push_back(path); + } else { + self.output_queue.lock().await.push_back(path); + } + } + + pub async fn process_files(&self, paths: Vec, force: bool) { + if !force { + self.update_request_queue.lock().await.extend(paths); + } else { + self.output_queue.lock().await.extend(paths); + } + } + + pub async fn status(&self) -> Result { + let mut status = self.status.lock().await.clone(); + status.db_size = match self.vecdb_handler.lock().await.size().await { + Ok(res) => res, + Err(err) => return Err(err) + }; + status.db_cache_size = match self.vecdb_handler.lock().await.cache_size().await { + Ok(res) => res, + Err(err) => return Err(err) + }; + Ok(status) + } +} diff --git a/src/vecdb_search.rs b/src/vecdb_search.rs deleted file mode 100644 index bdad0a4ab..000000000 --- a/src/vecdb_search.rs +++ /dev/null @@ -1,133 +0,0 @@ -use crate::call_validation::{ChatMessage, ChatPost}; -// use reqwest::header::AUTHORIZATION; -use reqwest::header::CONTENT_TYPE; -use reqwest::header::HeaderMap; -use reqwest::header::HeaderValue; -use serde::{Deserialize, Serialize}; -use serde_json::json; - -use std::sync::Arc; -use tokio::sync::Mutex as AMutex; -use async_trait::async_trait; - - -#[derive(Serialize, Deserialize, Clone, Debug, Default)] -pub struct VecdbResultRec { - pub file_name: String, - pub text: String, - pub score: String, -} - -#[derive(Serialize, Deserialize, Clone, Debug, Default)] -pub struct VecdbResult { - pub results: Vec, -} - -// FIXME: bad idea -pub async fn embed_vecdb_results( - vecdb_search: Arc>>, - post: &mut ChatPost, - limit_examples_cnt: usize, -) { - let my_vdb = vecdb_search.clone(); - let latest_msg_cont = &post.messages.last().unwrap().content; - let mut vecdb_locked = my_vdb.lock().await; - let vdb_resp = vecdb_locked.search(&latest_msg_cont).await; - let vdb_cont = vecdb_resp_to_prompt(&vdb_resp, limit_examples_cnt); - if vdb_cont.len() > 0 { - post.messages = [ - &post.messages[..post.messages.len() -1], - &[ChatMessage { - role: "user".to_string(), - content: vdb_cont, - }], - &post.messages[post.messages.len() -1..], - ].concat(); - } -} - -// FIXME: bad idea -fn vecdb_resp_to_prompt( - resp: &Result, - limit_examples_cnt: usize, -) -> String { - let mut cont = "".to_string(); - match resp { - Ok(resp) => { - cont.push_str("CONTEXT:\n"); - for i in 0..limit_examples_cnt { - if i >= resp.results.len() { - break; - } - cont.push_str("FILENAME:\n"); - cont.push_str(resp.results[i].file_name.clone().as_str()); - cont.push_str("\nTEXT:"); - cont.push_str(resp.results[i].text.clone().as_str()); - cont.push_str("\n"); - } - cont.push_str("\nRefer to the context to answer my next question.\n"); - cont - } - Err(e) => { - format!("Vecdb error: {}", e); - cont - } - } -} - -#[async_trait] -pub trait VecdbSearch: Send { - async fn search( - &mut self, - query: &str, - ) -> Result; -} - -#[derive(Debug, Clone)] -pub struct VecdbSearchTest { -} - -impl VecdbSearchTest { - pub fn new() -> Self { - VecdbSearchTest { - } - } -} - -// unsafe impl Send for VecdbSearchTest {} - -#[async_trait] -impl VecdbSearch for VecdbSearchTest { - async fn search( - &mut self, - query: &str, - ) -> Result { - let url = "http://127.0.0.1:8008/v1/vdb-search".to_string(); - let mut headers = HeaderMap::new(); - // headers.insert(AUTHORIZATION, HeaderValue::from_str(&format!("Bearer {}", self.token)).unwrap()); - headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap()); - let body = json!({ - "texts": [query], - "account": "XXX", - "top_k": 3, - }); - let res = reqwest::Client::new() - .post(&url) - .headers(headers) - .body(body.to_string()) - .send() - .await.map_err(|e| format!("Vecdb search HTTP error (1): {}", e))?; - - let body = res.text().await.map_err(|e| format!("Vecdb search HTTP error (2): {}", e))?; - // info!("Vecdb search result: {:?}", &body); - let result: Vec = serde_json::from_str(&body).map_err(|e| { - format!("vecdb JSON problem: {}", e) - })?; - if result.len() == 0 { - return Err("Vecdb search result is empty".to_string()); - } - let result0 = result[0].clone(); - // info!("Vecdb search result: {:?}", &result0); - Ok(result0) - } -} From e3f774e69b7fb23d02b564dc02dfaebc7f97e41e Mon Sep 17 00:00:00 2001 From: Valeryi Date: Tue, 19 Dec 2023 21:52:16 +0000 Subject: [PATCH 02/54] WIP fix WIP fixed deps! changes requested by @olegklimov simplified get_embedding_openai_style and get_embedding_hf_style: removed attempts removed hardcoded hf_url; fixed some naming restoring order in args & naming vecdb is now dynamically reloading implemented changes request from @olegklimov minor minor added endpoint_embeddings_style init changelist: * added additional tries to embedding requests' * added init folders as a vecdb's method. Calling from vecdb_background_reload and LSP::initialize changelist: * creating new database if emb_size if different * if emb_size is equal, deleting embeddings creating for another model changelist: * removed remove_where_old_model * creating dedicated db_dir using key (model_name, embedding_size) vecdb is reloading only if caps changes. It checks for it every 30s bugfix Changelist: * chat_rag_functions are filled within vecdb init; * vecdb is called only when @workspace command is present in the latest message --- src/background_tasks.rs | 2 +- src/caps.rs | 18 ++ .../req_client.rs => fetch_embedding.rs} | 102 ++++---- src/forward_to_hf_endpoint.rs | 44 ++++ src/forward_to_openai_endpoint.rs | 60 +++++ src/global_context.rs | 22 +- src/lsp.rs | 9 +- src/main.rs | 8 +- src/restream.rs | 18 ++ src/scratchpad_abstract.rs | 2 + src/scratchpads/chat_generic.rs | 21 +- src/scratchpads/chat_llama2.rs | 20 +- src/scratchpads/chat_passthrough.rs | 19 +- src/scratchpads/chat_utils_limit_history.rs | 16 +- src/scratchpads/chat_utils_rag.rs | 173 ++++++++++--- src/scratchpads/completion_single_file_fim.rs | 3 + src/vecdb/file_watcher_service.rs | 2 +- src/vecdb/handler.rs | 10 +- src/vecdb/mod.rs | 2 - src/vecdb/vecdb.rs | 237 +++++++++++++++--- src/vecdb/vectorizer_service.rs | 126 ++++++---- 21 files changed, 683 insertions(+), 231 deletions(-) rename src/{vecdb/req_client.rs => fetch_embedding.rs} (67%) diff --git a/src/background_tasks.rs b/src/background_tasks.rs index 8f3edc3e1..ae0e3f045 100644 --- a/src/background_tasks.rs +++ b/src/background_tasks.rs @@ -45,6 +45,6 @@ pub fn start_background_tasks(global_context: Arc>) -> Ba BackgroundTasksHolder::new(vec![ tokio::spawn(basic_transmit::telemetry_background_task(global_context.clone())), tokio::spawn(snippets_transmit::tele_snip_background_task(global_context.clone())), - tokio::spawn(vecdb::file_watcher_service::file_watcher_task(global_context.clone())), + tokio::spawn(vecdb::vecdb::vecdb_background_reload(global_context.clone())), ]) } diff --git a/src/caps.rs b/src/caps.rs index 8d3470def..f3061a6e4 100644 --- a/src/caps.rs +++ b/src/caps.rs @@ -44,9 +44,19 @@ pub struct CodeAssistantCaps { #[serde(default)] pub code_chat_models: HashMap, pub code_chat_default_model: String, + #[serde(default)] + pub default_embeddings_model: String, + #[serde(default)] + pub endpoint_embeddings_template: String, + #[serde(default)] + pub endpoint_embeddings_style: String, + #[serde(default)] + pub size_embeddings: i32, pub running_models: Vec, #[serde(default)] pub caps_version: i64, // need to reload if it increases on server, that happens when server configuration changes + #[serde(default)] + pub chat_rag_functions: Vec, } #[derive(Debug, Deserialize)] @@ -292,6 +302,12 @@ const HF_DEFAULT_CAPS: &str = r#" "cloud_name": "Hugging Face", "endpoint_template": "https://api-inference.huggingface.co/models/$MODEL", "endpoint_style": "hf", + + "default_embeddings_model": "BAAI/bge-small-en-v1.5", + "endpoint_embeddings_template": "https://api-inference.huggingface.co/models/$MODEL", + "endpoint_embeddings_style": "hf", + "size_embeddings": 384, + "tokenizer_path_template": "https://huggingface.co/$MODEL/resolve/main/tokenizer.json", "tokenizer_rewrite_path": { "meta-llama/Llama-2-70b-chat-hf": "TheBloke/Llama-2-70B-fp16" @@ -364,6 +380,7 @@ pub async fn load_caps( r1.endpoint_chat_passthrough = relative_to_full_url(&caps_url, &r1.endpoint_chat_passthrough)?; r1.telemetry_basic_dest = relative_to_full_url(&caps_url, &r1.telemetry_basic_dest)?; r1.telemetry_corrected_snippets_dest = relative_to_full_url(&caps_url, &r1.telemetry_corrected_snippets_dest)?; + r1.endpoint_embeddings_template = relative_to_full_url(&caps_url, &r1.endpoint_embeddings_template)?; info!("caps {} completion models", r1.code_completion_models.len()); info!("caps default completion model: \"{}\"", r1.code_completion_default_model); info!("caps {} chat models", r1.code_chat_models.len()); @@ -371,6 +388,7 @@ pub async fn load_caps( Ok(Arc::new(StdRwLock::new(r1))) } + fn relative_to_full_url( caps_url: &String, maybe_relative_url: &str, diff --git a/src/vecdb/req_client.rs b/src/fetch_embedding.rs similarity index 67% rename from src/vecdb/req_client.rs rename to src/fetch_embedding.rs index 6624df831..163819c5e 100644 --- a/src/vecdb/req_client.rs +++ b/src/fetch_embedding.rs @@ -1,72 +1,58 @@ -use std::time::Duration; use reqwest; use serde::Serialize; -use tokio::task::JoinHandle; -use tokio::time::sleep; - -#[derive(Serialize)] -struct Payload { - pub inputs: String, -} +use tracing::error; +use crate::forward_to_hf_endpoint::get_embedding_hf_style; +use crate::forward_to_openai_endpoint::get_embedding_openai_style; -fn get_base_url() -> String { - #[cfg(test)] - { - // When running tests, return the mockito server URL - mockito::server_url() - } - #[cfg(not(test))] - { - // In a regular run, return the actual URL - "https://api-inference.huggingface.co".to_string() +pub async fn get_embedding( + endpoint_embeddings_style: &String, + model_name: &String, + endpoint_template: &String, + text: String, + api_key: &String, +) -> Result, String> { + match endpoint_embeddings_style.to_lowercase().as_str() { + "hf" => get_embedding_hf_style(text, endpoint_template, model_name, api_key).await, + "openai" => get_embedding_openai_style(text, endpoint_template, model_name, api_key).await, + _ => { + error!("Invalid endpoint_embeddings_style: {}", endpoint_embeddings_style); + Err("Invalid endpoint_embeddings_style".to_string()) + } } } -pub fn get_embedding( - text: String, + +// HF often returns 500 errors for no reason +pub async fn try_get_embedding( + endpoint_embeddings_style: &String, model_name: &String, - api_key: String, -) -> JoinHandle, String>> { - let url = format!("{}/models/{}", get_base_url(), model_name); - let client = reqwest::Client::new(); - let payload = Payload { inputs: text }; - - tokio::spawn(async move { - let mut attempts = 0; - let max_attempts = 3; - let delay = Duration::from_secs(5); - - while attempts < max_attempts { - let maybe_response = client.post(&url) - .bearer_auth("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".clone()) - .json(&payload) - .send() - .await; - - match maybe_response { - Ok(response) => { - if response.status().is_success() { - match response.json::>().await { - Ok(embedding) => return Ok(embedding), - Err(err) => return Err(format!("Failed to parse the response: {:?}", err)), - } - } else if response.status().is_server_error() { - // Retry in case of 5xx server errors - attempts += 1; - sleep(delay).await; - continue; - } else { - return Err(format!("Failed to get a response: {:?}", response.status())); - } - }, - Err(err) => return Err(format!("Failed to send a request: {:?}", err)) + endpoint_template: &String, + text: String, + api_key: &String, + max_retries: usize, +) -> Result, String> { + let sleep_on_failure_ms = 300; + let mut retries = 0; + loop { + retries += 1; + match get_embedding( + endpoint_embeddings_style, + model_name, + endpoint_template, + text.clone(), + api_key, + ).await { + Ok(embedding) => return Ok(embedding), + Err(e) => { + tokio::time::sleep(tokio::time::Duration::from_millis(sleep_on_failure_ms)).await; + if retries > max_retries { + return Err(e); + } } } - - Err("Exceeded maximum attempts to reach the server".to_string()) - }) + } } #[cfg(test)] diff --git a/src/forward_to_hf_endpoint.rs b/src/forward_to_hf_endpoint.rs index 1593d5d34..a72df2271 100644 --- a/src/forward_to_hf_endpoint.rs +++ b/src/forward_to_hf_endpoint.rs @@ -3,6 +3,7 @@ use reqwest::header::CONTENT_TYPE; use reqwest::header::HeaderMap; use reqwest::header::HeaderValue; use reqwest_eventsource::EventSource; +use serde::Serialize; use serde_json::json; use crate::call_validation::SamplingParameters; @@ -85,3 +86,46 @@ pub async fn forward_to_hf_style_endpoint_streaming( )?; Ok(event_source) } + + +#[derive(Serialize)] +struct EmbeddingsPayloadHF { + pub inputs: String, +} + + +pub async fn get_embedding_hf_style( + text: String, + endpoint_template: &String, + model_name: &String, + api_key: &String, +) -> Result, String> { + let client = reqwest::Client::new(); + let payload = EmbeddingsPayloadHF { inputs: text }; + let url = endpoint_template.clone().replace("$MODEL", &model_name); + let api_key_clone = api_key.clone(); + + let join_handle = tokio::task::spawn(async move { + let maybe_response = client + .post(&url) + .bearer_auth(api_key_clone.clone()) + .json(&payload) + .send() + .await; + + return match maybe_response { + Ok(response) => { + if response.status().is_success() { + match response.json::>().await { + Ok(embedding) => Ok(embedding), + Err(err) => Err(format!("Failed to parse the response: {:?}", err)), + } + } else { + Err(format!("Failed to get a response: {:?}", response.status())) + } + }, + Err(err) => Err(format!("Failed to send a request: {:?}", err)), + } + }); + join_handle.await.unwrap_or_else(|_| Err("Task join error".to_string())) +} diff --git a/src/forward_to_openai_endpoint.rs b/src/forward_to_openai_endpoint.rs index aae1a50de..7f3098413 100644 --- a/src/forward_to_openai_endpoint.rs +++ b/src/forward_to_openai_endpoint.rs @@ -3,6 +3,7 @@ use reqwest::header::CONTENT_TYPE; use reqwest::header::HeaderMap; use reqwest::header::HeaderValue; use reqwest_eventsource::EventSource; +use serde::Serialize; use serde_json::json; use crate::call_validation; use crate::call_validation::SamplingParameters; @@ -109,3 +110,62 @@ fn _passthrough_messages_to_json( let messages: Vec = serde_json::from_str(&messages_str).unwrap(); data["messages"] = serde_json::json!(messages); } + + +#[derive(Serialize)] +struct EmbeddingsPayloadOpenAI { + pub input: String, + pub model: String, +} + + +pub async fn get_embedding_openai_style( + text: String, + endpoint_template: &String, + model_name: &String, + api_key: &String, +) -> Result, String> { + let client = reqwest::Client::new(); + let payload = EmbeddingsPayloadOpenAI { + input: text, + model: model_name.clone(), + }; + let url = endpoint_template.clone(); + let api_key_clone = api_key.clone(); + + let join_handle = tokio::spawn(async move { + let maybe_response = client + .post(&url) + .bearer_auth(api_key_clone.clone()) + .json(&payload) + .send() + .await; + + return match maybe_response { + Ok(response) => { + if response.status().is_success() { + let response_json = response.json::().await; + + match response_json { + Ok(json) => match &json["data"][0]["embedding"] { + serde_json::Value::Array(embedding) => { + let embedding_values: Result, _> = + serde_json::from_value(serde_json::Value::Array(embedding.clone())); + embedding_values.map_err(|err| { + format!("Failed to parse the response: {:?}", err) + }) + } + _ => Err("Response is missing 'data[0].embedding' field or it's not an array".to_string()), + }, + Err(err) => Err(format!("Failed to parse the response: {:?}", err)), + } + } else { + Err(format!("Failed to get a response: {:?}", response.status())) + } + } + Err(err) => Err(format!("Failed to send a request: {:?}", err)), + } + }); + + join_handle.await.unwrap_or_else(|_| Err("Task join error".to_string())) +} diff --git a/src/global_context.rs b/src/global_context.rs index 1c5164576..8598e1de2 100644 --- a/src/global_context.rs +++ b/src/global_context.rs @@ -164,26 +164,6 @@ pub async fn create_global_context( http_client_builder = http_client_builder.danger_accept_invalid_certs(true) } let http_client = http_client_builder.build().unwrap(); - let mut vec_db: Option = None; - if cmdline.vecdb { - vec_db = match VecDb::init( - cache_dir.clone(), cmdline.clone(), - 384, 60, 512, 1024, - "BAAI/bge-small-en-v1.5".to_string(), - ).await { - Ok(res) => Some(res), - Err(err) => { - error!("Ooops database is broken! - Last error message: {} - You can report this issue here: - https://github.com/smallcloudai/refact-lsp/issues - Also you can run this to erase your db: - `rm -rf ~/.cache/refact/refact_vecdb_cache` - After that restart this LSP server or your IDE.", err); - None - } - }; - } let cx = GlobalContext { @@ -196,7 +176,7 @@ pub async fn create_global_context( tokenizer_map: HashMap::new(), completions_cache: Arc::new(StdRwLock::new(CompletionCache::new())), telemetry: Arc::new(StdRwLock::new(telemetry_structs::Storage::new())), - vec_db: Arc::new(AMutex::new(vec_db)), + vec_db: Arc::new(AMutex::new(None)), ask_shutdown_sender: Arc::new(Mutex::new(ask_shutdown_sender)), lsp_backend_document_state: LSPBackendDocumentState { document_map: Arc::new(ARwLock::new(HashMap::new())), diff --git a/src/lsp.rs b/src/lsp.rs index 86bbd0924..95bef38cd 100644 --- a/src/lsp.rs +++ b/src/lsp.rs @@ -201,14 +201,9 @@ impl LanguageServer for Backend { } if let Some(folders) = params.workspace_folders { - let files = retrieve_files_by_proj_folders( - folders.iter().map(|x| PathBuf::from(x.uri.path())).collect() - ).await; match *self.gcx.read().await.vec_db.lock().await { - Some(ref db) => db.add_or_update_files(files, true).await, - None => { - info!("LSP no vec_db"); - } + Some(ref db) => db.init_folders(folders).await, + None => {}, }; } diff --git a/src/main.rs b/src/main.rs index ebd1dcd73..b86a3d58e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,6 +3,8 @@ use tokio::task::JoinHandle; use tracing::{error, info, Level}; use tracing_appender; +use std::sync::Arc; +use tokio::sync::Mutex as AMutex; use crate::background_tasks::start_background_tasks; use crate::lsp::spawn_lsp_task; @@ -25,6 +27,7 @@ mod http; mod background_tasks; mod receive_workspace_changes; mod vecdb; +mod fetch_embedding; #[tokio::main] @@ -61,11 +64,6 @@ async fn main() { } let mut background_tasks = start_background_tasks(gcx.clone()); - background_tasks.extend(match *gcx.read().await.vec_db.lock().await { - Some(ref db) => db.start_background_tasks().await, - None => vec![] - }); - let should_start_http = cmdline.http_port != 0; let should_start_lsp = (cmdline.lsp_port == 0 && cmdline.lsp_stdin_stdout == 1) || (cmdline.lsp_port != 0 && cmdline.lsp_stdin_stdout == 0); diff --git a/src/restream.rs b/src/restream.rs index 841b6887a..3c786cf7e 100644 --- a/src/restream.rs +++ b/src/restream.rs @@ -153,6 +153,24 @@ pub async fn scratchpad_interaction_stream( let mut save_url: String = String::new(); let permit = slowdown_arc.acquire().await; loop { + loop { + let value_maybe = scratch.response_spontaneous(); + if let Ok(value) = value_maybe { + if value == json!(null) { + break; + } + let value_str = format!("data: {}\n\n", serde_json::to_string(&value).unwrap()); + info!("yield: {:?}", value_str); + yield Result::<_, String>::Ok(value_str); + } else { + let err_str = value_maybe.unwrap_err(); + error!("response_spontaneous error: {}", err_str); + let value_str = format!("data: {}\n\n", serde_json::to_string(&json!({"detail": err_str})).unwrap()); + yield Result::<_, String>::Ok(value_str); + } + break; + } + let event_source_maybe = if endpoint_style == "hf" { forward_to_hf_endpoint::forward_to_hf_style_endpoint_streaming( &mut save_url, diff --git a/src/scratchpad_abstract.rs b/src/scratchpad_abstract.rs index 109c31d6f..9d2ce9e08 100644 --- a/src/scratchpad_abstract.rs +++ b/src/scratchpad_abstract.rs @@ -31,6 +31,8 @@ pub trait ScratchpadAbstract: Send { stop_toks: bool, stop_length: bool, ) -> Result<(serde_json::Value, bool), String>; + + fn response_spontaneous(&mut self) -> Result; } diff --git a/src/scratchpads/chat_generic.rs b/src/scratchpads/chat_generic.rs index 46c3cf969..13c8109c5 100644 --- a/src/scratchpads/chat_generic.rs +++ b/src/scratchpads/chat_generic.rs @@ -11,7 +11,7 @@ use crate::scratchpad_abstract::HasTokenizerAndEot; use crate::scratchpad_abstract::ScratchpadAbstract; use crate::scratchpads::chat_utils_deltadelta::DeltaDeltaChatStreamer; use crate::scratchpads::chat_utils_limit_history::limit_messages_history; -use crate::scratchpads::chat_utils_rag::embed_vecdb_results; +use crate::scratchpads::chat_utils_rag::{chat_functions_middleware, HasVecdb, HasVecdbResults}; use crate::vecdb::structs::VecdbSearch; const DEBUG: bool = true; @@ -27,6 +27,7 @@ pub struct GenericChatScratchpad { pub keyword_asst: String, pub default_system_message: String, pub vecdb_search: Arc>>, + pub has_vecdb_results: HasVecdbResults, } impl GenericChatScratchpad { @@ -44,7 +45,8 @@ impl GenericChatScratchpad { keyword_user: "".to_string(), keyword_asst: "".to_string(), default_system_message: "".to_string(), - vecdb_search + vecdb_search, + has_vecdb_results: HasVecdbResults::new(), } } } @@ -84,11 +86,12 @@ impl ScratchpadAbstract for GenericChatScratchpad< context_size: usize, sampling_parameters_to_patch: &mut SamplingParameters, ) -> Result { - let augmented_msgs = match *self.vecdb_search.lock().await { - Some(ref db) => embed_vecdb_results(db, &self.post.messages, 6).await, - None => { self.post.messages.clone() } - }; - let limited_msgs: Vec = limit_messages_history(&self.t, &augmented_msgs, self.post.parameters.max_new_tokens, context_size, &self.default_system_message)?; + match *self.vecdb_search.lock().await { + Some(ref db) => chat_functions_middleware(db, &mut self.post, 6, &mut self.has_vecdb_results).await, + None => {} + } + + let limited_msgs: Vec = limit_messages_history(&self.t, &self.post.messages, self.post.parameters.max_new_tokens, context_size, &self.default_system_message)?; sampling_parameters_to_patch.stop = Some(self.dd.stop_list.clone()); // adapted from https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/model.py#L24 let mut prompt = "".to_string(); @@ -148,5 +151,9 @@ impl ScratchpadAbstract for GenericChatScratchpad< ) -> Result<(serde_json::Value, bool), String> { self.dd.response_streaming(delta, stop_toks) } + + fn response_spontaneous(&mut self) -> Result { + return self.has_vecdb_results.response_streaming(); + } } diff --git a/src/scratchpads/chat_llama2.rs b/src/scratchpads/chat_llama2.rs index 1732f7c9a..5f9f5597a 100644 --- a/src/scratchpads/chat_llama2.rs +++ b/src/scratchpads/chat_llama2.rs @@ -11,7 +11,7 @@ use crate::scratchpad_abstract::HasTokenizerAndEot; use crate::scratchpad_abstract::ScratchpadAbstract; use crate::scratchpads::chat_utils_deltadelta::DeltaDeltaChatStreamer; use crate::scratchpads::chat_utils_limit_history::limit_messages_history; -use crate::scratchpads::chat_utils_rag::embed_vecdb_results; +use crate::scratchpads::chat_utils_rag::{chat_functions_middleware, HasVecdb, HasVecdbResults}; use crate::vecdb::structs::VecdbSearch; const DEBUG: bool = true; @@ -26,6 +26,7 @@ pub struct ChatLlama2 { pub keyword_slash_s: String, pub default_system_message: String, pub vecdb_search: Arc>>, + pub has_vecdb_results: HasVecdbResults, } @@ -42,7 +43,8 @@ impl ChatLlama2 { keyword_s: "".to_string(), keyword_slash_s: "".to_string(), default_system_message: "".to_string(), - vecdb_search + vecdb_search, + has_vecdb_results: HasVecdbResults::new(), } } } @@ -70,11 +72,11 @@ impl ScratchpadAbstract for ChatLlama2 { context_size: usize, sampling_parameters_to_patch: &mut SamplingParameters, ) -> Result { - let augmented_msgs = match *self.vecdb_search.lock().await { - Some(ref db) => embed_vecdb_results(db, &self.post.messages, 6).await, - None => { self.post.messages.clone() } - }; - let limited_msgs: Vec = limit_messages_history(&self.t, &augmented_msgs, self.post.parameters.max_new_tokens, context_size, &self.default_system_message)?; + match *self.vecdb_search.lock().await { + Some(ref db) => chat_functions_middleware(db, &mut self.post, 6, &mut self.has_vecdb_results).await, + None => {} + } + let limited_msgs: Vec = limit_messages_history(&self.t, &self.post.messages, self.post.parameters.max_new_tokens, context_size, &self.default_system_message)?; sampling_parameters_to_patch.stop = Some(self.dd.stop_list.clone()); // loosely adapted from https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/model.py#L24 let mut prompt = "".to_string(); @@ -137,5 +139,9 @@ impl ScratchpadAbstract for ChatLlama2 { ) -> Result<(serde_json::Value, bool), String> { self.dd.response_streaming(delta, stop_toks) } + + fn response_spontaneous(&mut self) -> Result { + return self.has_vecdb_results.response_streaming(); + } } diff --git a/src/scratchpads/chat_passthrough.rs b/src/scratchpads/chat_passthrough.rs index 7b8dc972f..0d70ec230 100644 --- a/src/scratchpads/chat_passthrough.rs +++ b/src/scratchpads/chat_passthrough.rs @@ -7,7 +7,7 @@ use tracing::info; use crate::call_validation::{ChatMessage, ChatPost, ContextFile, SamplingParameters}; use crate::scratchpad_abstract::ScratchpadAbstract; use crate::scratchpads::chat_utils_limit_history::limit_messages_history_in_bytes; -use crate::scratchpads::chat_utils_rag::embed_vecdb_results; +use crate::scratchpads::chat_utils_rag::{chat_functions_middleware, HasVecdb, HasVecdbResults}; use crate::vecdb::structs::VecdbSearch; const DEBUG: bool = true; @@ -19,6 +19,7 @@ pub struct ChatPassthrough { pub default_system_message: String, pub limit_bytes: usize, pub vecdb_search: Arc>>, + pub has_vecdb_results: HasVecdbResults, } const DEFAULT_LIMIT_BYTES: usize = 4096*6; @@ -33,6 +34,7 @@ impl ChatPassthrough { default_system_message: "".to_string(), limit_bytes: DEFAULT_LIMIT_BYTES, // one token translates to 3 bytes (not unicode chars) vecdb_search, + has_vecdb_results: HasVecdbResults::new(), } } } @@ -53,11 +55,11 @@ impl ScratchpadAbstract for ChatPassthrough { _context_size: usize, _sampling_parameters_to_patch: &mut SamplingParameters, ) -> Result { - let augmented_msgs = match *self.vecdb_search.lock().await { - Some(ref db) => embed_vecdb_results(db, &self.post.messages, 6).await, - None => { self.post.messages.clone() } - }; - let limited_msgs: Vec = limit_messages_history_in_bytes(&augmented_msgs, self.limit_bytes, &self.default_system_message)?; + match *self.vecdb_search.lock().await { + Some(ref db) => chat_functions_middleware(db, &mut self.post, 6, &mut self.has_vecdb_results).await, + None => {} + } + let limited_msgs: Vec = limit_messages_history_in_bytes(&self.post.messages, self.limit_bytes, &self.default_system_message)?; info!("chat passthrough {} messages -> {} messages after applying limits and possibly adding the default system message", &limited_msgs.len(), &limited_msgs.len()); let mut filtered_msgs: Vec = Vec::::new(); for msg in &limited_msgs { @@ -68,7 +70,7 @@ impl ScratchpadAbstract for ChatPassthrough { for context_file in &vector_of_context_files { filtered_msgs.push(ChatMessage { role: "user".to_string(), - content: format!("{}:{}-{}\n```\n{}```", context_file.file_name, context_file.line1, context_file.line2, context_file.file_content), + content: format!("{}\n```\n{}```", context_file.file_name, context_file.file_content), }); } } @@ -123,4 +125,7 @@ impl ScratchpadAbstract for ChatPassthrough { }); Ok((ans, finished)) } + fn response_spontaneous(&mut self) -> Result { + return self.has_vecdb_results.response_streaming(); + } } diff --git a/src/scratchpads/chat_utils_limit_history.rs b/src/scratchpads/chat_utils_limit_history.rs index efd6a7314..6803031c6 100644 --- a/src/scratchpads/chat_utils_limit_history.rs +++ b/src/scratchpads/chat_utils_limit_history.rs @@ -7,7 +7,7 @@ pub fn limit_messages_history( messages: &Vec, max_new_tokens: usize, context_size: usize, - default_system_mesage: &String, + default_system_message: &String, ) -> Result, String> { let tokens_limit: i32 = context_size as i32 - max_new_tokens as i32; @@ -24,9 +24,9 @@ pub fn limit_messages_history( have_system = true; } } - let need_default_system_msg = !have_system && default_system_mesage.len() > 0; + let need_default_system_msg = !have_system && default_system_message.len() > 0; if need_default_system_msg { - let tcnt = t.count_tokens(default_system_mesage.as_str())? as i32; + let tcnt = t.count_tokens(default_system_message.as_str())? as i32; tokens_used += tcnt; } for i in (0..messages.len()).rev() { @@ -44,7 +44,7 @@ pub fn limit_messages_history( if need_default_system_msg { messages_out.insert(0, ChatMessage { role: "system".to_string(), - content: default_system_mesage.clone(), + content: default_system_message.clone(), }); } Ok(messages_out) @@ -54,7 +54,7 @@ pub fn limit_messages_history( pub fn limit_messages_history_in_bytes( messages: &Vec, bytes_limit: usize, - default_system_mesage: &String, + default_system_message: &String, ) -> Result, String> { let mut bytes_used: usize = 0; @@ -67,9 +67,9 @@ pub fn limit_messages_history_in_bytes( have_system = true; } } - let need_default_system_msg = !have_system && default_system_mesage.len() > 0; + let need_default_system_msg = !have_system && default_system_message.len() > 0; if need_default_system_msg { - bytes_used += default_system_mesage.as_bytes().len(); + bytes_used += default_system_message.as_bytes().len(); } for i in (0..messages.len()).rev() { let bytes = messages[i].content.len(); @@ -86,7 +86,7 @@ pub fn limit_messages_history_in_bytes( if need_default_system_msg { messages_out.insert(0, ChatMessage { role: "system".to_string(), - content: default_system_mesage.clone(), + content: default_system_message.clone(), }); } Ok(messages_out) diff --git a/src/scratchpads/chat_utils_rag.rs b/src/scratchpads/chat_utils_rag.rs index 3e95f2a92..621eb1b5c 100644 --- a/src/scratchpads/chat_utils_rag.rs +++ b/src/scratchpads/chat_utils_rag.rs @@ -1,43 +1,156 @@ +use async_trait::async_trait; +use serde_json::json; +use tracing::info; + use crate::call_validation::{ChatMessage, ChatPost, ContextFile}; use crate::vecdb::structs::{SearchResult, VecdbSearch}; -pub async fn embed_vecdb_results( +pub async fn chat_functions_middleware( vecdb: &T, - msgs: &Vec, + post: &mut ChatPost, limit_examples_cnt: usize, -) -> Vec where T: VecdbSearch { - let latest_msg_cont = &msgs.last().unwrap().content; - let vdb_resp = vecdb.search(latest_msg_cont.clone(), limit_examples_cnt).await; - let vdb_cont = vecdb_resp_to_prompt(&vdb_resp); - if vdb_cont.is_ok() { - return [ - &msgs[..msgs.len() - 1], - &[ChatMessage { - role: "context_file".to_string(), - content: vdb_cont.unwrap(), - }], - &msgs[msgs.len() - 1..], - ].concat(); - } else { - return msgs.clone(); + has_vecdb: &mut dyn HasVecdb, +) where T: VecdbSearch { + let latest_msg_cont = &post.messages.last().unwrap().content; + if latest_msg_cont.starts_with("@workspace") { + embed_vecdb_results(vecdb, post, limit_examples_cnt, has_vecdb).await; } } -fn vecdb_resp_to_prompt( + +async fn embed_vecdb_results( + vecdb: &T, + post: &mut ChatPost, + limit_examples_cnt: usize, + has_vecdb: &mut dyn HasVecdb, +) where T: VecdbSearch { + let latest_msg_cont = &post.messages.last().unwrap().content; + let vdb_resp = vecdb.search(latest_msg_cont.clone(), limit_examples_cnt).await; + + has_vecdb.add2messages( + vdb_resp, + &mut post.messages, + ).await; +} + +fn vecdb_resp_to_json( resp: &Result -) -> serde_json::Result { - let context_files: Vec = match resp { +) -> serde_json::Result { + let mut context_files: Vec = match resp { Ok(search_res) => { - search_res.results.iter().map( - |x| ContextFile { - file_name: x.file_path.to_str().unwrap().to_string(), - file_content: x.window_text.clone(), - line1: x.start_line as i32, - line2: x.end_line as i32, - } - ).collect() + search_res.results.iter().map(|x| ContextFile { + file_name: x.file_path.to_str().unwrap().to_string(), + file_content: x.window_text.clone(), + line1: x.start_line as i32, + line2: x.end_line as i32, + }).collect() } - Err(_) => vec![] + Err(_) => vec![], }; - serde_json::to_string(&context_files) + + context_files.dedup_by(|a, b| { + a.file_name == b.file_name && a.file_content == b.file_content + }); + + context_files.iter_mut().for_each(|file| { + file.file_name = file.file_name + .rsplit('/') + .next() + .unwrap_or(&file.file_name) + .to_string(); + }); + + serde_json::to_value(&context_files) +} + +fn vecdb_resp_to_prompt( + resp_mb: &Result +) -> String { + let mut cont = "".to_string(); + + if resp_mb.is_err() { + info!("VECDB ERR"); + return cont + } + let resp = resp_mb.as_ref().unwrap(); + let mut results = resp.results.clone(); + results.dedup_by(|a, b| a.file_path == b.file_path && a.window_text == b.window_text); + + cont.push_str("CONTEXT:\n"); + for res in results.iter() { + cont.push_str("FILENAME:\n"); + cont.push_str(res.file_path.clone().to_str().unwrap_or_else( || "")); + cont.push_str("\nTEXT:"); + cont.push_str(res.window_text.clone().as_str()); + cont.push_str("\n"); + } + cont.push_str("\nRefer to the context to answer my next question.\n"); + info!("VECDB prompt:\n{}", cont); + cont +} + + +pub struct HasVecdbResults { + pub was_sent: bool, + pub in_json: serde_json::Value, +} + +impl HasVecdbResults { + pub fn new() -> Self { + HasVecdbResults { + was_sent: false, + in_json: json!(null) + } + } +} + +#[async_trait] +pub trait HasVecdb: Send { + async fn add2messages( + &mut self, + vdb_result_mb: Result, + messages: &mut Vec, + ); + fn response_streaming(&mut self) -> Result; +} + +#[async_trait] +impl HasVecdb for HasVecdbResults { + async fn add2messages( + &mut self, + result_mb: Result, + messages: &mut Vec, + ) { + // if messages.len() > 1 { + // return; + // } + + *messages = [ + &messages[..messages.len() -1], + &[ChatMessage { + role: "user".to_string(), + content: vecdb_resp_to_prompt(&result_mb), + }], + &messages[messages.len() -1..], + ].concat(); + + self.in_json = vecdb_resp_to_json(&result_mb).unwrap_or_else(|_| json!(null)); + } + + fn response_streaming(&mut self) -> Result { + if self.was_sent == true || self.in_json.is_null() { + return Ok(json!(null)); + } + self.was_sent = true; + return Ok(json!({ + "choices": [{ + "delta": { + "content": self.in_json.clone(), + "role": "context_file" + }, + "finish_reason": serde_json::Value::Null, + "index": 0 + }], + })); + } } diff --git a/src/scratchpads/completion_single_file_fim.rs b/src/scratchpads/completion_single_file_fim.rs index da81fbbd4..7ba20c21c 100644 --- a/src/scratchpads/completion_single_file_fim.rs +++ b/src/scratchpads/completion_single_file_fim.rs @@ -285,6 +285,9 @@ impl ScratchpadAbstract for SingleFileFIM { }); Ok((ans, finished)) } + fn response_spontaneous(&mut self) -> Result { + return Err("".to_string()); + } } fn get_context_near_cursor(text: &Rope, line_pos: usize, max_lines_count: usize) -> String { diff --git a/src/vecdb/file_watcher_service.rs b/src/vecdb/file_watcher_service.rs index d7df459b2..ba739a373 100644 --- a/src/vecdb/file_watcher_service.rs +++ b/src/vecdb/file_watcher_service.rs @@ -59,7 +59,7 @@ pub async fn file_watcher_task( let (mut watcher, mut rx) = make_async_watcher().expect("Failed to make file watcher"); let maybe_path = global_context.read().await.cmdline.files_set_path.clone(); if maybe_path.is_empty() { - info!("file watcher: no files to watch"); + info!("files_set_path is empty: no files to watch"); return; } let path = PathBuf::from(maybe_path); diff --git a/src/vecdb/handler.rs b/src/vecdb/handler.rs index 38c2c49a9..f44d02455 100644 --- a/src/vecdb/handler.rs +++ b/src/vecdb/handler.rs @@ -58,8 +58,13 @@ const ONE_MONTH: i32 = 30 * 24 * 3600; const MIN_LIKES: i32 = 3; impl VecDBHandler { - pub async fn init(cache_dir: PathBuf, embedding_size: i32) -> Result { - let cache_dir_str = match cache_dir.join("refact_vecdb_cache").to_str() { + pub async fn init(cache_dir: &PathBuf, model_name: &String, embedding_size: i32) -> Result { + let cache_dir_str = match cache_dir.join("refact_vecdb_cache") + .join(format!("model_{}_esize_{}", + model_name.replace("/", "_"), + embedding_size + )).to_str() { + Some(dir) => dir.to_string(), None => { return Err(format!("{:?}", "Cache directory is not a valid path")); @@ -327,6 +332,7 @@ impl VecDBHandler { Some(res) => res }; + // valerii: In documentation I found no way to preprocess strings to prevent SQL injections match self.cache_table.delete( format!("(file_path = \"{}\")", file_path_str).as_str() // TODO: Prevent a possible sql injection here ).await { diff --git a/src/vecdb/mod.rs b/src/vecdb/mod.rs index bc409f285..8286f73a3 100644 --- a/src/vecdb/mod.rs +++ b/src/vecdb/mod.rs @@ -1,8 +1,6 @@ mod file_splitter; mod handler; -mod req_client; mod vectorizer_service; - pub mod file_watcher_service; pub mod file_filter; pub mod structs; diff --git a/src/vecdb/vecdb.rs b/src/vecdb/vecdb.rs index 3100172cc..b2ac6b6c6 100644 --- a/src/vecdb/vecdb.rs +++ b/src/vecdb/vecdb.rs @@ -2,52 +2,213 @@ use std::path::PathBuf; use std::sync::Arc; use async_trait::async_trait; -use tokio::sync::Mutex; +use tracing::info; +use tokio::sync::Mutex as AMutex; use tokio::task::JoinHandle; +use crate::global_context::{CommandLine, GlobalContext}; +use tokio::sync::RwLock as ARwLock; +use tower_lsp::lsp_types::WorkspaceFolder; +use tracing::error; +use crate::background_tasks::BackgroundTasksHolder; +use crate::fetch_embedding; +use crate::vecdb; +use crate::vecdb::{file_filter}; use crate::vecdb::handler::{VecDBHandler, VecDBHandlerRef}; -use crate::vecdb::req_client::get_embedding; use crate::vecdb::vectorizer_service::FileVectorizerService; use crate::vecdb::structs::{SearchResult, VecdbSearch, VecDbStatus}; + #[derive(Debug)] pub struct VecDb { vecdb_handler: VecDBHandlerRef, - retriever_service: Arc>, - embedding_model_name: String, - cmdline: crate::global_context::CommandLine, + retriever_service: Arc>, + cmdline: CommandLine, + + model_name: String, + endpoint_template: String, + endpoint_embeddings_style: String, +} + + +pub async fn create_vecdb( + default_embeddings_model: String, + endpoint_embeddings_template: String, + endpoint_embeddings_style: String, + size_embeddings: i32, + + cmdline: &CommandLine, + cache_dir: &PathBuf, +) -> Option { + let vec_db = match VecDb::init( + &cache_dir, cmdline.clone(), + size_embeddings, 60, 512, 1024, + default_embeddings_model.clone(), + endpoint_embeddings_template.clone(), + endpoint_embeddings_style.clone(), + ).await { + Ok(res) => Some(res), + Err(err) => { + error!("Ooops database is broken! + Last error message: {} + You can report this issue here: + https://github.com/smallcloudai/refact-lsp/issues + Also, you can run this to erase your db: + `rm -rf ~/.cache/refact/refact_vecdb_cache` + After that restart this LSP server or your IDE.", err); + None + } + }; + vec_db } +pub async fn vecdb_background_reload( + global_context: Arc>, +) { + let mut background_tasks = BackgroundTasksHolder::new(vec![]); + loop { + tokio::time::sleep(tokio::time::Duration::from_secs(30)).await; + + let (caps_mb, cache_dir, cmdline) = { + let gcx_locked = global_context.read().await; + let caps_mb = gcx_locked.caps.clone(); + let cache_dir = gcx_locked.cache_dir.clone(); + (caps_mb, &cache_dir.clone(), &gcx_locked.cmdline.clone()) + }; + + if caps_mb.is_none() || !cmdline.vecdb { + continue; + } + + let ( + default_embeddings_model, + endpoint_embeddings_template, + endpoint_embeddings_style, + size_embeddings, + ) = { + let caps = caps_mb.unwrap(); + let caps_locked = caps.read().unwrap(); + ( + caps_locked.default_embeddings_model.clone(), + caps_locked.endpoint_embeddings_template.clone(), + caps_locked.endpoint_embeddings_style.clone(), + caps_locked.size_embeddings.clone(), + ) + }; + + if default_embeddings_model.is_empty() || endpoint_embeddings_template.is_empty() { + error!("vecd launch failed: default_embeddings_model.is_empty() || endpoint_embeddings_template.is_empty()"); + continue; + } + + + match *global_context.write().await.vec_db.lock().await { + None => {} + Some(ref db) => { + if db.model_name == default_embeddings_model && + db.endpoint_template == endpoint_embeddings_template && + db.endpoint_embeddings_style == endpoint_embeddings_style { + continue; + } + } + } + + info!("attempting to launch vecdb"); + + background_tasks.abort().await; + background_tasks = BackgroundTasksHolder::new(vec![]); + + let vecdb_mb = create_vecdb( + default_embeddings_model.clone(), + endpoint_embeddings_template, + endpoint_embeddings_style, + size_embeddings, + + cmdline, + cache_dir + ).await; + + if vecdb_mb.is_none() { + continue; + } + + { + let mut gcx_locked = global_context.write().await; + + if let Some(caps) = &mut gcx_locked.caps { + caps.write().unwrap().chat_rag_functions = vec![ + "@workspace".to_string() + ] + }; + + gcx_locked.vec_db = Arc::new(AMutex::new(vecdb_mb)); + info!("VECDB is launched successfully"); + + background_tasks.extend(match *gcx_locked.vec_db.lock().await { + Some(ref db) => { + let mut tasks = db.start_background_tasks().await; + tasks.push( + tokio::spawn(vecdb::file_watcher_service::file_watcher_task(global_context.clone())) + ); + tasks + } + None => vec![] + }); + { + if let Some(folders) = gcx_locked.lsp_backend_document_state.workspace_folders.clone().read().await.clone() { + let mut vec_db_lock = gcx_locked.vec_db.lock().await; + if let Some(ref mut db) = *vec_db_lock { + db.init_folders(folders).await; + } + } + } + } + } +} impl VecDb { pub async fn init( - cache_dir: PathBuf, - cmdline: crate::global_context::CommandLine, + cache_dir: &PathBuf, + cmdline: CommandLine, embedding_size: i32, cooldown_secs: u64, splitter_window_size: usize, splitter_soft_limit: usize, - embedding_model_name: String, + + model_name: String, + endpoint_template: String, + endpoint_embeddings_style: String, ) -> Result { - let handler = match VecDBHandler::init(cache_dir, embedding_size).await { + let handler = match VecDBHandler::init(cache_dir, &model_name, embedding_size).await { Ok(res) => res, Err(err) => { return Err(err) } }; - let vecdb_handler = Arc::new(Mutex::new(handler)); - let retriever_service = Arc::new(Mutex::new(FileVectorizerService::new( - vecdb_handler.clone(), cooldown_secs, splitter_window_size, splitter_soft_limit, - embedding_model_name.clone(), cmdline.api_key.clone(), + let vecdb_handler = Arc::new(AMutex::new(handler)); + let retriever_service = Arc::new(AMutex::new(FileVectorizerService::new( + vecdb_handler.clone(), + cooldown_secs, + splitter_window_size, + splitter_soft_limit, + + model_name.clone(), + cmdline.api_key.clone(), + endpoint_embeddings_style.clone(), + endpoint_template.clone(), ).await)); Ok(VecDb { vecdb_handler, retriever_service, - embedding_model_name, - cmdline, + cmdline: cmdline.clone(), + + model_name, + endpoint_template, + endpoint_embeddings_style, }) } pub async fn start_background_tasks(&self) -> Vec> { + info!("vecdb: start_background_tasks"); return self.retriever_service.lock().await.start_background_tasks().await; } @@ -66,30 +227,40 @@ impl VecDb { pub async fn get_status(&self) -> Result { self.retriever_service.lock().await.status().await } + + pub async fn init_folders(&self, folders: Vec) { + let files = file_filter::retrieve_files_by_proj_folders( + folders.iter().map(|x| PathBuf::from(x.uri.path())).collect() + ).await; + self.add_or_update_files(files, true).await; + info!("vecdb: init_folders complete"); + } } #[async_trait] impl VecdbSearch for VecDb { async fn search(&self, query: String, top_n: usize) -> Result { - let embedding = get_embedding( - query.clone(), &self.embedding_model_name, self.cmdline.api_key.clone(), - ).await.unwrap(); - match embedding { - Ok(vector) => { - let mut binding = self.vecdb_handler.lock().await; - let results = binding.search(vector, top_n).await.unwrap(); - binding.update_record_statistic(results.clone()).await; - Ok( - SearchResult { - query_text: query, - results: results, - } - ) - } - Err(_) => { - return Err("Failed to get embedding".to_string()); - } + let embedding_mb = fetch_embedding::try_get_embedding( + &self.endpoint_embeddings_style, + &self.model_name, + &self.endpoint_template, + query.clone(), + &self.cmdline.api_key, + 3 + ).await; + if embedding_mb.is_err() { + return Err("Failed to get embedding".to_string()); } + let mut binding = self.vecdb_handler.lock().await; + + let results = binding.search(embedding_mb.unwrap(), top_n).await.unwrap(); + binding.update_record_statistic(results.clone()).await; + Ok( + SearchResult { + query_text: query, + results, + } + ) } } diff --git a/src/vecdb/vectorizer_service.rs b/src/vecdb/vectorizer_service.rs index db2e631c9..b5a4419ed 100644 --- a/src/vecdb/vectorizer_service.rs +++ b/src/vecdb/vectorizer_service.rs @@ -3,13 +3,13 @@ use std::path::PathBuf; use std::sync::Arc; use std::time::SystemTime; -use tokio::sync::Mutex; +use tokio::sync::{Mutex, Semaphore}; use tokio::task::JoinHandle; use tracing::info; use crate::vecdb::file_splitter::FileSplitter; use crate::vecdb::handler::VecDBHandlerRef; -use crate::vecdb::req_client::get_embedding; +use crate::fetch_embedding::try_get_embedding; use crate::vecdb::structs::{Record, SplitResult, VecDbStatus, VecDbStatusRef}; #[derive(Debug)] @@ -21,8 +21,11 @@ pub struct FileVectorizerService { cooldown_secs: u64, splitter_window_size: usize, splitter_soft_limit: usize, - embedding_model_name: String, + + model_name: String, api_key: String, + endpoint_embeddings_style: String, + endpoint_template: String, } async fn cooldown_queue_thread( @@ -67,10 +70,16 @@ async fn vectorize_thread( status: VecDbStatusRef, splitter_window_size: usize, splitter_soft_limit: usize, - embedding_model_name: String, + + model_name: String, api_key: String, + endpoint_embeddings_style: String, + endpoint_template: String, + + max_concurrent_tasks: usize, ) { let file_splitter = FileSplitter::new(splitter_window_size, splitter_soft_limit); + let semaphore = Arc::new(Semaphore::new(max_concurrent_tasks)); loop { let (path_maybe, unprocessed_files_count) = { @@ -108,41 +117,63 @@ async fn vectorize_thread( drop(vecdb_handler); info!("Retrieving embeddings for {} chunks", split_data_filtered.len()); - let join_handles: Vec<_> = split_data_filtered.iter().map( - |x| get_embedding(x.window_text.clone(), &embedding_model_name, api_key.clone()) - ).collect(); - status.lock().await.requests_made_since_start += join_handles.len(); - - let mut split_join_data: VecDeque<(SplitResult, JoinHandle, String>>)> - = split_data_filtered.into_iter() - .zip(join_handles.into_iter()) - .collect::>(); - - let mut records: Vec = Vec::new(); - while let Some((data_res, handle)) = split_join_data.pop_front() { - match handle.await { - Ok(Ok(result)) => { - let now = SystemTime::now(); - records.push( - Record { - vector: Some(result), - window_text: data_res.window_text, - window_text_hash: data_res.window_text_hash, - file_path: data_res.file_path, - start_line: data_res.start_line, - end_line: data_res.end_line, - time_added: now, - time_last_used: now, - model_name: embedding_model_name.clone(), - used_counter: 0, - distance: -1.0, - } - ); - } - Ok(Err(e)) => { - info!("Error retrieving embeddings for {}: {}", data_res.window_text, e); + let join_handles: Vec<_> = split_data_filtered.into_iter().map(|x| { + let model_name_clone = model_name.clone(); + let api_key_clone = api_key.clone(); + let endpoint_embeddings_style_clone = endpoint_embeddings_style.clone(); + let endpoint_template_clone = endpoint_template.clone(); + + let semaphore_clone = Arc::clone(&semaphore); + tokio::spawn(async move { + let _permit = match semaphore_clone.acquire().await { + Ok(permit) => permit, + Err(_) => { + return None; + } + }; + + let result = try_get_embedding( + &endpoint_embeddings_style_clone, + &model_name_clone, + &endpoint_template_clone, + x.window_text.clone(), + &api_key_clone, + 3, + ).await; + + drop(_permit); + Some((x, result)) + }) + }).collect(); + + let mut records = vec![]; + + for handle in join_handles { + if let Some((data_res, result_mb)) = handle.await.unwrap() { + match result_mb { + Ok(result) => { + let now = SystemTime::now(); + + records.push( + Record { + vector: Some(result), + window_text: data_res.window_text, + window_text_hash: data_res.window_text_hash, + file_path: data_res.file_path, + start_line: data_res.start_line, + end_line: data_res.end_line, + time_added: SystemTime::now(), + model_name: model_name.clone(), + distance: -1.0, + used_counter: 0, + time_last_used: now, + } + ); + } + Err(e) => { + info!("Error retrieving embeddings for {}: {}", data_res.file_path.to_str().unwrap(), e); + } } - Err(_) => { continue; } } } match vecdb_handler_ref.lock().await.add_or_update(records, true).await { @@ -170,8 +201,11 @@ impl FileVectorizerService { cooldown_secs: u64, splitter_window_size: usize, splitter_soft_limit: usize, - embedding_model_name: String, + + model_name: String, api_key: String, + endpoint_embeddings_style: String, + endpoint_template: String, ) -> Self { let update_request_queue = Arc::new(Mutex::new(VecDeque::new())); let output_queue = Arc::new(Mutex::new(VecDeque::new())); @@ -191,8 +225,11 @@ impl FileVectorizerService { cooldown_secs, splitter_window_size, splitter_soft_limit, - embedding_model_name, + + model_name, api_key, + endpoint_embeddings_style, + endpoint_template, } } @@ -213,11 +250,16 @@ impl FileVectorizerService { self.status.clone(), self.splitter_window_size, self.splitter_soft_limit, - self.embedding_model_name.clone(), + + self.model_name.clone(), self.api_key.clone(), + self.endpoint_embeddings_style.clone(), + self.endpoint_template.clone(), + + 4, ) ); - + let cleanup_thread_handle = tokio::spawn( cleanup_thread( self.vecdb_handler.clone() From 155d4d109703213c3a48e5ea28593964758ab280 Mon Sep 17 00:00:00 2001 From: Valeryi Date: Tue, 16 Jan 2024 20:57:42 +0000 Subject: [PATCH 03/54] smallfix --- src/http/routers/v1/lsp_like_handlers.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/http/routers/v1/lsp_like_handlers.rs b/src/http/routers/v1/lsp_like_handlers.rs index e2390f9fb..4b9ffeaad 100644 --- a/src/http/routers/v1/lsp_like_handlers.rs +++ b/src/http/routers/v1/lsp_like_handlers.rs @@ -24,10 +24,10 @@ struct PostDocument { pub async fn handle_v1_lsp_initialize( - Extension(_global_context): Extension, + Extension(global_context): Extension, body_bytes: hyper::body::Bytes, ) -> Result, ScratchError> { - let _post = serde_json::from_slice::(&body_bytes).map_err(|e| { + let post = serde_json::from_slice::(&body_bytes).map_err(|e| { ScratchError::new(StatusCode::BAD_REQUEST, format!("JSON problem: {}", e)) })?; From f203d9a9c94937f6224ba48e8a4bc0687afe4827 Mon Sep 17 00:00:00 2001 From: Valeryi Date: Wed, 17 Jan 2024 12:33:02 +0000 Subject: [PATCH 04/54] changelist: * removed tokio::spawn from get_embedding_openai_style and get_embedding_hf_style * added try_load_caps_quickly_if_not_present to vecdb_background_reload * added test search to vecdb_background_reload --- src/forward_to_hf_endpoint.rs | 40 ++++++++++----------- src/forward_to_openai_endpoint.rs | 58 ++++++++++++++----------------- src/vecdb/vecdb.rs | 29 +++++++++++----- 3 files changed, 65 insertions(+), 62 deletions(-) diff --git a/src/forward_to_hf_endpoint.rs b/src/forward_to_hf_endpoint.rs index a72df2271..4b14ca96a 100644 --- a/src/forward_to_hf_endpoint.rs +++ b/src/forward_to_hf_endpoint.rs @@ -103,29 +103,25 @@ pub async fn get_embedding_hf_style( let client = reqwest::Client::new(); let payload = EmbeddingsPayloadHF { inputs: text }; let url = endpoint_template.clone().replace("$MODEL", &model_name); - let api_key_clone = api_key.clone(); - let join_handle = tokio::task::spawn(async move { - let maybe_response = client - .post(&url) - .bearer_auth(api_key_clone.clone()) - .json(&payload) - .send() - .await; + let maybe_response = client + .post(&url) + .bearer_auth(api_key.clone()) + .json(&payload) + .send() + .await; - return match maybe_response { - Ok(response) => { - if response.status().is_success() { - match response.json::>().await { - Ok(embedding) => Ok(embedding), - Err(err) => Err(format!("Failed to parse the response: {:?}", err)), - } - } else { - Err(format!("Failed to get a response: {:?}", response.status())) + match maybe_response { + Ok(response) => { + if response.status().is_success() { + match response.json::>().await { + Ok(embedding) => Ok(embedding), + Err(err) => Err(format!("Failed to parse the response: {:?}", err)), } - }, - Err(err) => Err(format!("Failed to send a request: {:?}", err)), - } - }); - join_handle.await.unwrap_or_else(|_| Err("Task join error".to_string())) + } else { + Err(format!("Failed to get a response: {:?}", response.status())) + } + }, + Err(err) => Err(format!("Failed to send a request: {:?}", err)), + } } diff --git a/src/forward_to_openai_endpoint.rs b/src/forward_to_openai_endpoint.rs index 7f3098413..cb8391636 100644 --- a/src/forward_to_openai_endpoint.rs +++ b/src/forward_to_openai_endpoint.rs @@ -130,42 +130,36 @@ pub async fn get_embedding_openai_style( input: text, model: model_name.clone(), }; - let url = endpoint_template.clone(); - let api_key_clone = api_key.clone(); - let join_handle = tokio::spawn(async move { - let maybe_response = client - .post(&url) - .bearer_auth(api_key_clone.clone()) - .json(&payload) - .send() - .await; + let maybe_response = client + .post(endpoint_template) + .bearer_auth(api_key) + .json(&payload) + .send() + .await; - return match maybe_response { - Ok(response) => { - if response.status().is_success() { - let response_json = response.json::().await; + match maybe_response { + Ok(response) => { + if response.status().is_success() { + let response_json = response.json::().await; - match response_json { - Ok(json) => match &json["data"][0]["embedding"] { - serde_json::Value::Array(embedding) => { - let embedding_values: Result, _> = - serde_json::from_value(serde_json::Value::Array(embedding.clone())); - embedding_values.map_err(|err| { - format!("Failed to parse the response: {:?}", err) - }) - } - _ => Err("Response is missing 'data[0].embedding' field or it's not an array".to_string()), - }, - Err(err) => Err(format!("Failed to parse the response: {:?}", err)), - } - } else { - Err(format!("Failed to get a response: {:?}", response.status())) + match response_json { + Ok(json) => match &json["data"][0]["embedding"] { + serde_json::Value::Array(embedding) => { + let embedding_values: Result, _> = + serde_json::from_value(serde_json::Value::Array(embedding.clone())); + embedding_values.map_err(|err| { + format!("Failed to parse the response: {:?}", err) + }) + } + _ => Err("Response is missing 'data[0].embedding' field or it's not an array".to_string()), + }, + Err(err) => Err(format!("Failed to parse the response: {:?}", err)), } + } else { + Err(format!("Failed to get a response: {:?}", response.status())) } - Err(err) => Err(format!("Failed to send a request: {:?}", err)), } - }); - - join_handle.await.unwrap_or_else(|_| Err("Task join error".to_string())) + Err(err) => Err(format!("Failed to send a request: {:?}", err)), + } } diff --git a/src/vecdb/vecdb.rs b/src/vecdb/vecdb.rs index b2ac6b6c6..6c255a480 100644 --- a/src/vecdb/vecdb.rs +++ b/src/vecdb/vecdb.rs @@ -67,16 +67,17 @@ pub async fn vecdb_background_reload( ) { let mut background_tasks = BackgroundTasksHolder::new(vec![]); loop { - tokio::time::sleep(tokio::time::Duration::from_secs(30)).await; + tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; - let (caps_mb, cache_dir, cmdline) = { + let (cache_dir, cmdline) = { let gcx_locked = global_context.read().await; - let caps_mb = gcx_locked.caps.clone(); let cache_dir = gcx_locked.cache_dir.clone(); - (caps_mb, &cache_dir.clone(), &gcx_locked.cmdline.clone()) + (&cache_dir.clone(), &gcx_locked.cmdline.clone()) }; - if caps_mb.is_none() || !cmdline.vecdb { + let caps_mb = crate::global_context::try_load_caps_quickly_if_not_present(global_context.clone(), 10).await; + + if caps_mb.is_err() || !cmdline.vecdb { continue; } @@ -113,7 +114,7 @@ pub async fn vecdb_background_reload( } } - info!("attempting to launch vecdb"); + info!("vecdb: attempting to launch"); background_tasks.abort().await; background_tasks = BackgroundTasksHolder::new(vec![]); @@ -131,6 +132,18 @@ pub async fn vecdb_background_reload( if vecdb_mb.is_none() { continue; } + let vecdb = vecdb_mb.unwrap(); + + let search_result = vecdb.search("".to_string(), 3).await; + match search_result { + Ok(_) => { + info!("vecdb: test search complete") + } + Err(_) => { + error!("vecdb: test search failed"); + continue; + } + } { let mut gcx_locked = global_context.write().await; @@ -141,8 +154,8 @@ pub async fn vecdb_background_reload( ] }; - gcx_locked.vec_db = Arc::new(AMutex::new(vecdb_mb)); - info!("VECDB is launched successfully"); + gcx_locked.vec_db = Arc::new(AMutex::new(Some(vecdb))); + info!("vecdb is launched successfully"); background_tasks.extend(match *gcx_locked.vec_db.lock().await { Some(ref db) => { From 91a0093dd15b240b49f24955c51d3b58efa446ec Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Wed, 17 Jan 2024 13:53:20 +0100 Subject: [PATCH 05/54] simplify forward_to_openai_endpoint --- src/forward_to_openai_endpoint.rs | 47 ++++++++++++++----------------- 1 file changed, 21 insertions(+), 26 deletions(-) diff --git a/src/forward_to_openai_endpoint.rs b/src/forward_to_openai_endpoint.rs index cb8391636..c374e8b13 100644 --- a/src/forward_to_openai_endpoint.rs +++ b/src/forward_to_openai_endpoint.rs @@ -130,36 +130,31 @@ pub async fn get_embedding_openai_style( input: text, model: model_name.clone(), }; - - let maybe_response = client - .post(endpoint_template) - .bearer_auth(api_key) + let url = endpoint_template.clone(); + let api_key_clone = api_key.clone(); + let response = client + .post(&url) + .bearer_auth(api_key_clone.clone()) .json(&payload) .send() - .await; + .await + .map_err(|e| format!("Failed to send a request: {:?}", e))?; + + if !response.status().is_success() { + info!("get_embedding_openai_style: {:?}", response); + return Err(format!("get_embedding_openai_style: bad status: {:?}", response.status())) + } - match maybe_response { - Ok(response) => { - if response.status().is_success() { - let response_json = response.json::().await; + let json = response.json::() + .await + .map_err(|err| format!("get_embedding_openai_style: failed to parse the response: {:?}", err))?; - match response_json { - Ok(json) => match &json["data"][0]["embedding"] { - serde_json::Value::Array(embedding) => { - let embedding_values: Result, _> = - serde_json::from_value(serde_json::Value::Array(embedding.clone())); - embedding_values.map_err(|err| { - format!("Failed to parse the response: {:?}", err) - }) - } - _ => Err("Response is missing 'data[0].embedding' field or it's not an array".to_string()), - }, - Err(err) => Err(format!("Failed to parse the response: {:?}", err)), - } - } else { - Err(format!("Failed to get a response: {:?}", response.status())) - } + // info!("get_embedding_openai_style: {:?}", json); + match &json["data"][0]["embedding"] { + serde_json::Value::Array(embedding) => { + serde_json::from_value(serde_json::Value::Array(embedding.clone())) + .map_err(|err| { format!("Failed to parse the response: {:?}", err) }) } - Err(err) => Err(format!("Failed to send a request: {:?}", err)), + _ => Err("Response is missing 'data[0].embedding' field or it's not an array".to_string()), } } From 62956fceb07cdb5a3083742b5101d50aabef6b02 Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Wed, 17 Jan 2024 15:29:26 +0100 Subject: [PATCH 06/54] remove caps reload every 10 seconds --- src/vecdb/vecdb.rs | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/src/vecdb/vecdb.rs b/src/vecdb/vecdb.rs index 6c255a480..c18c90cb7 100644 --- a/src/vecdb/vecdb.rs +++ b/src/vecdb/vecdb.rs @@ -75,15 +75,15 @@ pub async fn vecdb_background_reload( (&cache_dir.clone(), &gcx_locked.cmdline.clone()) }; - let caps_mb = crate::global_context::try_load_caps_quickly_if_not_present(global_context.clone(), 10).await; + let caps_mb = crate::global_context::try_load_caps_quickly_if_not_present(global_context.clone(), 0).await; if caps_mb.is_err() || !cmdline.vecdb { continue; } let ( - default_embeddings_model, - endpoint_embeddings_template, + default_embeddings_model, + endpoint_embeddings_template, endpoint_embeddings_style, size_embeddings, ) = { @@ -124,7 +124,7 @@ pub async fn vecdb_background_reload( endpoint_embeddings_template, endpoint_embeddings_style, size_embeddings, - + cmdline, cache_dir ).await; @@ -147,13 +147,6 @@ pub async fn vecdb_background_reload( { let mut gcx_locked = global_context.write().await; - - if let Some(caps) = &mut gcx_locked.caps { - caps.write().unwrap().chat_rag_functions = vec![ - "@workspace".to_string() - ] - }; - gcx_locked.vec_db = Arc::new(AMutex::new(Some(vecdb))); info!("vecdb is launched successfully"); From 95a81f236740332f26750a404c2352d5d43ff38a Mon Sep 17 00:00:00 2001 From: Valeryi Date: Wed, 17 Jan 2024 12:41:54 +0000 Subject: [PATCH 07/54] fixed some warnings --- src/background_tasks.rs | 2 +- src/fetch_embedding.rs | 2 -- src/lsp.rs | 2 +- src/vecdb/handler.rs | 10 +++++----- src/vecdb/vecdb.rs | 2 +- src/vecdb/vectorizer_service.rs | 4 ++-- 6 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/background_tasks.rs b/src/background_tasks.rs index ae0e3f045..7558e41e8 100644 --- a/src/background_tasks.rs +++ b/src/background_tasks.rs @@ -5,7 +5,7 @@ use std::vec; use tokio::sync::RwLock as ARwLock; use tokio::task::JoinHandle; -use crate::{global_context, vecdb}; +use crate::vecdb; use crate::global_context::GlobalContext; use crate::snippets_transmit; use crate::telemetry::basic_transmit; diff --git a/src/fetch_embedding.rs b/src/fetch_embedding.rs index 163819c5e..b50c17a43 100644 --- a/src/fetch_embedding.rs +++ b/src/fetch_embedding.rs @@ -1,5 +1,3 @@ -use reqwest; -use serde::Serialize; use tracing::error; use crate::forward_to_hf_endpoint::get_embedding_hf_style; diff --git a/src/lsp.rs b/src/lsp.rs index 95bef38cd..d290d65ae 100644 --- a/src/lsp.rs +++ b/src/lsp.rs @@ -19,7 +19,7 @@ use crate::http::routers::v1::code_completion::handle_v1_code_completion; use crate::receive_workspace_changes; use crate::telemetry; use crate::telemetry::snippets_collection; -use crate::vecdb::file_filter::{is_valid_file, retrieve_files_by_proj_folders}; +use crate::vecdb::file_filter::is_valid_file; const VERSION: &str = env!("CARGO_PKG_VERSION"); diff --git a/src/vecdb/handler.rs b/src/vecdb/handler.rs index f44d02455..9a398e37c 100644 --- a/src/vecdb/handler.rs +++ b/src/vecdb/handler.rs @@ -34,7 +34,7 @@ impl Debug for VecDBHandler { pub struct VecDBHandler { cache_database: Database, - data_database_temp_dir: TempDir, + _data_database_temp_dir: TempDir, cache_table: Table, data_table: Table, schema: SchemaRef, @@ -119,7 +119,7 @@ impl VecDBHandler { Ok(VecDBHandler { cache_database, - data_database_temp_dir, + _data_database_temp_dir: data_database_temp_dir, schema, cache_table, data_table, @@ -170,7 +170,7 @@ impl VecDBHandler { (records, hashes.into_iter().collect()) } - pub async fn get_records_from_data(&mut self, hashes: Vec) -> (Vec, Vec) { + pub async fn _get_records_from_data(&mut self, hashes: Vec) -> (Vec, Vec) { self.get_records(self.data_table.clone(), hashes).await } pub async fn get_records_from_cache(&mut self, hashes: Vec) -> (Vec, Vec) { @@ -192,10 +192,10 @@ impl VecDBHandler { } } - pub async fn get_record_from_data(&mut self, hash: String) -> vectordb::error::Result { + pub async fn _get_record_from_data(&mut self, hash: String) -> vectordb::error::Result { self.get_record(self.data_table.clone(), hash).await } - pub async fn get_record_from_cache(&mut self, hash: String) -> vectordb::error::Result { + pub async fn _get_record_from_cache(&mut self, hash: String) -> vectordb::error::Result { self.get_record(self.cache_table.clone(), hash).await } diff --git a/src/vecdb/vecdb.rs b/src/vecdb/vecdb.rs index c18c90cb7..320c926d4 100644 --- a/src/vecdb/vecdb.rs +++ b/src/vecdb/vecdb.rs @@ -67,7 +67,7 @@ pub async fn vecdb_background_reload( ) { let mut background_tasks = BackgroundTasksHolder::new(vec![]); loop { - tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; + tokio::time::sleep(tokio::time::Duration::from_secs(30)).await; let (cache_dir, cmdline) = { let gcx_locked = global_context.read().await; diff --git a/src/vecdb/vectorizer_service.rs b/src/vecdb/vectorizer_service.rs index b5a4419ed..34d5b2277 100644 --- a/src/vecdb/vectorizer_service.rs +++ b/src/vecdb/vectorizer_service.rs @@ -31,12 +31,12 @@ pub struct FileVectorizerService { async fn cooldown_queue_thread( update_request_queue: Arc>>, out_queue: Arc>>, - status: VecDbStatusRef, + _status: VecDbStatusRef, cooldown_secs: u64, ) { let mut last_updated: HashMap = HashMap::new(); loop { - let (path_maybe, unprocessed_files_count) = { + let (path_maybe, _unprocessed_files_count) = { let mut queue_locked = update_request_queue.lock().await; if !queue_locked.is_empty() { (Some(queue_locked.pop_front().unwrap()), queue_locked.len()) From e5a6afae7e04dfb1d66736d6f3376dabb055be36 Mon Sep 17 00:00:00 2001 From: Valeryi Date: Wed, 17 Jan 2024 14:35:52 +0000 Subject: [PATCH 08/54] added /v1/vdb-caps --- src/caps.rs | 2 -- src/http/routers/v1.rs | 3 ++- src/http/routers/v1/vecdb.rs | 22 ++++++++++++++++++++++ src/vecdb/vecdb.rs | 17 +++++++++++++++-- 4 files changed, 39 insertions(+), 5 deletions(-) diff --git a/src/caps.rs b/src/caps.rs index f3061a6e4..c950019a4 100644 --- a/src/caps.rs +++ b/src/caps.rs @@ -55,8 +55,6 @@ pub struct CodeAssistantCaps { pub running_models: Vec, #[serde(default)] pub caps_version: i64, // need to reload if it increases on server, that happens when server configuration changes - #[serde(default)] - pub chat_rag_functions: Vec, } #[derive(Debug, Deserialize)] diff --git a/src/http/routers/v1.rs b/src/http/routers/v1.rs index 4d0f1a88d..afa547481 100644 --- a/src/http/routers/v1.rs +++ b/src/http/routers/v1.rs @@ -20,7 +20,7 @@ use crate::http::routers::v1::telemetry_network::handle_v1_telemetry_network; use crate::http::routers::v1::lsp_like_handlers::handle_v1_lsp_initialize; use crate::http::routers::v1::lsp_like_handlers::handle_v1_lsp_did_change; use crate::http::utils::telemetry_wrapper; -use crate::http::routers::v1::vecdb::{handle_v1_vecdb_search, handle_v1_vecdb_status}; +use crate::http::routers::v1::vecdb::{handle_v1_vecdb_search, handle_v1_vecdb_status, handle_v1_vecdb_caps}; pub mod code_completion; pub mod chat; @@ -43,6 +43,7 @@ pub fn make_v1_router() -> Router { .route("/vdb-search", telemetry_get!(handle_v1_vecdb_search)) .route("/vdb-status", telemetry_get!(handle_v1_vecdb_status)) + .route("/vdb-caps", telemetry_get!(handle_v1_vecdb_caps)) .route("/lsp-initialize", telemetry_post!(handle_v1_lsp_initialize)) .route("/lsp-did-changed", telemetry_post!(handle_v1_lsp_did_change)) diff --git a/src/http/routers/v1/vecdb.rs b/src/http/routers/v1/vecdb.rs index 6f38c9cc0..136c20efe 100644 --- a/src/http/routers/v1/vecdb.rs +++ b/src/http/routers/v1/vecdb.rs @@ -64,3 +64,25 @@ pub async fn handle_v1_vecdb_status( .body(Body::from(json!(status).to_string())) .unwrap()) } + +pub async fn handle_v1_vecdb_caps( + Extension(global_context): Extension, + _: hyper::body::Bytes, +) -> Result, ScratchError> { + let caps = { + let cx_locked = global_context.read().await; + let db = cx_locked.vec_db.lock().await; + if let Some(ref db) = *db { + db.caps().await + } else { + return Err(ScratchError::new( + StatusCode::INTERNAL_SERVER_ERROR, + "Vector db is not available".to_string() + )); + } + }; + Ok(Response::builder() + .status(StatusCode::OK) + .body(Body::from(json!(caps).to_string())) + .unwrap()) +} \ No newline at end of file diff --git a/src/vecdb/vecdb.rs b/src/vecdb/vecdb.rs index 320c926d4..f85b9dff7 100644 --- a/src/vecdb/vecdb.rs +++ b/src/vecdb/vecdb.rs @@ -2,6 +2,7 @@ use std::path::PathBuf; use std::sync::Arc; use async_trait::async_trait; +use serde::Serialize; use tracing::info; use tokio::sync::Mutex as AMutex; use tokio::task::JoinHandle; @@ -31,6 +32,11 @@ pub struct VecDb { } +#[derive(Debug, Serialize)] +pub struct VecDbCaps { + chat_rag_functions: Vec, +} + pub async fn create_vecdb( default_embeddings_model: String, endpoint_embeddings_template: String, @@ -98,7 +104,7 @@ pub async fn vecdb_background_reload( }; if default_embeddings_model.is_empty() || endpoint_embeddings_template.is_empty() { - error!("vecd launch failed: default_embeddings_model.is_empty() || endpoint_embeddings_template.is_empty()"); + error!("vecdb launch failed: default_embeddings_model.is_empty() || endpoint_embeddings_template.is_empty()"); continue; } @@ -137,7 +143,7 @@ pub async fn vecdb_background_reload( let search_result = vecdb.search("".to_string(), 3).await; match search_result { Ok(_) => { - info!("vecdb: test search complete") + info!("vecdb: test search complete"); } Err(_) => { error!("vecdb: test search failed"); @@ -147,6 +153,7 @@ pub async fn vecdb_background_reload( { let mut gcx_locked = global_context.write().await; + gcx_locked.vec_db = Arc::new(AMutex::new(Some(vecdb))); info!("vecdb is launched successfully"); @@ -241,6 +248,12 @@ impl VecDb { self.add_or_update_files(files, true).await; info!("vecdb: init_folders complete"); } + + pub async fn caps(&self) -> VecDbCaps { + VecDbCaps { + chat_rag_functions: vec!["@workspace".to_string()], + } + } } From 60579f38eec362ccdc5b45da4b477ceb1d8579f0 Mon Sep 17 00:00:00 2001 From: Valeryi Date: Wed, 17 Jan 2024 16:57:41 +0000 Subject: [PATCH 09/54] chat_rag_functions -> functions --- src/vecdb/vecdb.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/vecdb/vecdb.rs b/src/vecdb/vecdb.rs index f85b9dff7..e23dbf6d4 100644 --- a/src/vecdb/vecdb.rs +++ b/src/vecdb/vecdb.rs @@ -34,7 +34,7 @@ pub struct VecDb { #[derive(Debug, Serialize)] pub struct VecDbCaps { - chat_rag_functions: Vec, + functions: Vec, } pub async fn create_vecdb( @@ -251,7 +251,7 @@ impl VecDb { pub async fn caps(&self) -> VecDbCaps { VecDbCaps { - chat_rag_functions: vec!["@workspace".to_string()], + functions: vec!["@workspace".to_string()], } } } From c076c4fc311ef6846445b26918ff3045f42287d3 Mon Sep 17 00:00:00 2001 From: Valeryi Date: Wed, 17 Jan 2024 19:35:39 +0000 Subject: [PATCH 10/54] simplified vecdb_background_reload --- src/vecdb/vecdb.rs | 213 +++++++++++++++++++++++---------------------- 1 file changed, 109 insertions(+), 104 deletions(-) diff --git a/src/vecdb/vecdb.rs b/src/vecdb/vecdb.rs index e23dbf6d4..87766e03a 100644 --- a/src/vecdb/vecdb.rs +++ b/src/vecdb/vecdb.rs @@ -37,21 +37,46 @@ pub struct VecDbCaps { functions: Vec, } -pub async fn create_vecdb( +struct VecDbParams { default_embeddings_model: String, endpoint_embeddings_template: String, endpoint_embeddings_style: String, size_embeddings: i32, +} + +async fn vecdb_test_request( + vecdb: &VecDb +) -> Result<(), String> { + let search_result = vecdb.search("".to_string(), 3).await; + match search_result { + Ok(_) => { + Ok(()) + } + Err(_) => { + error!("vecdb: test search failed"); + Err("vecdb: test search failed".to_string()) + } + } + +} - cmdline: &CommandLine, - cache_dir: &PathBuf, -) -> Option { - let vec_db = match VecDb::init( +async fn create_vecdb( + global_context: Arc>, + background_tasks: &mut BackgroundTasksHolder, + vdb_params: VecDbParams, +) -> Result<(), String> { + info!("vecdb: attempting to launch"); + + let (cache_dir, cmdline) = { + let gcx_locked = global_context.read().await; + (gcx_locked.cache_dir.clone(), gcx_locked.cmdline.clone()) + }; + let vec_db_mb = match VecDb::init( &cache_dir, cmdline.clone(), - size_embeddings, 60, 512, 1024, - default_embeddings_model.clone(), - endpoint_embeddings_template.clone(), - endpoint_embeddings_style.clone(), + vdb_params.size_embeddings, 60, 512, 1024, + vdb_params.default_embeddings_model, + vdb_params.endpoint_embeddings_template, + vdb_params.endpoint_embeddings_style, ).await { Ok(res) => Some(res), Err(err) => { @@ -62,120 +87,100 @@ pub async fn create_vecdb( Also, you can run this to erase your db: `rm -rf ~/.cache/refact/refact_vecdb_cache` After that restart this LSP server or your IDE.", err); - None + return Err(err); } }; - vec_db -} - -pub async fn vecdb_background_reload( - global_context: Arc>, -) { - let mut background_tasks = BackgroundTasksHolder::new(vec![]); - loop { - tokio::time::sleep(tokio::time::Duration::from_secs(30)).await; + let vec_db = vec_db_mb.unwrap(); - let (cache_dir, cmdline) = { - let gcx_locked = global_context.read().await; - let cache_dir = gcx_locked.cache_dir.clone(); - (&cache_dir.clone(), &gcx_locked.cmdline.clone()) - }; + match vecdb_test_request(&vec_db).await { + Ok(_) => {}, + Err(s) => {return Err(s);} + } + info!("vecdb: test request complete"); - let caps_mb = crate::global_context::try_load_caps_quickly_if_not_present(global_context.clone(), 0).await; + { + let mut gcx_locked = global_context.write().await; - if caps_mb.is_err() || !cmdline.vecdb { - continue; + if let Some(folders) = gcx_locked.lsp_backend_document_state.workspace_folders.clone().read().await.clone() { + let mut vec_db_lock = gcx_locked.vec_db.lock().await; + if let Some(ref mut db) = *vec_db_lock { + db.init_folders(folders).await; + } } + let mut tasks = vec_db.start_background_tasks().await; + tasks.extend(vec![tokio::spawn(vecdb::file_watcher_service::file_watcher_task(global_context.clone()))]); + background_tasks.extend(tasks); - let ( - default_embeddings_model, - endpoint_embeddings_template, - endpoint_embeddings_style, - size_embeddings, - ) = { - let caps = caps_mb.unwrap(); - let caps_locked = caps.read().unwrap(); - ( - caps_locked.default_embeddings_model.clone(), - caps_locked.endpoint_embeddings_template.clone(), - caps_locked.endpoint_embeddings_style.clone(), - caps_locked.size_embeddings.clone(), - ) - }; + gcx_locked.vec_db = Arc::new(AMutex::new(Some(vec_db))); - if default_embeddings_model.is_empty() || endpoint_embeddings_template.is_empty() { - error!("vecdb launch failed: default_embeddings_model.is_empty() || endpoint_embeddings_template.is_empty()"); - continue; - } + } + info!("vecdb: launch complete"); + Ok(()) +} +async fn proceed_vecdb_reload( + global_context: Arc>, +) -> (bool, Option) { + let caps = match crate::global_context::try_load_caps_quickly_if_not_present(global_context.clone(), 0).await { + Ok(caps) => caps, + Err(_) => { return (false, None) } + }; - match *global_context.write().await.vec_db.lock().await { - None => {} - Some(ref db) => { - if db.model_name == default_embeddings_model && - db.endpoint_template == endpoint_embeddings_template && - db.endpoint_embeddings_style == endpoint_embeddings_style { - continue; - } - } + let vdb_params = { + let caps_locked = caps.read().unwrap(); + VecDbParams { + default_embeddings_model: caps_locked.default_embeddings_model.clone(), + endpoint_embeddings_template: caps_locked.endpoint_embeddings_template.clone(), + endpoint_embeddings_style: caps_locked.endpoint_embeddings_style.clone(), + size_embeddings: caps_locked.size_embeddings.clone(), } + }; - info!("vecdb: attempting to launch"); - - background_tasks.abort().await; - background_tasks = BackgroundTasksHolder::new(vec![]); - - let vecdb_mb = create_vecdb( - default_embeddings_model.clone(), - endpoint_embeddings_template, - endpoint_embeddings_style, - size_embeddings, - - cmdline, - cache_dir - ).await; + if vdb_params.default_embeddings_model.is_empty() || vdb_params.endpoint_embeddings_template.is_empty() { + error!("vecdb launch failed: default_embeddings_model.is_empty() || endpoint_embeddings_template.is_empty()"); + return (false, None); + } - if vecdb_mb.is_none() { - continue; - } - let vecdb = vecdb_mb.unwrap(); - let search_result = vecdb.search("".to_string(), 3).await; - match search_result { - Ok(_) => { - info!("vecdb: test search complete"); - } - Err(_) => { - error!("vecdb: test search failed"); - continue; + match *global_context.write().await.vec_db.lock().await { + None => {} + Some(ref db) => { + if db.model_name == vdb_params.default_embeddings_model && + db.endpoint_template == vdb_params.endpoint_embeddings_template && + db.endpoint_embeddings_style == vdb_params.endpoint_embeddings_style { + return (false, None); } } + } + + return (true, Some(vdb_params)); +} + - { - let mut gcx_locked = global_context.write().await; - - gcx_locked.vec_db = Arc::new(AMutex::new(Some(vecdb))); - info!("vecdb is launched successfully"); - - background_tasks.extend(match *gcx_locked.vec_db.lock().await { - Some(ref db) => { - let mut tasks = db.start_background_tasks().await; - tasks.push( - tokio::spawn(vecdb::file_watcher_service::file_watcher_task(global_context.clone())) - ); - tasks - } - None => vec![] - }); - { - if let Some(folders) = gcx_locked.lsp_backend_document_state.workspace_folders.clone().read().await.clone() { - let mut vec_db_lock = gcx_locked.vec_db.lock().await; - if let Some(ref mut db) = *vec_db_lock { - db.init_folders(folders).await; - } - } +pub async fn vecdb_background_reload( + global_context: Arc>, +) { + let cmd_line = global_context.read().await.cmdline.clone(); + if !cmd_line.vecdb { + return; + } + let mut background_tasks = BackgroundTasksHolder::new(vec![]); + loop { + let (proceed, vdb_params_mb) = proceed_vecdb_reload(global_context.clone()).await; + if proceed || vdb_params_mb.is_some() { + background_tasks.abort().await; + background_tasks = BackgroundTasksHolder::new(vec![]); + + match create_vecdb( + global_context.clone(), + &mut background_tasks, + vdb_params_mb.unwrap(), + ).await{ + Ok(_) => {} + Err(err) => {error!("vecdb: reload failed: {}", err);} } } + tokio::time::sleep(tokio::time::Duration::from_secs(30)).await; } } From 8fbf578c01c98a515bf27c12783ef56f4c4c0a8e Mon Sep 17 00:00:00 2001 From: Valeryi Date: Thu, 18 Jan 2024 09:01:18 +0000 Subject: [PATCH 11/54] minor --- src/vecdb/vecdb.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vecdb/vecdb.rs b/src/vecdb/vecdb.rs index 87766e03a..1c111b092 100644 --- a/src/vecdb/vecdb.rs +++ b/src/vecdb/vecdb.rs @@ -167,7 +167,7 @@ pub async fn vecdb_background_reload( let mut background_tasks = BackgroundTasksHolder::new(vec![]); loop { let (proceed, vdb_params_mb) = proceed_vecdb_reload(global_context.clone()).await; - if proceed || vdb_params_mb.is_some() { + if proceed && vdb_params_mb.is_some() { background_tasks.abort().await; background_tasks = BackgroundTasksHolder::new(vec![]); From 580d84783f435d789b5b726b8865bd9cb8505098 Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Thu, 18 Jan 2024 11:25:35 +0100 Subject: [PATCH 12/54] rag logs --- src/vecdb/vectorizer_service.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/vecdb/vectorizer_service.rs b/src/vecdb/vectorizer_service.rs index 34d5b2277..819ee3b34 100644 --- a/src/vecdb/vectorizer_service.rs +++ b/src/vecdb/vectorizer_service.rs @@ -100,7 +100,6 @@ async fn vectorize_thread( } } }; - info!("Processing file: {}", path.display()); let split_data = match file_splitter.split(&path).await { Ok(data) => data, @@ -115,8 +114,11 @@ async fn vectorize_thread( .collect(); split_data_filtered = vecdb_handler.try_add_from_cache(split_data_filtered).await; drop(vecdb_handler); - info!("Retrieving embeddings for {} chunks", split_data_filtered.len()); + let last_30_chars: String = path.display().to_string().chars().rev().take(30).collect::().chars().rev().collect(); + info!("...{} embeddings todo/total {}/{}", last_30_chars, split_data_filtered.len(), split_data.len()); + + // TODO: replace with a batched call? let join_handles: Vec<_> = split_data_filtered.into_iter().map(|x| { let model_name_clone = model_name.clone(); let api_key_clone = api_key.clone(); From 1e1ec729011a7d3523d53f8a3bc0f381651b2ae3 Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Thu, 18 Jan 2024 11:37:58 +0100 Subject: [PATCH 13/54] rename proceed_vecdb_reload -> do_i_need_to_reload_vecdb --- src/vecdb/vecdb.rs | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/vecdb/vecdb.rs b/src/vecdb/vecdb.rs index 1c111b092..96778a94c 100644 --- a/src/vecdb/vecdb.rs +++ b/src/vecdb/vecdb.rs @@ -118,7 +118,7 @@ async fn create_vecdb( Ok(()) } -async fn proceed_vecdb_reload( +async fn do_i_need_to_reload_vecdb( global_context: Arc>, ) -> (bool, Option) { let caps = match crate::global_context::try_load_caps_quickly_if_not_present(global_context.clone(), 0).await { @@ -166,21 +166,23 @@ pub async fn vecdb_background_reload( } let mut background_tasks = BackgroundTasksHolder::new(vec![]); loop { - let (proceed, vdb_params_mb) = proceed_vecdb_reload(global_context.clone()).await; - if proceed && vdb_params_mb.is_some() { + let (need_reload, vdb_params_mb) = do_i_need_to_reload_vecdb(global_context.clone()).await; + if need_reload && vdb_params_mb.is_some() { background_tasks.abort().await; background_tasks = BackgroundTasksHolder::new(vec![]); - match create_vecdb( global_context.clone(), &mut background_tasks, vdb_params_mb.unwrap(), - ).await{ + ).await { Ok(_) => {} - Err(err) => {error!("vecdb: reload failed: {}", err);} + Err(err) => { + error!("vecdb: init failed: {}", err); + // global_context.vec_db stays None, the rest of the system continues working + } } } - tokio::time::sleep(tokio::time::Duration::from_secs(30)).await; + tokio::time::sleep(tokio::time::Duration::from_secs(60)).await; } } From 56db6f88e02c65a83fddbfc9717fefe44f0c7d6c Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Thu, 18 Jan 2024 12:05:46 +0100 Subject: [PATCH 14/54] improve /v1/vdb-search --- src/http/routers/v1.rs | 2 +- src/http/routers/v1/vecdb.rs | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/http/routers/v1.rs b/src/http/routers/v1.rs index afa547481..1d74cb60a 100644 --- a/src/http/routers/v1.rs +++ b/src/http/routers/v1.rs @@ -41,7 +41,7 @@ pub fn make_v1_router() -> Router { .route("/caps", telemetry_get!(handle_v1_caps)) .route("/graceful-shutdown", telemetry_get!(handle_v1_graceful_shutdown)) - .route("/vdb-search", telemetry_get!(handle_v1_vecdb_search)) + .route("/vdb-search", telemetry_post!(handle_v1_vecdb_search)) .route("/vdb-status", telemetry_get!(handle_v1_vecdb_status)) .route("/vdb-caps", telemetry_get!(handle_v1_vecdb_caps)) diff --git a/src/http/routers/v1/vecdb.rs b/src/http/routers/v1/vecdb.rs index 136c20efe..f17cc9905 100644 --- a/src/http/routers/v1/vecdb.rs +++ b/src/http/routers/v1/vecdb.rs @@ -7,6 +7,7 @@ use serde_json::json; use crate::custom_error::ScratchError; use crate::global_context::SharedGlobalContext; use crate::vecdb::structs::VecdbSearch; +use tracing::info; #[derive(Serialize, Deserialize, Clone)] struct VecDBPost { @@ -34,9 +35,12 @@ pub async fn handle_v1_vecdb_search( match search_res { Ok(search_res) => { + let json_string = serde_json::to_string_pretty(&search_res).map_err(|e| { + ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, format!("JSON serialization problem: {}", e)) + })?; Ok(Response::builder() .status(StatusCode::OK) - .body(Body::from(json!(search_res).to_string())) + .body(Body::from(json_string)) .unwrap()) } Err(e) => { From b2028bfaaad6309e177e467130562374f3fe45f9 Mon Sep 17 00:00:00 2001 From: JegernOUTT Date: Thu, 18 Jan 2024 22:53:41 +1030 Subject: [PATCH 15/54] add vecdb_forced_path fix bug with `add_from_cache` --- src/global_context.rs | 2 ++ src/vecdb/handler.rs | 4 ++++ src/vecdb/vecdb.rs | 6 +++++- 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/global_context.rs b/src/global_context.rs index 8598e1de2..2e241f07c 100644 --- a/src/global_context.rs +++ b/src/global_context.rs @@ -51,6 +51,8 @@ pub struct CommandLine { pub vecdb: bool, #[structopt(long, short = "f", default_value = "", help = "The path to jsonl file which contains filtered source files")] pub files_set_path: String, + #[structopt(long, default_value = "", help = "Vecdb forced path")] + pub vecdb_forced_path: String, } impl CommandLine { diff --git a/src/vecdb/handler.rs b/src/vecdb/handler.rs index 9a398e37c..2e77f692e 100644 --- a/src/vecdb/handler.rs +++ b/src/vecdb/handler.rs @@ -200,6 +200,10 @@ impl VecDBHandler { } pub async fn try_add_from_cache(&mut self, data: Vec) -> Vec { + if data.is_empty() { + return vec![]; + } + let hashes = data.iter().map(|x| x.window_text_hash.clone()).collect(); let (found_records, left_hashes) = self.get_records_from_cache(hashes).await; let left_results: Vec = diff --git a/src/vecdb/vecdb.rs b/src/vecdb/vecdb.rs index 96778a94c..b646d654c 100644 --- a/src/vecdb/vecdb.rs +++ b/src/vecdb/vecdb.rs @@ -71,8 +71,12 @@ async fn create_vecdb( let gcx_locked = global_context.read().await; (gcx_locked.cache_dir.clone(), gcx_locked.cmdline.clone()) }; + let base_dir: PathBuf = match cmdline.vecdb_forced_path.as_str() { + "" => cache_dir, + path => PathBuf::from(path), + }; let vec_db_mb = match VecDb::init( - &cache_dir, cmdline.clone(), + &base_dir, cmdline.clone(), vdb_params.size_embeddings, 60, 512, 1024, vdb_params.default_embeddings_model, vdb_params.endpoint_embeddings_template, From 323d13efe4e8eea1bcc03bfe0f8ba5d8aec525fb Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Sat, 20 Jan 2024 16:24:11 +0100 Subject: [PATCH 16/54] get rid of VecDBHandlerRef, don't eat caps error in vecdb --- src/global_context.rs | 3 ++- src/main.rs | 4 +--- src/vecdb/handler.rs | 3 +-- src/vecdb/vecdb.rs | 15 ++++++++++----- src/vecdb/vectorizer_service.rs | 29 +++++++++++++++-------------- 5 files changed, 29 insertions(+), 25 deletions(-) diff --git a/src/global_context.rs b/src/global_context.rs index 2e241f07c..cac582779 100644 --- a/src/global_context.rs +++ b/src/global_context.rs @@ -1,5 +1,5 @@ use std::collections::hash_map::DefaultHasher; -use tracing::info; +use tracing::{error, info}; use std::collections::HashMap; use std::future::Future; use std::hash::Hasher; @@ -130,6 +130,7 @@ pub async fn try_load_caps_quickly_if_not_present( Ok(caps) }, Err(e) => { + error!("load caps failed: \"{}\"", e); return Err(ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, format!("server is not reachable: {}", e))); } } diff --git a/src/main.rs b/src/main.rs index b86a3d58e..cbf9bf78d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,8 @@ use std::io::Write; use tokio::task::JoinHandle; -use tracing::{error, info, Level}; +use tracing::info; use tracing_appender; -use std::sync::Arc; -use tokio::sync::Mutex as AMutex; use crate::background_tasks::start_background_tasks; use crate::lsp::spawn_lsp_task; diff --git a/src/vecdb/handler.rs b/src/vecdb/handler.rs index 2e77f692e..109e31827 100644 --- a/src/vecdb/handler.rs +++ b/src/vecdb/handler.rs @@ -17,14 +17,13 @@ use futures_util::TryStreamExt; use lance::dataset::{WriteMode, WriteParams}; use log::info; use tempfile::{tempdir, TempDir}; -use tokio::sync::Mutex; +use tokio::sync::Mutex as AMutex; use tracing::error; use vectordb::database::Database; use vectordb::table::Table; use crate::vecdb::structs::{Record, SplitResult}; -pub type VecDBHandlerRef = Arc>; impl Debug for VecDBHandler { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { diff --git a/src/vecdb/vecdb.rs b/src/vecdb/vecdb.rs index b646d654c..303387a0a 100644 --- a/src/vecdb/vecdb.rs +++ b/src/vecdb/vecdb.rs @@ -14,15 +14,14 @@ use crate::background_tasks::BackgroundTasksHolder; use crate::fetch_embedding; use crate::vecdb; -use crate::vecdb::{file_filter}; -use crate::vecdb::handler::{VecDBHandler, VecDBHandlerRef}; +use crate::vecdb::handler::VecDBHandler; use crate::vecdb::vectorizer_service::FileVectorizerService; use crate::vecdb::structs::{SearchResult, VecdbSearch, VecDbStatus}; #[derive(Debug)] pub struct VecDb { - vecdb_handler: VecDBHandlerRef, + vecdb_handler: Arc>, retriever_service: Arc>, cmdline: CommandLine, @@ -37,6 +36,7 @@ pub struct VecDbCaps { functions: Vec, } +#[derive(Debug)] struct VecDbParams { default_embeddings_model: String, endpoint_embeddings_template: String, @@ -127,7 +127,11 @@ async fn do_i_need_to_reload_vecdb( ) -> (bool, Option) { let caps = match crate::global_context::try_load_caps_quickly_if_not_present(global_context.clone(), 0).await { Ok(caps) => caps, - Err(_) => { return (false, None) } + Err(e) => { + // This branch makes caps error disappear, unless we print it right here: + info!("vecdb: no caps, will not start or reload vecdb, the error was: {}", e); + return (false, None) + } }; let vdb_params = { @@ -151,7 +155,8 @@ async fn do_i_need_to_reload_vecdb( Some(ref db) => { if db.model_name == vdb_params.default_embeddings_model && db.endpoint_template == vdb_params.endpoint_embeddings_template && - db.endpoint_embeddings_style == vdb_params.endpoint_embeddings_style { + db.endpoint_embeddings_style == vdb_params.endpoint_embeddings_style + { return (false, None); } } diff --git a/src/vecdb/vectorizer_service.rs b/src/vecdb/vectorizer_service.rs index 819ee3b34..d1565df89 100644 --- a/src/vecdb/vectorizer_service.rs +++ b/src/vecdb/vectorizer_service.rs @@ -3,20 +3,21 @@ use std::path::PathBuf; use std::sync::Arc; use std::time::SystemTime; -use tokio::sync::{Mutex, Semaphore}; +use tokio::sync::Mutex as AMutex; +use tokio::sync::Semaphore; use tokio::task::JoinHandle; use tracing::info; use crate::vecdb::file_splitter::FileSplitter; -use crate::vecdb::handler::VecDBHandlerRef; +use crate::vecdb::handler::VecDBHandler; use crate::fetch_embedding::try_get_embedding; use crate::vecdb::structs::{Record, SplitResult, VecDbStatus, VecDbStatusRef}; #[derive(Debug)] pub struct FileVectorizerService { - update_request_queue: Arc>>, - output_queue: Arc>>, - vecdb_handler: VecDBHandlerRef, + update_request_queue: Arc>>, + output_queue: Arc>>, + vecdb_handler: Arc>, status: VecDbStatusRef, cooldown_secs: u64, splitter_window_size: usize, @@ -29,8 +30,8 @@ pub struct FileVectorizerService { } async fn cooldown_queue_thread( - update_request_queue: Arc>>, - out_queue: Arc>>, + update_request_queue: Arc>>, + out_queue: Arc>>, _status: VecDbStatusRef, cooldown_secs: u64, ) { @@ -65,8 +66,8 @@ async fn cooldown_queue_thread( async fn vectorize_thread( - queue: Arc>>, - vecdb_handler_ref: VecDBHandlerRef, + queue: Arc>>, + vecdb_handler_ref: Arc>, status: VecDbStatusRef, splitter_window_size: usize, splitter_soft_limit: usize, @@ -187,7 +188,7 @@ async fn vectorize_thread( } } -async fn cleanup_thread(vecdb_handler: VecDBHandlerRef) { +async fn cleanup_thread(vecdb_handler: Arc>) { loop { { let mut vecdb = vecdb_handler.lock().await; @@ -199,7 +200,7 @@ async fn cleanup_thread(vecdb_handler: VecDBHandlerRef) { impl FileVectorizerService { pub async fn new( - vecdb_handler: VecDBHandlerRef, + vecdb_handler: Arc>, cooldown_secs: u64, splitter_window_size: usize, splitter_soft_limit: usize, @@ -209,9 +210,9 @@ impl FileVectorizerService { endpoint_embeddings_style: String, endpoint_template: String, ) -> Self { - let update_request_queue = Arc::new(Mutex::new(VecDeque::new())); - let output_queue = Arc::new(Mutex::new(VecDeque::new())); - let status = Arc::new(Mutex::new( + let update_request_queue = Arc::new(AMutex::new(VecDeque::new())); + let output_queue = Arc::new(AMutex::new(VecDeque::new())); + let status = Arc::new(AMutex::new( VecDbStatus { unprocessed_files_count: 0, requests_made_since_start: 0, From bb277f996b8a841155eac68e9e8d2864868bbb8d Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Sun, 21 Jan 2024 08:06:06 +0100 Subject: [PATCH 17/54] oops --- src/vecdb/vecdb.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/vecdb/vecdb.rs b/src/vecdb/vecdb.rs index 303387a0a..0bc871439 100644 --- a/src/vecdb/vecdb.rs +++ b/src/vecdb/vecdb.rs @@ -14,6 +14,7 @@ use crate::background_tasks::BackgroundTasksHolder; use crate::fetch_embedding; use crate::vecdb; +use crate::vecdb::file_filter; use crate::vecdb::handler::VecDBHandler; use crate::vecdb::vectorizer_service::FileVectorizerService; use crate::vecdb::structs::{SearchResult, VecdbSearch, VecDbStatus}; From 4c3b2d9e8f9a64b744817a3bee7b150c23303621 Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Sun, 21 Jan 2024 08:29:14 +0100 Subject: [PATCH 18/54] warnings --- src/http/routers/v1/code_completion.rs | 4 +--- src/http/routers/v1/vecdb.rs | 1 - 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/http/routers/v1/code_completion.rs b/src/http/routers/v1/code_completion.rs index 06615da05..e2c4de0dc 100644 --- a/src/http/routers/v1/code_completion.rs +++ b/src/http/routers/v1/code_completion.rs @@ -4,10 +4,8 @@ use tokio::sync::RwLock as ARwLock; use axum::Extension; use axum::response::Result; -use futures_util::future::ok; use hyper::{Body, Response, StatusCode}; -use ropey::Rope; -use tracing::{error, info}; +use tracing::info; use crate::call_validation::{CodeCompletionPost, validate_post}; use crate::caps; diff --git a/src/http/routers/v1/vecdb.rs b/src/http/routers/v1/vecdb.rs index f17cc9905..9a1193030 100644 --- a/src/http/routers/v1/vecdb.rs +++ b/src/http/routers/v1/vecdb.rs @@ -7,7 +7,6 @@ use serde_json::json; use crate::custom_error::ScratchError; use crate::global_context::SharedGlobalContext; use crate::vecdb::structs::VecdbSearch; -use tracing::info; #[derive(Serialize, Deserialize, Clone)] struct VecDBPost { From 4a5f0bd93f328c3696a4d3086d82f0c07180f9f5 Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Sun, 21 Jan 2024 08:29:34 +0100 Subject: [PATCH 19/54] speed explanation --- src/vecdb/handler.rs | 12 ++++++++---- src/vecdb/vecdb.rs | 13 ++++++++++--- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/src/vecdb/handler.rs b/src/vecdb/handler.rs index 109e31827..d16b017d8 100644 --- a/src/vecdb/handler.rs +++ b/src/vecdb/handler.rs @@ -17,7 +17,6 @@ use futures_util::TryStreamExt; use lance::dataset::{WriteMode, WriteParams}; use log::info; use tempfile::{tempdir, TempDir}; -use tokio::sync::Mutex as AMutex; use tracing::error; use vectordb::database::Database; use vectordb::table::Table; @@ -425,7 +424,8 @@ impl VecDBHandler { }).collect() } - pub async fn search(&mut self, embedding: Vec, top_n: usize) -> vectordb::error::Result> { + pub async fn search(&mut self, embedding: Vec, top_n: usize) -> vectordb::error::Result> + { let query = self.data_table.clone() .search(Some(Float32Array::from(embedding.clone()))) .limit(top_n) @@ -438,7 +438,9 @@ impl VecDBHandler { VecDBHandler::parse_table_iter(record_batch, false, Some(&embedding)) } - pub async fn update_record_statistic(&mut self, records: Vec) { + pub async fn update_record_statistic(&mut self, records: Vec) + { + // TODO: very slow, 0.8s for db_size=1368, maybe make one update call? let now = SystemTime::now(); for record in records { for mut table in vec![self.data_table.clone(), self.cache_table.clone()] { @@ -451,7 +453,9 @@ impl VecDBHandler { self.checkout().await; } } - pub async fn cleanup_old_records(&mut self) -> Result<(), String> { + + pub async fn cleanup_old_records(&mut self) -> Result<(), String> + { let now = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap(); let q = format!("{} - time_last_used > {TWO_WEEKS} AND used_counter < {MIN_LIKES}", now.as_secs()); self.cache_table.delete(&*q).await.expect("could not delete old records"); diff --git a/src/vecdb/vecdb.rs b/src/vecdb/vecdb.rs index 0bc871439..89f3a2c6f 100644 --- a/src/vecdb/vecdb.rs +++ b/src/vecdb/vecdb.rs @@ -277,6 +277,7 @@ impl VecDb { #[async_trait] impl VecdbSearch for VecDb { async fn search(&self, query: String, top_n: usize) -> Result { + let t0 = std::time::Instant::now(); let embedding_mb = fetch_embedding::try_get_embedding( &self.endpoint_embeddings_style, &self.model_name, @@ -288,10 +289,16 @@ impl VecdbSearch for VecDb { if embedding_mb.is_err() { return Err("Failed to get embedding".to_string()); } - let mut binding = self.vecdb_handler.lock().await; + info!("search query {:?}, it took {:.3}s to vectorize the query", query, t0.elapsed().as_secs_f64()); - let results = binding.search(embedding_mb.unwrap(), top_n).await.unwrap(); - binding.update_record_statistic(results.clone()).await; + let mut handler_locked = self.vecdb_handler.lock().await; + let t1 = std::time::Instant::now(); + let results = handler_locked.search(embedding_mb.unwrap(), top_n).await.unwrap(); + info!("search itself {:.3}s", t1.elapsed().as_secs_f64()); + + let t2 = std::time::Instant::now(); + handler_locked.update_record_statistic(results.clone()).await; + info!("update_record_statistic {:.3}s", t2.elapsed().as_secs_f64()); Ok( SearchResult { query_text: query, From bc568665b09adcad483b4ecc270030e7b9bc65b7 Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Sun, 21 Jan 2024 10:34:03 +0100 Subject: [PATCH 20/54] VecdbConstants, logs --- src/http/routers/v1/vecdb.rs | 3 +- src/vecdb/structs.rs | 11 ++++ src/vecdb/vecdb.rs | 113 ++++++++++++-------------------- src/vecdb/vectorizer_service.rs | 86 +++++++++++------------- 4 files changed, 94 insertions(+), 119 deletions(-) diff --git a/src/http/routers/v1/vecdb.rs b/src/http/routers/v1/vecdb.rs index 9a1193030..d6f669d6f 100644 --- a/src/http/routers/v1/vecdb.rs +++ b/src/http/routers/v1/vecdb.rs @@ -61,10 +61,9 @@ pub async fn handle_v1_vecdb_status( )); } }; - Ok(Response::builder() .status(StatusCode::OK) - .body(Body::from(json!(status).to_string())) + .body(Body::from(serde_json::to_string_pretty(&status).unwrap())) .unwrap()) } diff --git a/src/vecdb/structs.rs b/src/vecdb/structs.rs index feaeed2f8..320db6922 100644 --- a/src/vecdb/structs.rs +++ b/src/vecdb/structs.rs @@ -16,6 +16,17 @@ pub trait VecdbSearch: Send { ) -> Result; } +#[derive(Debug, Clone)] +pub struct VecdbConstants { + // constant in a sense it cannot be changed without creating a new db + pub model_name: String, + pub embedding_size: i32, + pub endpoint_embeddings_template: String, + pub endpoint_embeddings_style: String, + pub cooldown_secs: u64, + pub splitter_window_size: usize, + pub splitter_soft_limit: usize, +} #[derive(Debug, Serialize, Deserialize, Clone)] pub struct VecDbStatus { diff --git a/src/vecdb/vecdb.rs b/src/vecdb/vecdb.rs index 89f3a2c6f..fdb0ec3cd 100644 --- a/src/vecdb/vecdb.rs +++ b/src/vecdb/vecdb.rs @@ -1,15 +1,15 @@ use std::path::PathBuf; use std::sync::Arc; +use std::sync::RwLock as StdRwLock; +use tokio::sync::RwLock as ARwLock; +use tokio::sync::Mutex as AMutex; +use tracing::{info, error}; use async_trait::async_trait; use serde::Serialize; -use tracing::info; -use tokio::sync::Mutex as AMutex; use tokio::task::JoinHandle; use crate::global_context::{CommandLine, GlobalContext}; -use tokio::sync::RwLock as ARwLock; use tower_lsp::lsp_types::WorkspaceFolder; -use tracing::error; use crate::background_tasks::BackgroundTasksHolder; use crate::fetch_embedding; @@ -17,18 +17,29 @@ use crate::vecdb; use crate::vecdb::file_filter; use crate::vecdb::handler::VecDBHandler; use crate::vecdb::vectorizer_service::FileVectorizerService; -use crate::vecdb::structs::{SearchResult, VecdbSearch, VecDbStatus}; - +use crate::vecdb::structs::{SearchResult, VecdbSearch, VecDbStatus, VecdbConstants}; + +fn vecdb_constants( + caps: Arc>, +) -> VecdbConstants { + let caps_locked = caps.read().unwrap(); + VecdbConstants { + model_name: caps_locked.default_embeddings_model.clone(), + embedding_size: caps_locked.size_embeddings.clone(), + endpoint_embeddings_template: caps_locked.endpoint_embeddings_template.clone(), + endpoint_embeddings_style: caps_locked.endpoint_embeddings_style.clone(), + cooldown_secs: 20, + splitter_window_size: 512, + splitter_soft_limit: 1024, + } +} #[derive(Debug)] pub struct VecDb { vecdb_handler: Arc>, retriever_service: Arc>, cmdline: CommandLine, - - model_name: String, - endpoint_template: String, - endpoint_embeddings_style: String, + constants: VecdbConstants, } @@ -37,18 +48,10 @@ pub struct VecDbCaps { functions: Vec, } -#[derive(Debug)] -struct VecDbParams { - default_embeddings_model: String, - endpoint_embeddings_template: String, - endpoint_embeddings_style: String, - size_embeddings: i32, -} - async fn vecdb_test_request( vecdb: &VecDb ) -> Result<(), String> { - let search_result = vecdb.search("".to_string(), 3).await; + let search_result = vecdb.search("test query".to_string(), 3).await; match search_result { Ok(_) => { Ok(()) @@ -64,7 +67,7 @@ async fn vecdb_test_request( async fn create_vecdb( global_context: Arc>, background_tasks: &mut BackgroundTasksHolder, - vdb_params: VecDbParams, + constants: VecdbConstants, ) -> Result<(), String> { info!("vecdb: attempting to launch"); @@ -77,11 +80,9 @@ async fn create_vecdb( path => PathBuf::from(path), }; let vec_db_mb = match VecDb::init( - &base_dir, cmdline.clone(), - vdb_params.size_embeddings, 60, 512, 1024, - vdb_params.default_embeddings_model, - vdb_params.endpoint_embeddings_template, - vdb_params.endpoint_embeddings_style, + &base_dir, + cmdline.clone(), + constants, ).await { Ok(res) => Some(res), Err(err) => { @@ -119,13 +120,12 @@ async fn create_vecdb( gcx_locked.vec_db = Arc::new(AMutex::new(Some(vec_db))); } - info!("vecdb: launch complete"); Ok(()) } async fn do_i_need_to_reload_vecdb( global_context: Arc>, -) -> (bool, Option) { +) -> (bool, Option) { let caps = match crate::global_context::try_load_caps_quickly_if_not_present(global_context.clone(), 0).await { Ok(caps) => caps, Err(e) => { @@ -134,36 +134,26 @@ async fn do_i_need_to_reload_vecdb( return (false, None) } }; + let consts = vecdb_constants(caps); - let vdb_params = { - let caps_locked = caps.read().unwrap(); - VecDbParams { - default_embeddings_model: caps_locked.default_embeddings_model.clone(), - endpoint_embeddings_template: caps_locked.endpoint_embeddings_template.clone(), - endpoint_embeddings_style: caps_locked.endpoint_embeddings_style.clone(), - size_embeddings: caps_locked.size_embeddings.clone(), - } - }; - - if vdb_params.default_embeddings_model.is_empty() || vdb_params.endpoint_embeddings_template.is_empty() { + if consts.model_name.is_empty() || consts.endpoint_embeddings_template.is_empty() { error!("vecdb launch failed: default_embeddings_model.is_empty() || endpoint_embeddings_template.is_empty()"); return (false, None); } - match *global_context.write().await.vec_db.lock().await { None => {} Some(ref db) => { - if db.model_name == vdb_params.default_embeddings_model && - db.endpoint_template == vdb_params.endpoint_embeddings_template && - db.endpoint_embeddings_style == vdb_params.endpoint_embeddings_style + if db.constants.model_name == consts.model_name && + db.constants.endpoint_embeddings_template == consts.endpoint_embeddings_template && + db.constants.endpoint_embeddings_style == consts.endpoint_embeddings_style { return (false, None); } } } - return (true, Some(vdb_params)); + return (true, Some(consts)); } @@ -176,14 +166,14 @@ pub async fn vecdb_background_reload( } let mut background_tasks = BackgroundTasksHolder::new(vec![]); loop { - let (need_reload, vdb_params_mb) = do_i_need_to_reload_vecdb(global_context.clone()).await; - if need_reload && vdb_params_mb.is_some() { + let (need_reload, consts) = do_i_need_to_reload_vecdb(global_context.clone()).await; + if need_reload && consts.is_some() { background_tasks.abort().await; background_tasks = BackgroundTasksHolder::new(vec![]); match create_vecdb( global_context.clone(), &mut background_tasks, - vdb_params_mb.unwrap(), + consts.unwrap(), ).await { Ok(_) => {} Err(err) => { @@ -200,40 +190,23 @@ impl VecDb { pub async fn init( cache_dir: &PathBuf, cmdline: CommandLine, - embedding_size: i32, - cooldown_secs: u64, - splitter_window_size: usize, - splitter_soft_limit: usize, - - model_name: String, - endpoint_template: String, - endpoint_embeddings_style: String, + constants: VecdbConstants, ) -> Result { - let handler = match VecDBHandler::init(cache_dir, &model_name, embedding_size).await { + let handler = match VecDBHandler::init(cache_dir, &constants.model_name, constants.embedding_size).await { Ok(res) => res, Err(err) => { return Err(err) } }; let vecdb_handler = Arc::new(AMutex::new(handler)); let retriever_service = Arc::new(AMutex::new(FileVectorizerService::new( vecdb_handler.clone(), - cooldown_secs, - splitter_window_size, - splitter_soft_limit, - - model_name.clone(), + constants.clone(), cmdline.api_key.clone(), - endpoint_embeddings_style.clone(), - endpoint_template.clone(), ).await)); - Ok(VecDb { vecdb_handler, retriever_service, cmdline: cmdline.clone(), - - model_name, - endpoint_template, - endpoint_embeddings_style, + constants: constants.clone(), }) } @@ -279,9 +252,9 @@ impl VecdbSearch for VecDb { async fn search(&self, query: String, top_n: usize) -> Result { let t0 = std::time::Instant::now(); let embedding_mb = fetch_embedding::try_get_embedding( - &self.endpoint_embeddings_style, - &self.model_name, - &self.endpoint_template, + &self.constants.endpoint_embeddings_style, + &self.constants.model_name, + &self.constants.endpoint_embeddings_template, query.clone(), &self.cmdline.api_key, 3 diff --git a/src/vecdb/vectorizer_service.rs b/src/vecdb/vectorizer_service.rs index d1565df89..0d5b5027b 100644 --- a/src/vecdb/vectorizer_service.rs +++ b/src/vecdb/vectorizer_service.rs @@ -1,8 +1,9 @@ use std::collections::{HashMap, VecDeque}; use std::path::PathBuf; -use std::sync::Arc; use std::time::SystemTime; - +use std::ops::Div; +use std::io::Write; +use std::sync::Arc; use tokio::sync::Mutex as AMutex; use tokio::sync::Semaphore; use tokio::task::JoinHandle; @@ -11,7 +12,7 @@ use tracing::info; use crate::vecdb::file_splitter::FileSplitter; use crate::vecdb::handler::VecDBHandler; use crate::fetch_embedding::try_get_embedding; -use crate::vecdb::structs::{Record, SplitResult, VecDbStatus, VecDbStatusRef}; +use crate::vecdb::structs::{Record, SplitResult, VecDbStatus, VecDbStatusRef, VecdbConstants}; #[derive(Debug)] pub struct FileVectorizerService { @@ -19,14 +20,8 @@ pub struct FileVectorizerService { output_queue: Arc>>, vecdb_handler: Arc>, status: VecDbStatusRef, - cooldown_secs: u64, - splitter_window_size: usize, - splitter_soft_limit: usize, - - model_name: String, + constants: VecdbConstants, api_key: String, - endpoint_embeddings_style: String, - endpoint_template: String, } async fn cooldown_queue_thread( @@ -51,16 +46,24 @@ async fn cooldown_queue_thread( } let mut paths_to_process: Vec = Vec::new(); + let mut stat_too_new = 0; + let mut stat_proceed = 0; for (path, time) in &last_updated { if time.elapsed().unwrap().as_secs() > cooldown_secs { paths_to_process.push(path.clone()); + stat_proceed += 1; + } else { + stat_too_new += 1; } } + if stat_proceed > 0 || stat_too_new > 0 { + info!("cooldown_queue_thread: {} files to process, {} files too new", stat_proceed, stat_too_new); + } for path in paths_to_process { last_updated.remove(&path); out_queue.lock().await.push_back(path); } - tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + tokio::time::sleep(tokio::time::Duration::from_secs(10)).await; } } @@ -69,18 +72,14 @@ async fn vectorize_thread( queue: Arc>>, vecdb_handler_ref: Arc>, status: VecDbStatusRef, - splitter_window_size: usize, - splitter_soft_limit: usize, - - model_name: String, + constants: VecdbConstants, api_key: String, - endpoint_embeddings_style: String, - endpoint_template: String, - max_concurrent_tasks: usize, ) { - let file_splitter = FileSplitter::new(splitter_window_size, splitter_soft_limit); + let file_splitter = FileSplitter::new(constants.splitter_window_size, constants.splitter_soft_limit); let semaphore = Arc::new(Semaphore::new(max_concurrent_tasks)); + let mut reported_unprocessed: usize = 0; + let mut reported_vecdb_complete: bool = false; loop { let (path_maybe, unprocessed_files_count) = { @@ -91,12 +90,22 @@ async fn vectorize_thread( (None, 0) } }; + if (unprocessed_files_count + 99).div(100) != (reported_unprocessed + 99).div(100) { + info!("have {} unprocessed files", unprocessed_files_count); + reported_unprocessed = unprocessed_files_count; + } status.lock().await.unprocessed_files_count = unprocessed_files_count; + reported_vecdb_complete &= unprocessed_files_count==0; let path = { match path_maybe { Some(path) => path, None => { - tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; + // No files left to process + if !reported_vecdb_complete { + write!(std::io::stderr(), "VECDB COMPLETE\n").unwrap(); + reported_vecdb_complete = true; + } + tokio::time::sleep(tokio::time::Duration::from_millis(5000)).await; continue; } } @@ -121,10 +130,10 @@ async fn vectorize_thread( // TODO: replace with a batched call? let join_handles: Vec<_> = split_data_filtered.into_iter().map(|x| { - let model_name_clone = model_name.clone(); + let model_name_clone = constants.model_name.clone(); let api_key_clone = api_key.clone(); - let endpoint_embeddings_style_clone = endpoint_embeddings_style.clone(); - let endpoint_template_clone = endpoint_template.clone(); + let endpoint_embeddings_style_clone = constants.endpoint_embeddings_style.clone(); + let endpoint_template_clone = constants.endpoint_embeddings_template.clone(); let semaphore_clone = Arc::clone(&semaphore); tokio::spawn(async move { @@ -166,7 +175,7 @@ async fn vectorize_thread( start_line: data_res.start_line, end_line: data_res.end_line, time_added: SystemTime::now(), - model_name: model_name.clone(), + model_name: constants.model_name.clone(), distance: -1.0, used_counter: 0, time_last_used: now, @@ -201,14 +210,8 @@ async fn cleanup_thread(vecdb_handler: Arc>) { impl FileVectorizerService { pub async fn new( vecdb_handler: Arc>, - cooldown_secs: u64, - splitter_window_size: usize, - splitter_soft_limit: usize, - - model_name: String, + constants: VecdbConstants, api_key: String, - endpoint_embeddings_style: String, - endpoint_template: String, ) -> Self { let update_request_queue = Arc::new(AMutex::new(VecDeque::new())); let output_queue = Arc::new(AMutex::new(VecDeque::new())); @@ -225,24 +228,19 @@ impl FileVectorizerService { output_queue: output_queue.clone(), vecdb_handler: vecdb_handler.clone(), status: status.clone(), - cooldown_secs, - splitter_window_size, - splitter_soft_limit, - - model_name, + constants, api_key, - endpoint_embeddings_style, - endpoint_template, } } - pub async fn start_background_tasks(&self) -> Vec> { + pub async fn start_background_tasks(&self) -> Vec> + { let cooldown_queue_join_handle = tokio::spawn( cooldown_queue_thread( self.update_request_queue.clone(), self.output_queue.clone(), self.status.clone(), - self.cooldown_secs, + self.constants.cooldown_secs, ) ); @@ -251,14 +249,8 @@ impl FileVectorizerService { self.output_queue.clone(), self.vecdb_handler.clone(), self.status.clone(), - self.splitter_window_size, - self.splitter_soft_limit, - - self.model_name.clone(), + self.constants.clone(), self.api_key.clone(), - self.endpoint_embeddings_style.clone(), - self.endpoint_template.clone(), - 4, ) ); From 0bba226e199cacb12b63c52f2b51c1a34a374823 Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Sun, 21 Jan 2024 10:46:21 +0100 Subject: [PATCH 21/54] log files found --- src/vecdb/vecdb.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/vecdb/vecdb.rs b/src/vecdb/vecdb.rs index fdb0ec3cd..2ed54e814 100644 --- a/src/vecdb/vecdb.rs +++ b/src/vecdb/vecdb.rs @@ -268,7 +268,10 @@ impl VecdbSearch for VecDb { let t1 = std::time::Instant::now(); let results = handler_locked.search(embedding_mb.unwrap(), top_n).await.unwrap(); info!("search itself {:.3}s", t1.elapsed().as_secs_f64()); - + for rec in results.iter() { + let last_30_chars: String = rec.file_path.display().to_string().chars().rev().take(30).collect::().chars().rev().collect(); + info!("found ...{}:{}-{}, distance: {:.3}", last_30_chars, rec.start_line, rec.end_line, rec.distance); + } let t2 = std::time::Instant::now(); handler_locked.update_record_statistic(results.clone()).await; info!("update_record_statistic {:.3}s", t2.elapsed().as_secs_f64()); From 093715356115426342abd51fb9472d6b3d1ea8e3 Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Sun, 21 Jan 2024 11:30:48 +0100 Subject: [PATCH 22/54] vecdb file queue debug info --- src/vecdb/file_watcher_service.rs | 8 ++++---- src/vecdb/vecdb.rs | 5 ++++- src/vecdb/vectorizer_service.rs | 11 ++++++++--- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/vecdb/file_watcher_service.rs b/src/vecdb/file_watcher_service.rs index ba739a373..7edb36d2f 100644 --- a/src/vecdb/file_watcher_service.rs +++ b/src/vecdb/file_watcher_service.rs @@ -5,7 +5,7 @@ use futures::{ channel::mpsc::{channel, Receiver}, SinkExt, StreamExt, }; -use log::info; +use log::{info, error}; use notify::{Config, Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher}; use serde_json::Value; use tokio::fs::File; @@ -64,7 +64,7 @@ pub async fn file_watcher_task( } let path = PathBuf::from(maybe_path); let load_data = || async { - let filenames_data = match parse_jsonl(&path).await { + let filenames_vec = match parse_jsonl(&path).await { Ok(data) => data, Err(_) => { info!("invalid jsonl file: {:?}", path); @@ -72,13 +72,13 @@ pub async fn file_watcher_task( } }; match *global_context.read().await.vec_db.lock().await { - Some(ref mut db) => db.add_or_update_files(filenames_data, true).await, + Some(ref mut db) => db.add_or_update_files(filenames_vec, true).await, None => {} }; }; if watcher.watch(path.as_ref(), RecursiveMode::Recursive).is_err() { - info!("file watcher: {:?} is already watching", path); + error!("file watcher {:?} failed to start watching", path); return; } load_data().await; diff --git a/src/vecdb/vecdb.rs b/src/vecdb/vecdb.rs index 2ed54e814..2d73e3448 100644 --- a/src/vecdb/vecdb.rs +++ b/src/vecdb/vecdb.rs @@ -113,6 +113,9 @@ async fn create_vecdb( db.init_folders(folders).await; } } + // FIXME: + // Race between file_watcher_task that puts files into the queue and + // vectorizer_service that can report empty queue as "vectorization finished" let mut tasks = vec_db.start_background_tasks().await; tasks.extend(vec![tokio::spawn(vecdb::file_watcher_service::file_watcher_task(global_context.clone()))]); background_tasks.extend(tasks); @@ -270,7 +273,7 @@ impl VecdbSearch for VecDb { info!("search itself {:.3}s", t1.elapsed().as_secs_f64()); for rec in results.iter() { let last_30_chars: String = rec.file_path.display().to_string().chars().rev().take(30).collect::().chars().rev().collect(); - info!("found ...{}:{}-{}, distance: {:.3}", last_30_chars, rec.start_line, rec.end_line, rec.distance); + info!("distance {:.3}, found ...{}:{}-{}, ", rec.distance, last_30_chars, rec.start_line, rec.end_line); } let t2 = std::time::Instant::now(); handler_locked.update_record_statistic(results.clone()).await; diff --git a/src/vecdb/vectorizer_service.rs b/src/vecdb/vectorizer_service.rs index 0d5b5027b..a1e932932 100644 --- a/src/vecdb/vectorizer_service.rs +++ b/src/vecdb/vectorizer_service.rs @@ -30,12 +30,14 @@ async fn cooldown_queue_thread( _status: VecDbStatusRef, cooldown_secs: u64, ) { + // This function delays vectorization of a file, until mtime is at least cooldown_secs old. let mut last_updated: HashMap = HashMap::new(); loop { let (path_maybe, _unprocessed_files_count) = { let mut queue_locked = update_request_queue.lock().await; + let queue_len = queue_locked.len(); if !queue_locked.is_empty() { - (Some(queue_locked.pop_front().unwrap()), queue_locked.len()) + (Some(queue_locked.pop_front().unwrap()), queue_len) } else { (None, 0) } @@ -84,8 +86,9 @@ async fn vectorize_thread( loop { let (path_maybe, unprocessed_files_count) = { let mut queue_locked = queue.lock().await; - if !queue_locked.is_empty() { - (Some(queue_locked.pop_front().unwrap()), queue_locked.len()) + let queue_len = queue_locked.len(); + if queue_len > 0 { + (Some(queue_locked.pop_front().unwrap()), queue_len) } else { (None, 0) } @@ -265,6 +268,7 @@ impl FileVectorizerService { } pub async fn process_file(&self, path: PathBuf, force: bool) { + info!("adding single file"); if !force { self.update_request_queue.lock().await.push_back(path); } else { @@ -273,6 +277,7 @@ impl FileVectorizerService { } pub async fn process_files(&self, paths: Vec, force: bool) { + info!("adding {} files", paths.len()); if !force { self.update_request_queue.lock().await.extend(paths); } else { From 74c9257f710b9eba64d05a8bbda878abb062ce29 Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Sun, 21 Jan 2024 11:35:33 +0100 Subject: [PATCH 23/54] VECDB COMPLETE in logs --- src/vecdb/vectorizer_service.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/vecdb/vectorizer_service.rs b/src/vecdb/vectorizer_service.rs index a1e932932..0d62c670d 100644 --- a/src/vecdb/vectorizer_service.rs +++ b/src/vecdb/vectorizer_service.rs @@ -105,7 +105,9 @@ async fn vectorize_thread( None => { // No files left to process if !reported_vecdb_complete { + // you can see "VECDB COMPLETE" sometimes faster vs info! logs, because it's stderr, not sure why that happens, but a bug write!(std::io::stderr(), "VECDB COMPLETE\n").unwrap(); + info!("VECDB COMPLETE"); // just to make sure reported_vecdb_complete = true; } tokio::time::sleep(tokio::time::Duration::from_millis(5000)).await; From 27a76bebe7119263b6e9802e6356463dc7f015ab Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Sun, 21 Jan 2024 11:37:38 +0100 Subject: [PATCH 24/54] VECDB COMPLETE resolved --- src/vecdb/vectorizer_service.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/vecdb/vectorizer_service.rs b/src/vecdb/vectorizer_service.rs index 0d62c670d..e8c1b95c3 100644 --- a/src/vecdb/vectorizer_service.rs +++ b/src/vecdb/vectorizer_service.rs @@ -105,9 +105,8 @@ async fn vectorize_thread( None => { // No files left to process if !reported_vecdb_complete { - // you can see "VECDB COMPLETE" sometimes faster vs info! logs, because it's stderr, not sure why that happens, but a bug write!(std::io::stderr(), "VECDB COMPLETE\n").unwrap(); - info!("VECDB COMPLETE"); // just to make sure + info!("VECDB COMPLETE"); // you can see "VECDB COMPLETE" sometimes faster vs logs reported_vecdb_complete = true; } tokio::time::sleep(tokio::time::Duration::from_millis(5000)).await; From 2c12c95850027832d68c9b0ae2aeab8253a9b43d Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Sun, 21 Jan 2024 12:26:31 +0100 Subject: [PATCH 25/54] maybe will delete old logs --- src/main.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/main.rs b/src/main.rs index cbf9bf78d..98a92588d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,5 @@ use std::io::Write; use tokio::task::JoinHandle; - use tracing::info; use tracing_appender; @@ -37,11 +36,12 @@ async fn main() { tracing_appender::non_blocking(std::io::stderr()) } else { write!(std::io::stderr(), "This rust binary keeps logs as files, rotated daily. Try\ntail -f {}/logs/\nor use --logs-stderr for debugging.\n\n", cache_dir.display()).unwrap(); - tracing_appender::non_blocking(tracing_appender::rolling::RollingFileAppender::new( - tracing_appender::rolling::Rotation::DAILY, - cache_dir.join("logs"), - "rustbinary", - )) + tracing_appender::non_blocking(tracing_appender::rolling::RollingFileAppender::builder() + .rotation(tracing_appender::rolling::Rotation::DAILY) + .filename_prefix("rustbinary") + .max_log_files(30) + .build(cache_dir.join("logs")).unwrap() + ) }; let _tracing = tracing_subscriber::fmt() .with_max_level(if cmdline.verbose {Level::DEBUG} else {Level::INFO}) From ec918847f1e6145258714fd8c1d861fed88e21b7 Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Sun, 21 Jan 2024 12:27:05 +0100 Subject: [PATCH 26/54] handle SIGUSR1, move signals from http to global --- src/global_context.rs | 52 +++++++++++++++++++++++++++++++++++++++++-- src/http.rs | 32 +------------------------- 2 files changed, 51 insertions(+), 33 deletions(-) diff --git a/src/global_context.rs b/src/global_context.rs index cac582779..b21351115 100644 --- a/src/global_context.rs +++ b/src/global_context.rs @@ -8,9 +8,15 @@ use std::sync::{Arc, Mutex}; use std::sync::RwLock as StdRwLock; use tokio::sync::{AcquireError, Mutex as AMutex, Semaphore, SemaphorePermit}; use tokio::sync::RwLock as ARwLock; +use std::io::Write; +use tokio::signal; + use tokenizers::Tokenizer; use structopt::StructOpt; -use std::io::Write; +use hyper::StatusCode; +use tower_lsp::lsp_types::WorkspaceFolder; + +use crate::custom_error::ScratchError; use async_trait::async_trait; use crate::caps::CodeAssistantCaps; use crate::completion_cache::CompletionCache; @@ -157,6 +163,49 @@ pub async fn look_for_piggyback_fields( } } +pub async fn block_until_signal(ask_shutdown_receiver: std::sync::mpsc::Receiver) { + let ctrl_c = async { + signal::ctrl_c() + .await + .expect("failed to install Ctrl+C handler"); + }; + + #[cfg(unix)] + let sigterm = async { + signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to install signal handler") + .recv() + .await; + }; + #[cfg(not(unix))] + let sigterm = std::future::pending::<()>(); + + #[cfg(unix)] + let sigusr1 = async { + signal::unix::signal(signal::unix::SignalKind::user_defined1()) + .expect("failed to install signal handler") + .recv() + .await; + }; + #[cfg(not(unix))] + let sigusr1 = std::future::pending::<()>(); + + tokio::select! { + _ = ctrl_c => { + info!("SIGINT signal received"); + }, + _ = sigterm => { + info!("SIGTERM signal received"); + }, + _ = sigusr1 => { + info!("SIGUSR1 signal received"); + }, + _ = tokio::task::spawn_blocking(move || ask_shutdown_receiver.recv()) => { + info!("graceful shutdown to store telemetry"); + } + } +} + pub async fn create_global_context( cache_dir: PathBuf, ) -> (Arc>, std::sync::mpsc::Receiver, CommandLine) { @@ -168,7 +217,6 @@ pub async fn create_global_context( } let http_client = http_client_builder.build().unwrap(); - let cx = GlobalContext { cmdline: cmdline.clone(), http_client, diff --git a/src/http.rs b/src/http.rs index 9cc078ad8..31dcb5cb9 100644 --- a/src/http.rs +++ b/src/http.rs @@ -3,7 +3,6 @@ use std::sync::Arc; use axum::{Extension, http::{StatusCode, Uri}, response::IntoResponse}; use hyper::Server; -use tokio::signal; use tokio::sync::RwLock as ARwLock; use tokio::task::JoinHandle; use tracing::{error, info}; @@ -19,35 +18,6 @@ async fn handler_404(path: Uri) -> impl IntoResponse { } -pub async fn shutdown_signal(ask_shutdown_receiver: std::sync::mpsc::Receiver) { - let ctrl_c = async { - signal::ctrl_c() - .await - .expect("failed to install Ctrl+C handler"); - }; - - #[cfg(unix)] - let terminate = async { - signal::unix::signal(signal::unix::SignalKind::terminate()) - .expect("failed to install signal handler") - .recv() - .await; - }; - - #[cfg(not(unix))] - let terminate = std::future::pending::<()>(); - - tokio::select! { - _ = ctrl_c => { - info!("SIGINT signal received"); - }, - _ = terminate => {}, - _ = tokio::task::spawn_blocking(move || ask_shutdown_receiver.recv()) => { - info!("graceful shutdown to store telemetry"); - } - } -} - pub async fn start_server( global_context: Arc>, ask_shutdown_receiver: std::sync::mpsc::Receiver, @@ -69,7 +39,7 @@ pub async fn start_server( let router = make_refact_http_server().layer(Extension(global_context.clone())); let server = builder .serve(router.into_make_service()) - .with_graceful_shutdown(shutdown_signal(ask_shutdown_receiver)); + .with_graceful_shutdown(crate::global_context::block_until_signal(ask_shutdown_receiver)); let resp = server.await.map_err(|e| format!("HTTP server error: {}", e)); if let Err(e) = resp { error!("server error: {}", e); From 4bf10997002c1b41bc598de89749b4818e281ec2 Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Sun, 21 Jan 2024 12:33:56 +0100 Subject: [PATCH 27/54] no HTTP port by default --- src/global_context.rs | 2 +- src/main.rs | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/global_context.rs b/src/global_context.rs index b21351115..ea58446e1 100644 --- a/src/global_context.rs +++ b/src/global_context.rs @@ -37,7 +37,7 @@ pub struct CommandLine { pub address_url: String, #[structopt(long, short="k", default_value="", help="The API key to authenticate your requests, will appear in HTTP requests this binary makes.")] pub api_key: String, - #[structopt(long, short="p", default_value="8001", help="Bind 127.0.0.1: to listen for HTTP requests, such as /v1/code-completion, /v1/chat, /v1/caps.")] + #[structopt(long, short="p", default_value="0", help="Bind 127.0.0.1: to listen for HTTP requests, such as /v1/code-completion, /v1/chat, /v1/caps.")] pub http_port: u16, #[structopt(long, default_value="", help="End-user client version, such as version of VS Code plugin.")] pub enduser_client_version: String, diff --git a/src/main.rs b/src/main.rs index 98a92588d..d3743a3df 100644 --- a/src/main.rs +++ b/src/main.rs @@ -72,6 +72,7 @@ async fn main() { } if should_start_lsp { if main_handle.is_none() { + // FIXME: this ignores crate::global_context::block_until_signal , important because now we have a database to corrupt main_handle = spawn_lsp_task(gcx.clone(), cmdline.clone()).await; } else { background_tasks.push_back(spawn_lsp_task(gcx.clone(), cmdline.clone()).await.unwrap()) From 3704cf7c797332d7e4a51257712342725c2cbb2b Mon Sep 17 00:00:00 2001 From: JegernOUTT Date: Mon, 29 Jan 2024 19:16:53 +1030 Subject: [PATCH 28/54] - sqlite instead of lancedb in the cache vecdb - removed unused tests - cached client for embedding requests - polishing the code - index building for vecdb (should be tested more) --- Cargo.toml | 5 + src/fetch_embedding.rs | 142 +-------- src/forward_to_hf_endpoint.rs | 19 +- src/forward_to_openai_endpoint.rs | 25 +- src/global_context.rs | 1 - src/http/routers/v1/vecdb.rs | 7 +- src/vecdb/handler.rs | 493 ++++++++++++++++++------------ src/vecdb/vecdb.rs | 9 +- src/vecdb/vectorizer_service.rs | 25 +- 9 files changed, 368 insertions(+), 358 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e3e00a641..ebd0cd7b8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,13 +44,18 @@ arrow-array = "47.0.0" arrow-schema= "47.0.0" async_once= "0.2.6" async-process = "2.0.1" +itertools = "0.11.0" lance = "=0.9.0" +lance-linalg = "=0.9.0" +lance-index = "=0.9.0" log = "0.4.20" md5 = "0.7" mockito = "0.28.0" notify = "6.1.1" +rusqlite = { version = "0.30.0", features = ["bundled"] } tempfile = "3.8.1" time = "0.3.30" +tokio-rusqlite = "0.5.0" vectordb = "=0.4.0" walkdir = "2.3" which = "5.0.0" diff --git a/src/fetch_embedding.rs b/src/fetch_embedding.rs index b50c17a43..06af007b8 100644 --- a/src/fetch_embedding.rs +++ b/src/fetch_embedding.rs @@ -1,10 +1,13 @@ +use std::sync::Arc; + +use tokio::sync::Mutex as AMutex; use tracing::error; use crate::forward_to_hf_endpoint::get_embedding_hf_style; use crate::forward_to_openai_endpoint::get_embedding_openai_style; - pub async fn get_embedding( + client: Arc>, endpoint_embeddings_style: &String, model_name: &String, endpoint_template: &String, @@ -12,8 +15,8 @@ pub async fn get_embedding( api_key: &String, ) -> Result, String> { match endpoint_embeddings_style.to_lowercase().as_str() { - "hf" => get_embedding_hf_style(text, endpoint_template, model_name, api_key).await, - "openai" => get_embedding_openai_style(text, endpoint_template, model_name, api_key).await, + "hf" => get_embedding_hf_style(client, text, endpoint_template, model_name, api_key).await, + "openai" => get_embedding_openai_style(client, text, endpoint_template, model_name, api_key).await, _ => { error!("Invalid endpoint_embeddings_style: {}", endpoint_embeddings_style); Err("Invalid endpoint_embeddings_style".to_string()) @@ -24,6 +27,7 @@ pub async fn get_embedding( // HF often returns 500 errors for no reason pub async fn try_get_embedding( + client: Arc>, endpoint_embeddings_style: &String, model_name: &String, endpoint_template: &String, @@ -36,6 +40,7 @@ pub async fn try_get_embedding( loop { retries += 1; match get_embedding( + client.clone(), endpoint_embeddings_style, model_name, endpoint_template, @@ -52,134 +57,3 @@ pub async fn try_get_embedding( } } } - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn test_valid_request() { - let _m = mockito::mock("POST", "/models/valid_model") - .with_status(200) - .with_body(r#"{"embedding": [1.0, 2.0, 3.0]}"#) - .create(); - - let text = "sample text".to_string(); - let model_name = "valid_model".to_string(); - let api_key = "valid_api_key".to_string(); - - let result = get_embedding(text, &model_name, api_key).await.unwrap(); - - assert!(result.is_ok()); - assert_eq!(result.unwrap(), vec![1.0, 2.0, 3.0]); - } - - #[tokio::test] - async fn test_invalid_api_key() { - let _m = mockito::mock("POST", "/models/valid_model") - .with_status(401) - .create(); - - let text = "sample text".to_string(); - let model_name = "valid_model".to_string(); - let api_key = "invalid_api_key".to_string(); - - let result = get_embedding(text, &model_name, api_key).await.unwrap(); - - assert!(result.is_err()); - } - - #[tokio::test] - async fn test_concurrent_requests() { - let mock = mockito::mock("POST", "/models/valid_model") - .with_status(200) - .with_body(r#"{"embedding": [1.0, 2.0, 3.0]}"#) - .expect(10) // Expect 10 calls - .create(); - - let handles: Vec<_> = (0..10).map(|_| { - let text = "sample text".to_string(); - let model_name = "valid_model".to_string(); - let api_key = "valid_api_key".to_string(); - - get_embedding(text, &model_name, api_key) - }).collect(); - - for handle in handles { - let result = handle.await.unwrap(); - assert!(result.is_ok()); - assert_eq!(result.unwrap(), vec![1.0, 2.0, 3.0]); - } - - mock.assert(); - } - - #[tokio::test] - async fn test_empty_text_input() { - let _m = mockito::mock("POST", "/models/valid_model") - .with_status(200) - .with_body(r#"{"embedding": []}"#) - .create(); - - let text = "".to_string(); - let model_name = "valid_model".to_string(); - let api_key = "valid_api_key".to_string(); - - let result = get_embedding(text, &model_name, api_key).await.unwrap(); - assert!(result.is_ok()); - assert_eq!(result.unwrap(), Vec::::new()); - } - - #[tokio::test] - async fn test_invalid_model_name() { - let _m = mockito::mock("POST", "/models/invalid_model") - .with_status(404) - .create(); - - let text = "sample text".to_string(); - let model_name = "invalid_model".to_string(); - let api_key = "valid_api_key".to_string(); - - let result = get_embedding(text, &model_name, api_key).await.unwrap(); - assert!(result.is_err()); - } - - #[tokio::test] - async fn test_network_failure() { - let _m = mockito::mock("POST", "/models/valid_model") - .with_status(500) // Internal Server Error to simulate server-side failure - .create(); - - let text = "sample text".to_string(); - let model_name = "valid_model".to_string(); - let api_key = "valid_api_key".to_string(); - - let result = get_embedding(text, &model_name, api_key).await.unwrap(); - assert!(result.is_err()); - } - - #[tokio::test] - async fn test_different_embeddings() { - let mock1 = mockito::mock("POST", "/models/model1") - .with_status(200) - .with_body(r#"{"embedding": [1.0, 2.0]}"#) - .create(); - - let mock2 = mockito::mock("POST", "/models/model2") - .with_status(200) - .with_body(r#"{"embedding": [3.0, 4.0]}"#) - .create(); - - let text = "sample text".to_string(); - let model_names = vec!["model1", "model2"]; - let api_key = "valid_api_key".to_string(); - - for model_name in model_names { - let result = get_embedding(text.clone(), &model_name.to_string(), api_key.clone()).await.unwrap(); - assert!(result.is_ok()); - } - - mock1.assert(); - mock2.assert(); - } -} \ No newline at end of file diff --git a/src/forward_to_hf_endpoint.rs b/src/forward_to_hf_endpoint.rs index 4b14ca96a..191f88926 100644 --- a/src/forward_to_hf_endpoint.rs +++ b/src/forward_to_hf_endpoint.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use reqwest::header::AUTHORIZATION; use reqwest::header::CONTENT_TYPE; use reqwest::header::HeaderMap; @@ -5,6 +7,8 @@ use reqwest::header::HeaderValue; use reqwest_eventsource::EventSource; use serde::Serialize; use serde_json::json; +use tokio::sync::Mutex as AMutex; + use crate::call_validation::SamplingParameters; // Idea: use USER_AGENT @@ -48,7 +52,10 @@ pub async fn forward_to_hf_style_endpoint( if status_code != 200 { return Err(format!("{} status={} text {}", url, status_code, response_txt)); } - Ok(serde_json::from_str(&response_txt).unwrap()) // FIXME: unwrap + Ok(match serde_json::from_str(&response_txt) { + Ok(json) => json, + Err(e) => return Err(format!("{}: {}", url, e)), + }) } @@ -79,8 +86,8 @@ pub async fn forward_to_hf_style_endpoint_streaming( }); let builder = client.post(&url) - .headers(headers) - .body(data.to_string()); + .headers(headers) + .body(data.to_string()); let event_source: EventSource = EventSource::new(builder).map_err(|e| format!("can't stream from {}: {}", url, e) )?; @@ -95,16 +102,16 @@ struct EmbeddingsPayloadHF { pub async fn get_embedding_hf_style( + client: Arc>, text: String, endpoint_template: &String, model_name: &String, api_key: &String, ) -> Result, String> { - let client = reqwest::Client::new(); let payload = EmbeddingsPayloadHF { inputs: text }; let url = endpoint_template.clone().replace("$MODEL", &model_name); - let maybe_response = client + let maybe_response = client.lock().await .post(&url) .bearer_auth(api_key.clone()) .json(&payload) @@ -121,7 +128,7 @@ pub async fn get_embedding_hf_style( } else { Err(format!("Failed to get a response: {:?}", response.status())) } - }, + } Err(err) => Err(format!("Failed to send a request: {:?}", err)), } } diff --git a/src/forward_to_openai_endpoint.rs b/src/forward_to_openai_endpoint.rs index c374e8b13..407b69735 100644 --- a/src/forward_to_openai_endpoint.rs +++ b/src/forward_to_openai_endpoint.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use reqwest::header::AUTHORIZATION; use reqwest::header::CONTENT_TYPE; use reqwest::header::HeaderMap; @@ -5,10 +7,11 @@ use reqwest::header::HeaderValue; use reqwest_eventsource::EventSource; use serde::Serialize; use serde_json::json; -use crate::call_validation; -use crate::call_validation::SamplingParameters; +use tokio::sync::Mutex as AMutex; use tracing::info; +use crate::call_validation; +use crate::call_validation::SamplingParameters; pub async fn forward_to_openai_style_endpoint( save_url: &mut String, @@ -43,10 +46,10 @@ pub async fn forward_to_openai_style_endpoint( } // When cancelling requests, coroutine ususally gets aborted here on the following line. let req = client.post(&url) - .headers(headers) - .body(data.to_string()) - .send() - .await; + .headers(headers) + .body(data.to_string()) + .send() + .await; let resp = req.map_err(|e| format!("{}", e))?; let status_code = resp.status().as_u16(); let response_txt = resp.text().await.map_err(|e| @@ -93,8 +96,8 @@ pub async fn forward_to_openai_style_endpoint_streaming( data["prompt"] = serde_json::Value::String(prompt.to_string()); } let builder = client.post(&url) - .headers(headers) - .body(data.to_string()); + .headers(headers) + .body(data.to_string()); let event_source: EventSource = EventSource::new(builder).map_err(|e| format!("can't stream from {}: {}", url, e) )?; @@ -120,19 +123,19 @@ struct EmbeddingsPayloadOpenAI { pub async fn get_embedding_openai_style( + client: Arc>, text: String, endpoint_template: &String, model_name: &String, api_key: &String, ) -> Result, String> { - let client = reqwest::Client::new(); let payload = EmbeddingsPayloadOpenAI { input: text, model: model_name.clone(), }; let url = endpoint_template.clone(); let api_key_clone = api_key.clone(); - let response = client + let response = client.lock().await .post(&url) .bearer_auth(api_key_clone.clone()) .json(&payload) @@ -142,7 +145,7 @@ pub async fn get_embedding_openai_style( if !response.status().is_success() { info!("get_embedding_openai_style: {:?}", response); - return Err(format!("get_embedding_openai_style: bad status: {:?}", response.status())) + return Err(format!("get_embedding_openai_style: bad status: {:?}", response.status())); } let json = response.json::() diff --git a/src/global_context.rs b/src/global_context.rs index ea58446e1..0d54ced39 100644 --- a/src/global_context.rs +++ b/src/global_context.rs @@ -77,7 +77,6 @@ pub struct LSPBackendDocumentState { pub workspace_folders: Arc>>>, } -// #[derive(Debug)] pub struct GlobalContext { pub cmdline: CommandLine, pub http_client: reqwest::Client, diff --git a/src/http/routers/v1/vecdb.rs b/src/http/routers/v1/vecdb.rs index d6f669d6f..73f2e1935 100644 --- a/src/http/routers/v1/vecdb.rs +++ b/src/http/routers/v1/vecdb.rs @@ -54,7 +54,12 @@ pub async fn handle_v1_vecdb_status( ) -> Result, ScratchError> { let cx_locked = global_context.read().await; let status = match *cx_locked.vec_db.lock().await { - Some(ref db) => db.get_status().await, + Some(ref db) => match db.get_status().await { + Ok(status) => status, + Err(err) => { + return Err(ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, err)); + } + }, None => { return Err(ScratchError::new( StatusCode::INTERNAL_SERVER_ERROR, "Vector db is not available".to_string() diff --git a/src/vecdb/handler.rs b/src/vecdb/handler.rs index d16b017d8..c278c787f 100644 --- a/src/vecdb/handler.rs +++ b/src/vecdb/handler.rs @@ -1,4 +1,5 @@ use std::any::Any; +use std::cmp::min; use std::collections::HashSet; use std::fmt::{Debug, Formatter}; use std::path::PathBuf; @@ -13,12 +14,21 @@ use arrow_array::{FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIter use arrow_array::cast::{as_fixed_size_list_array, as_primitive_array, as_string_array}; use arrow_array::types::{Float32Type, UInt64Type}; use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use itertools::Itertools; use futures_util::TryStreamExt; use lance::dataset::{WriteMode, WriteParams}; +use lance_linalg::distance::MetricType; +use lance_index::vector::ivf::IvfBuildParams; +use lance::index::vector::pq::PQBuildParams; use log::info; +use rusqlite::{OpenFlags, params, Result}; use tempfile::{tempdir, TempDir}; +use tokio::fs; +use tokio::sync::Mutex as AMutex; +use tokio_rusqlite::Connection; use tracing::error; use vectordb::database::Database; +use vectordb::index::vector::IvfPQIndexBuilder; use vectordb::table::Table; use crate::vecdb::structs::{Record, SplitResult}; @@ -31,9 +41,8 @@ impl Debug for VecDBHandler { } pub struct VecDBHandler { - cache_database: Database, + cache_database: Arc>, _data_database_temp_dir: TempDir, - cache_table: Table, data_table: Table, schema: SchemaRef, data_table_hashes: HashSet, @@ -58,7 +67,7 @@ const MIN_LIKES: i32 = 3; impl VecDBHandler { pub async fn init(cache_dir: &PathBuf, model_name: &String, embedding_size: i32) -> Result { let cache_dir_str = match cache_dir.join("refact_vecdb_cache") - .join(format!("model_{}_esize_{}", + .join(format!("model_{}_esize_{}.sqlite", model_name.replace("/", "_"), embedding_size )).to_str() { @@ -77,8 +86,18 @@ impl VecDBHandler { None => return Err(format!("{:?}", "Temp directory is not a valid path")), }; - let cache_database = match Database::connect(cache_dir_str.as_str()).await { - Ok(db) => db, + if !cache_dir.join("refact_vecdb_cache").exists() { + match fs::create_dir_all(cache_dir.join("refact_vecdb_cache")).await { + Ok(_) => {} + Err(e) => return Err(format!("{:?}", e)), + } + } + let cache_database = match Connection::open_with_flags( + cache_dir_str, OpenFlags::SQLITE_OPEN_READ_WRITE + | OpenFlags::SQLITE_OPEN_CREATE + | OpenFlags::SQLITE_OPEN_NO_MUTEX + | OpenFlags::SQLITE_OPEN_URI).await { + Ok(db) => Arc::new(AMutex::new(db)), Err(err) => return Err(format!("{:?}", err)) }; let temp_database = match Database::connect(data_database_temp_dir_str).await { @@ -99,16 +118,31 @@ impl VecDBHandler { Field::new("model_name", DataType::Utf8, true), Field::new("used_counter", DataType::UInt64, true), ])); - let cache_table = match cache_database.open_table("data").await { - Ok(table) => { table } - Err(_) => { - let batches_iter = RecordBatchIterator::new(vec![].into_iter().map(Ok), schema.clone()); - match cache_database.create_table("data", batches_iter, Option::from(WriteParams::default())).await { - Ok(table) => table, - Err(err) => return Err(format!("{:?}", err)) - } - } - }; + match cache_database.lock().await.call(|conn| { + conn.execute( + "CREATE TABLE IF NOT EXISTS data ( + vector BLOB, + window_text TEXT NOT NULL, + window_text_hash TEXT NOT NULL, + file_path TEXT NOT NULL, + start_line INTEGER NOT NULL, + end_line INTEGER NOT NULL, + time_added INTEGER NOT NULL, + time_last_used INTEGER NOT NULL, + model_name TEXT NOT NULL, + used_counter INTEGER NOT NULL + )", [], + )?; + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_window_text_hash ON data (window_text_hash)", + [], + )?; + Ok(()) + }).await { + Ok(_) => {} + Err(err) => return Err(format!("{:?}", err)) + } + let batches_iter = RecordBatchIterator::new(vec![].into_iter().map(Ok), schema.clone()); let data_table = match temp_database.create_table("data", batches_iter, Option::from(WriteParams::default())).await { Ok(table) => table, @@ -119,7 +153,6 @@ impl VecDBHandler { cache_database, _data_database_temp_dir: data_database_temp_dir, schema, - cache_table, data_table, data_table_hashes: HashSet::new(), embedding_size, @@ -129,72 +162,211 @@ impl VecDBHandler { async fn checkout(&mut self) { match self.data_table.checkout_latest().await { Ok(table) => { self.data_table = table } - Err(err) => error!("Error while checking out data table: {:?}", err) + Err(err) => error!("Error while checking out the data table: {:?}", err) } - match self.cache_table.checkout_latest().await { - Ok(table) => { self.cache_table = table } - Err(err) => error!("Error while checking out data table: {:?}", err) + match self.cache_database.lock().await.call(|connection| { + connection.cache_flush()?; + Ok({}) + }).await { + Ok(_) => {} + Err(err) => error!("Error while flushing cache: {:?}", err) } } - pub async fn size(&self) -> Result { - match self.data_table.count_rows().await { - Ok(size) => Ok(size), - Err(err) => Err(format!("{:?}", err)) + async fn get_records_from_cache(&mut self, hashes: Vec) -> Result<(Vec, Vec), String> { + let mut hashes_set: HashSet = HashSet::from_iter(hashes.iter().cloned()); + let placeholders: String = hashes.iter().map(|_| "?").collect::>().join(","); + let query = format!("SELECT * FROM data WHERE window_text_hash IN ({})", placeholders); + + let records = match self.cache_database.lock().await.call(move |connection| { + let mut statement = connection.prepare(&query)?; + let params = rusqlite::params_from_iter(hashes.iter()); + let records = statement.query_map(params, |row| { + let vector_blob: Vec = row.get(0)?; + let vector: Vec = vector_blob + .chunks_exact(4) + .map(|b| f32::from_ne_bytes(b.try_into().unwrap())) + .collect(); + + let file_path_str: String = row.get(3)?; + let file_path = PathBuf::from(file_path_str); + + let time_added_timestamp: i64 = row.get(6)?; + let time_added = SystemTime::UNIX_EPOCH + Duration::from_secs(time_added_timestamp as u64); + + let time_last_used_timestamp: i64 = row.get(7)?; + let time_last_used = SystemTime::UNIX_EPOCH + Duration::from_secs(time_last_used_timestamp as u64); + + Ok(Record { + vector: Some(vector), + window_text: row.get(1)?, + window_text_hash: row.get(2)?, + file_path: file_path, + start_line: row.get(4)?, + end_line: row.get(5)?, + time_added: time_added, + time_last_used: time_last_used, + model_name: row.get(8)?, + used_counter: row.get(9)?, + distance: -1.0, + }) + })? + .filter_map(|row| row.ok()) + .collect::>(); + Ok(records) + }).await { + Ok(records) => records, + Err(err) => return Err(format!("{:?}", err)) + }; + + for r in &records { + hashes_set.remove(&r.window_text_hash); } + Ok((records, hashes_set.iter().map(|x| x.clone()).collect())) } - pub async fn cache_size(&self) -> Result { - match self.cache_table.count_rows().await { - Ok(size) => Ok(size), + async fn insert_records_to_cache(&mut self, records: Vec) -> Result<(), String> { + match self.cache_database.lock().await.call(|connection| { + let transaction = connection.transaction()?; + for record in records { + let time_added = record.time_added.duration_since( + SystemTime::UNIX_EPOCH + ).unwrap_or(Duration::ZERO) + .as_secs(); + + let time_last_used = record.time_last_used.duration_since( + SystemTime::UNIX_EPOCH + ).unwrap_or(Duration::ZERO) + .as_secs(); + + let vector_as_bytes: Vec = record.vector.expect( + "An attempt to push vector-less data to cache DB" + ).iter() + .flat_map(|&num| num.to_ne_bytes()) + .collect(); + + match transaction.execute( + "INSERT INTO data (vector, window_text, window_text_hash, \ + file_path, start_line, end_line, time_added, \ + time_last_used, model_name, used_counter) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)", + rusqlite::params![ + vector_as_bytes, + record.window_text, + record.window_text_hash, + record.file_path.to_str(), + record.start_line, + record.end_line, + time_added as i64, + time_last_used as i64, + record.model_name, + record.used_counter, + ], + ) { + Ok(_) => {} + Err(err) => { + info!("Error while inserting record to cache: {:?}", err); + continue; + } + } + } + match transaction.commit() { + Ok(_) => Ok(()), + Err(err) => Err(err.into()) + } + }).await { + Ok(_) => Ok(()), Err(err) => Err(format!("{:?}", err)) } } - async fn get_records(&mut self, table: Table, _hashes: Vec) -> (Vec, Vec) { - let mut hashes: HashSet = HashSet::from_iter(_hashes); - let q = hashes.iter().map(|x| format!("'{}'", x)).collect::>().join(", "); - let records = table - .filter(format!("window_text_hash in ({})", q)) - .execute() - .await.unwrap() - .try_collect::>() - .await.unwrap(); - let record_batch = concat_batches(&self.schema, &records).unwrap(); - let records = VecDBHandler::parse_table_iter(record_batch, true, None).unwrap(); - for r in &records { - hashes.remove(&r.window_text_hash); + async fn remove_records_from_cache(&mut self, file_path: String) -> Result<(), String> { + match self.cache_database.lock().await.call(move |connection| { + match connection.execute( + "DELETE FROM data WHERE file_path = ?1", + params![file_path], + ) { + Ok(_) => Ok(()), + Err(err) => Err(err.into()) + } + }).await { + Ok(_) => Ok(()), + Err(err) => Err(format!("{:?}", err)) } - (records, hashes.into_iter().collect()) } - pub async fn _get_records_from_data(&mut self, hashes: Vec) -> (Vec, Vec) { - self.get_records(self.data_table.clone(), hashes).await - } - pub async fn get_records_from_cache(&mut self, hashes: Vec) -> (Vec, Vec) { - self.get_records(self.cache_table.clone(), hashes).await + async fn update_cache_records(&mut self, records: Vec) -> Result<(), String> { + let now = SystemTime::now().duration_since(std::time::UNIX_EPOCH) + .unwrap_or(Duration::ZERO) + .as_secs(); + match self.cache_database.lock().await.call(move |connection| { + let transaction = connection.transaction()?; + for record in records { + match transaction.execute( + "UPDATE data SET time_last_used = ?, used_counter = ? WHERE window_text_hash = ?", + params![ + now, + record.used_counter, + record.window_text_hash, + ], + ) { + Ok(_) => {} + Err(_) => { + continue; + } + }; + } + match transaction.commit() { + Ok(_) => Ok(()), + Err(err) => Err(err.into()) + } + }).await { + Ok(_) => Ok(()), + Err(err) => Err(format!("{:?}", err)) + } } - async fn get_record(&mut self, table: Table, hash: String) -> vectordb::error::Result { - let records = table - .filter(format!("window_text_hash == '{}'", hash)) - .execute() - .await? - .try_collect::>() - .await?; - let record_batch = concat_batches(&self.schema, &records)?; - let records = VecDBHandler::parse_table_iter(record_batch, true, None)?; - match records.get(0) { - Some(x) => Ok(x.clone()), - None => Err(vectordb::error::Error::Lance { message: format!("No record found for hash: {}", hash) }) + async fn delete_old_records_from_cache(&mut self) -> Result<(), String> { + let now = SystemTime::now().duration_since(std::time::UNIX_EPOCH) + .unwrap_or(Duration::ZERO) + .as_secs(); + + match self.cache_database.lock().await.call(move |connection| { + let transaction = connection.transaction()?; + + transaction.execute( + "DELETE FROM data WHERE (?1 - time_last_used > ?2) AND (used_counter < ?3)", + params![now, TWO_WEEKS, MIN_LIKES], + )?; + + transaction.execute( + "DELETE FROM data WHERE (?1 - time_last_used > ?2)", + params![now, ONE_MONTH], + )?; + + transaction.commit()?; + Ok({}) + }).await { + Ok(_) => Ok(()), + Err(err) => Err(format!("{:?}", err)) } } - pub async fn _get_record_from_data(&mut self, hash: String) -> vectordb::error::Result { - self.get_record(self.data_table.clone(), hash).await + pub async fn size(&self) -> Result { + match self.data_table.count_rows().await { + Ok(size) => Ok(size), + Err(err) => Err(format!("{:?}", err)) + } } - pub async fn _get_record_from_cache(&mut self, hash: String) -> vectordb::error::Result { - self.get_record(self.cache_table.clone(), hash).await + + pub async fn cache_size(&self) -> Result { + self.cache_database.lock().await.call(move |connection| { + let mut stmt = connection.prepare("SELECT COUNT(*) FROM data")?; + let count: usize = stmt.query_row([], |row| row.get(0))?; + Ok(count) + }).await + .map_err(|e| { + e.to_string() + }) } pub async fn try_add_from_cache(&mut self, data: Vec) -> Vec { @@ -203,7 +375,13 @@ impl VecDBHandler { } let hashes = data.iter().map(|x| x.window_text_hash.clone()).collect(); - let (found_records, left_hashes) = self.get_records_from_cache(hashes).await; + let (found_records, left_hashes) = match self.get_records_from_cache(hashes).await { + Ok(records) => records, + Err(err) => { + info!("Error while getting values from cache: {:?}", err); + return vec![]; + } + }; let left_results: Vec = data.into_iter().filter(|x| left_hashes.contains(&x.window_text_hash)).collect(); @@ -280,7 +458,7 @@ impl VecDBHandler { )], self.schema.clone(), ); - let cache_batches_iter = RecordBatchIterator::new( + RecordBatchIterator::new( vec![RecordBatch::try_new( self.schema.clone(), vec![ @@ -300,13 +478,7 @@ impl VecDBHandler { ); if add_to_cache { - let cache_res = self.cache_table.add( - cache_batches_iter, Option::from(WriteParams { - mode: WriteMode::Append, - ..Default::default() - }), - ); - match cache_res.await { + match self.insert_records_to_cache(records).await { Ok(_) => {} Err(err) => return Err(format!("{:?}", err)) }; @@ -334,25 +506,43 @@ impl VecDBHandler { Some(res) => res }; - // valerii: In documentation I found no way to preprocess strings to prevent SQL injections - match self.cache_table.delete( - format!("(file_path = \"{}\")", file_path_str).as_str() // TODO: Prevent a possible sql injection here - ).await { + match self.remove_records_from_cache(file_path_str.to_string()).await { Ok(_) => {} Err(err) => { - info!("Error while deleting from cache: {:?}", err); + info!("Error while deleting from cache table: {:?}", err); } } + // valerii: In documentation I found no way to preprocess strings to prevent SQL injections match self.data_table.delete( format!("(file_path = \"{}\")", file_path_str).as_str() // TODO: Prevent a possible sql injection here ).await { Ok(_) => {} Err(err) => { - info!("Error while deleting from cache: {:?}", err); + info!("Error while deleting from data table: {:?}", err); } } } + pub async fn create_index(&mut self) -> vectordb::error::Result<()> { + let size = self.size().await.unwrap_or(0); + if size == 0 { + return Err(vectordb::error::Error::Lance { + message: "The vector database is empty".to_string(), + }.into()); + } + self.data_table.create_index( + IvfPQIndexBuilder::default() + .column("vector".to_owned()) + .index_name("index".to_owned()) + .metric_type(MetricType::Cosine) + .ivf_params(IvfBuildParams { + num_partitions: min(size, 512), + ..IvfBuildParams::default() + }) + .replace(true) + ).await + } + pub fn contains(&self, hash: &str) -> bool { self.data_table_hashes.contains(hash) } @@ -424,9 +614,14 @@ impl VecDBHandler { }).collect() } - pub async fn search(&mut self, embedding: Vec, top_n: usize) -> vectordb::error::Result> - { - let query = self.data_table.clone() + pub async fn search( + &mut self, + embedding: Vec, + top_n: usize + ) -> vectordb::error::Result> { + let query = self + .data_table + .clone() .search(Some(Float32Array::from(embedding.clone()))) .limit(top_n) .use_index(true) @@ -435,133 +630,43 @@ impl VecDBHandler { .try_collect::>() .await?; let record_batch = concat_batches(&self.schema, &query)?; - VecDBHandler::parse_table_iter(record_batch, false, Some(&embedding)) + match VecDBHandler::parse_table_iter(record_batch, false, Some(&embedding)) { + Ok(records) => { + let filtered: Vec = records + .into_iter() + .sorted_unstable_by(|a, b| { + a.distance + .partial_cmp(&b.distance) + .unwrap_or(std::cmp::Ordering::Equal) + }) + .collect(); + Ok(filtered) + } + Err(err) => Err(err), + } } - pub async fn update_record_statistic(&mut self, records: Vec) - { - // TODO: very slow, 0.8s for db_size=1368, maybe make one update call? - let now = SystemTime::now(); - for record in records { - for mut table in vec![self.data_table.clone(), self.cache_table.clone()] { - let _ = table.update(Some(format!("window_text_hash == '{}'", record.window_text_hash.clone()).as_str()), - vec![ - ("used_counter", &(&record.used_counter + 1).to_string()), - ("time_last_used", &*now.elapsed().unwrap().as_secs().to_string()), - ]).await.unwrap(); + pub async fn update_record_statistic(&mut self, records: Vec) { + match self.update_cache_records(records).await { + Ok(_) => {} + Err(err) => { + info!("Error while deleting from data table: {:?}", err); } - self.checkout().await; } } - pub async fn cleanup_old_records(&mut self) -> Result<(), String> - { + pub async fn cleanup_old_records(&mut self) -> Result<(), String> { + info!("VECDB: Cleaning up old records"); + let now = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap(); let q = format!("{} - time_last_used > {TWO_WEEKS} AND used_counter < {MIN_LIKES}", now.as_secs()); - self.cache_table.delete(&*q).await.expect("could not delete old records"); self.data_table.delete(&*q).await.expect("could not delete old records"); - self.checkout().await; let q = format!("{} - time_last_used > {ONE_MONTH}", now.as_secs()); - self.cache_table.delete(&*q).await.expect("could not delete old records"); self.data_table.delete(&*q).await.expect("could not delete old records"); + + self.delete_old_records_from_cache().await.expect("could not delete old records"); self.checkout().await; Ok(()) } } - -#[cfg(test)] -mod tests { - use std::time::SystemTime; - - use tempfile::tempdir; - use tokio; - - use super::*; - - #[tokio::test] - async fn test_init() { - let temp_dir = tempdir().unwrap(); - let embedding_size = 2; - let mut handler = VecDBHandler::init( - temp_dir.path().to_path_buf(), - embedding_size, - ).await; - assert_eq!(handler.size().await, 0); - } - - #[tokio::test] - async fn test_add_or_update() { - let temp_dir = tempdir().unwrap(); - let embedding_size = 2; - let mut handler = VecDBHandler::init( - temp_dir.path().to_path_buf(), - embedding_size, - ).await; - let expected_size = 1; - - // Prepare a sample record - let records = vec![ - Record { - vector: Some(vec![1.0, 2.0]), // Example values - window_text: "sample text".to_string(), - window_text_hash: "hash1".to_string(), - file_path: PathBuf::from("/path/to/file"), - start_line: 1, - end_line: 2, - time_added: SystemTime::now(), - time_last_used: SystemTime::now(), - model_name: "model1".to_string(), - used_counter: 0, - distance: 1.0, - }, - ]; - - // Call add_or_update - handler.add_or_update(records, true).await.unwrap(); - - // Validate the records - assert_eq!(handler.size().await, expected_size); - } - - #[tokio::test] - async fn test_search() { - let temp_dir = tempdir().unwrap(); - let embedding_size = 4; - let mut handler = VecDBHandler::init( - temp_dir.path().to_path_buf(), - embedding_size, - ).await; - let top_n = 1; - - let time_added = SystemTime::now(); - let records = vec![ - Record { - vector: Some(vec![1.0, 2.0, 3.0, 4.0]), - window_text: "test text".to_string(), - window_text_hash: "hash2".to_string(), - file_path: PathBuf::from("/path/to/another/file"), - start_line: 3, - end_line: 4, - time_added: time_added, - time_last_used: time_added, - model_name: "model2".to_string(), - used_counter: 0, - distance: 1.0, - }, - ]; - handler.add_or_update(records, true).await.unwrap(); - - let query_embedding = vec![1.0, 2.0, 3.0, 4.0]; - let results = handler.search(query_embedding, top_n).await.unwrap(); - - assert!(!results.is_empty()); - assert_eq!(results[0].window_text, "test text"); - assert_eq!(results[0].window_text_hash, "hash2"); - assert_eq!(results[0].file_path, PathBuf::from("/path/to/another/file")); - assert_eq!(results[0].start_line, 3); - assert_eq!(results[0].end_line, 4); - assert_eq!(results[0].model_name, "model2"); - assert_eq!(results[0].distance, 1.0); - } -} diff --git a/src/vecdb/vecdb.rs b/src/vecdb/vecdb.rs index 2d73e3448..2b47fdded 100644 --- a/src/vecdb/vecdb.rs +++ b/src/vecdb/vecdb.rs @@ -34,8 +34,8 @@ fn vecdb_constants( } } -#[derive(Debug)] pub struct VecDb { + vecdb_emb_client: Arc>, vecdb_handler: Arc>, retriever_service: Arc>, cmdline: CommandLine, @@ -61,7 +61,6 @@ async fn vecdb_test_request( Err("vecdb: test search failed".to_string()) } } - } async fn create_vecdb( @@ -206,6 +205,7 @@ impl VecDb { cmdline.api_key.clone(), ).await)); Ok(VecDb { + vecdb_emb_client: Arc::new(AMutex::new(reqwest::Client::new())), vecdb_handler, retriever_service, cmdline: cmdline.clone(), @@ -215,7 +215,7 @@ impl VecDb { pub async fn start_background_tasks(&self) -> Vec> { info!("vecdb: start_background_tasks"); - return self.retriever_service.lock().await.start_background_tasks().await; + return self.retriever_service.lock().await.start_background_tasks(self.vecdb_emb_client.clone()).await; } pub async fn add_or_update_file(&mut self, file_path: PathBuf, force: bool) { @@ -255,12 +255,13 @@ impl VecdbSearch for VecDb { async fn search(&self, query: String, top_n: usize) -> Result { let t0 = std::time::Instant::now(); let embedding_mb = fetch_embedding::try_get_embedding( + self.vecdb_emb_client.clone(), &self.constants.endpoint_embeddings_style, &self.constants.model_name, &self.constants.endpoint_embeddings_template, query.clone(), &self.cmdline.api_key, - 3 + 5 ).await; if embedding_mb.is_err() { return Err("Failed to get embedding".to_string()); diff --git a/src/vecdb/vectorizer_service.rs b/src/vecdb/vectorizer_service.rs index e8c1b95c3..16ef87db9 100644 --- a/src/vecdb/vectorizer_service.rs +++ b/src/vecdb/vectorizer_service.rs @@ -71,6 +71,7 @@ async fn cooldown_queue_thread( async fn vectorize_thread( + client: Arc>, queue: Arc>>, vecdb_handler_ref: Arc>, status: VecDbStatusRef, @@ -105,9 +106,14 @@ async fn vectorize_thread( None => { // No files left to process if !reported_vecdb_complete { + reported_vecdb_complete = true; + info!("VECDB Creating index"); + match vecdb_handler_ref.lock().await.create_index().await { + Ok(_) => info!("VECDB CREATED INDEX"), + Err(err) => info!("VECDB Error creating index: {}", err) + } write!(std::io::stderr(), "VECDB COMPLETE\n").unwrap(); info!("VECDB COMPLETE"); // you can see "VECDB COMPLETE" sometimes faster vs logs - reported_vecdb_complete = true; } tokio::time::sleep(tokio::time::Duration::from_millis(5000)).await; continue; @@ -138,6 +144,8 @@ async fn vectorize_thread( let api_key_clone = api_key.clone(); let endpoint_embeddings_style_clone = constants.endpoint_embeddings_style.clone(); let endpoint_template_clone = constants.endpoint_embeddings_template.clone(); + let status_clone = Arc::clone(&status); + let client_clone = Arc::clone(&client); let semaphore_clone = Arc::clone(&semaphore); tokio::spawn(async move { @@ -149,13 +157,15 @@ async fn vectorize_thread( }; let result = try_get_embedding( + client_clone, &endpoint_embeddings_style_clone, &model_name_clone, &endpoint_template_clone, x.window_text.clone(), &api_key_clone, - 3, + 1, ).await; + status_clone.lock().await.requests_made_since_start += 1; drop(_permit); Some((x, result)) @@ -169,7 +179,6 @@ async fn vectorize_thread( match result_mb { Ok(result) => { let now = SystemTime::now(); - records.push( Record { vector: Some(result), @@ -178,7 +187,7 @@ async fn vectorize_thread( file_path: data_res.file_path, start_line: data_res.start_line, end_line: data_res.end_line, - time_added: SystemTime::now(), + time_added: now, model_name: constants.model_name.clone(), distance: -1.0, used_counter: 0, @@ -188,6 +197,7 @@ async fn vectorize_thread( } Err(e) => { info!("Error retrieving embeddings for {}: {}", data_res.file_path.to_str().unwrap(), e); + queue.lock().await.push_back(data_res.file_path); // push it back again } } } @@ -206,6 +216,7 @@ async fn cleanup_thread(vecdb_handler: Arc>) { { let mut vecdb = vecdb_handler.lock().await; let _ = vecdb.cleanup_old_records().await; + let _ = vecdb.create_index().await; } tokio::time::sleep(tokio::time::Duration::from_secs(2 * 3600)).await; } @@ -237,8 +248,7 @@ impl FileVectorizerService { } } - pub async fn start_background_tasks(&self) -> Vec> - { + pub async fn start_background_tasks(&self, vecdb_client: Arc>) -> Vec> { let cooldown_queue_join_handle = tokio::spawn( cooldown_queue_thread( self.update_request_queue.clone(), @@ -250,6 +260,7 @@ impl FileVectorizerService { let retrieve_thread_handle = tokio::spawn( vectorize_thread( + vecdb_client.clone(), self.output_queue.clone(), self.vecdb_handler.clone(), self.status.clone(), @@ -294,7 +305,7 @@ impl FileVectorizerService { }; status.db_cache_size = match self.vecdb_handler.lock().await.cache_size().await { Ok(res) => res, - Err(err) => return Err(err) + Err(err) => return Err(err.to_string()) }; Ok(status) } From 28852dd193b8f643472d131b1cc8a0f528a801ba Mon Sep 17 00:00:00 2001 From: Valerii Date: Tue, 30 Jan 2024 16:31:46 +0000 Subject: [PATCH 29/54] @ commands (#65) major: added @ commands enhanced logic of update_indexed_file_paths changelist: * update_indexed_file_paths ran in a wrong place * changed proximity metric to jaro-winkler * added is_cmd_executable param * added command_completion_error macro * various added at-command-preview implemented file command preview big one comments, renames fatman deadlock fix implemented format requirements just cleaning... minor renames and TODOs minor improvements format fixes & minor * misc --- Cargo.toml | 1 + src/at_commands/at_commands.rs | 15 ++ src/at_commands/at_file.rs | 89 +++++++++ src/at_commands/at_params.rs | 62 +++++++ src/at_commands/at_workspace.rs | 94 ++++++++++ src/at_commands/mod.rs | 7 + src/at_commands/query.rs | 78 ++++++++ src/at_commands/structs.rs | 122 +++++++++++++ src/at_commands/utils.rs | 43 +++++ src/http/routers/v1.rs | 4 + src/http/routers/v1/at_commands.rs | 169 ++++++++++++++++++ src/http/routers/v1/chat.rs | 2 - src/lsp.rs | 4 +- src/main.rs | 1 + src/restream.rs | 9 +- src/scratchpad_abstract.rs | 3 +- src/scratchpads/chat_generic.rs | 26 +-- src/scratchpads/chat_llama2.rs | 29 ++- src/scratchpads/chat_passthrough.rs | 47 ++--- src/scratchpads/chat_utils_rag.rs | 160 ++++------------- src/scratchpads/completion_single_file_fim.rs | 3 +- src/scratchpads/mod.rs | 18 +- src/vecdb/handler.rs | 50 +++++- src/vecdb/vecdb.rs | 24 ++- src/vecdb/vectorizer_service.rs | 12 ++ 25 files changed, 868 insertions(+), 204 deletions(-) create mode 100644 src/at_commands/at_commands.rs create mode 100644 src/at_commands/at_file.rs create mode 100644 src/at_commands/at_params.rs create mode 100644 src/at_commands/at_workspace.rs create mode 100644 src/at_commands/mod.rs create mode 100644 src/at_commands/query.rs create mode 100644 src/at_commands/structs.rs create mode 100644 src/at_commands/utils.rs create mode 100644 src/http/routers/v1/at_commands.rs diff --git a/Cargo.toml b/Cargo.toml index ebd0cd7b8..bbf95d065 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -59,3 +59,4 @@ tokio-rusqlite = "0.5.0" vectordb = "=0.4.0" walkdir = "2.3" which = "5.0.0" +strsim = "0.8.0" diff --git a/src/at_commands/at_commands.rs b/src/at_commands/at_commands.rs new file mode 100644 index 000000000..fdd715fda --- /dev/null +++ b/src/at_commands/at_commands.rs @@ -0,0 +1,15 @@ +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::Mutex as AMutex; + +use crate::at_commands::at_workspace::AtWorkspace; +use crate::at_commands::at_file::AtFile; +use crate::at_commands::structs::AtCommandKind; + + +pub async fn at_commands_dict() -> HashMap>> { + return HashMap::from([ + ("@workspace".to_string(), Arc::new(AMutex::new(AtCommandKind::AtWorkspace(AtWorkspace::new())))), + ("@file".to_string(), Arc::new(AMutex::new(AtCommandKind::AtFile(AtFile::new())))), + ]); +} diff --git a/src/at_commands/at_file.rs b/src/at_commands/at_file.rs new file mode 100644 index 000000000..a3f8b841f --- /dev/null +++ b/src/at_commands/at_file.rs @@ -0,0 +1,89 @@ +use std::sync::Arc; +use async_trait::async_trait; +use serde_json::{json, Value}; +use crate::at_commands::structs::{AtCommand, AtCommandsContext, AtParam, AtParamKind}; +use crate::at_commands::at_params::AtParamFilePath; +use tokio::sync::Mutex as AMutex; +use crate::at_commands::utils::compose_context_file_msg_from_result; +use crate::call_validation::{ChatMessage, ContextFile}; +use crate::vecdb::vecdb::FileSearchResult; + +pub struct AtFile { + pub name: String, + pub params: Vec>>, +} + +impl AtFile { + pub fn new() -> Self { + AtFile { + name: "@file".to_string(), + params: vec![ + Arc::new(AMutex::new(AtParamKind::AtParamFilePath(AtParamFilePath::new()))) + ], + } + } +} + +fn search2messages(result: &FileSearchResult) -> Vec { + // TODO: change to context_file, encode json including line1 line2 + vec![ChatMessage { + role: "user".to_string(), + content: format!("FILENAME:\n{}\nTEXT:\n{}\n", result.file_path, result.file_text) + }] +} + +fn search2json(result: &FileSearchResult) -> Value { + let cf = ContextFile { + file_name: result.file_path.clone().rsplit("/").next().unwrap_or(&result.file_path).to_string(), + file_content: result.file_text.clone(), + line1: 0, + line2: result.file_text.lines().count() as i32, + }; + compose_context_file_msg_from_result(&serde_json::to_value(&vec![cf]).unwrap_or(json!(null))) +} + +#[async_trait] +impl AtCommand for AtFile { + fn name(&self) -> &String { + &self.name + } + fn params(&self) -> &Vec>> { + &self.params + } + async fn are_args_valid(&self, args: &Vec, context: &AtCommandsContext) -> Vec { + let mut results = Vec::new(); + for (arg, param) in args.iter().zip(self.params.iter()) { + let param = param.lock().await; + results.push(param.is_value_valid(arg, context).await); + } + results + } + + async fn can_execute(&self, args: &Vec, context: &AtCommandsContext) -> bool { + if self.are_args_valid(args, context).await.iter().any(|&x| x == false) || args.len() != self.params.len() { + return false; + } + return true; + } + + async fn execute(&self, _query: &String, args: &Vec, _top_n: usize, context: &AtCommandsContext) -> Result<(Vec, Value), String> { + let can_execute = self.can_execute(args, context).await; + match *context.global_context.read().await.vec_db.lock().await { + Some(ref db) => { + if !can_execute { + return Err("incorrect arguments".to_string()); + } + let file_path = match args.get(0) { + Some(x) => x, + None => return Err("no file path".to_string()), + }; + let search_result = db.get_file_orig_text(file_path.clone()).await; + Ok(( + search2messages(&search_result), + search2json(&search_result) + )) + } + None => Err("vecdb is not available".to_string()) + } + } +} diff --git a/src/at_commands/at_params.rs b/src/at_commands/at_params.rs new file mode 100644 index 000000000..d421cf305 --- /dev/null +++ b/src/at_commands/at_params.rs @@ -0,0 +1,62 @@ +use crate::at_commands::structs::{AtCommandsContext, AtParam}; +use async_trait::async_trait; +use itertools::Itertools; +use strsim::jaro_winkler; + +pub struct AtParamFilePath { + pub name: String, +} + +impl AtParamFilePath { + pub fn new() -> Self { + Self { + name: "file_path".to_string() + } + } +} + +#[async_trait] +impl AtParam for AtParamFilePath { + fn name(&self) -> &String { + &self.name + } + async fn is_value_valid(&self, value: &String, context: &AtCommandsContext) -> bool { + match *context.global_context.read().await.vec_db.lock().await { + Some(ref db) => { + let index_file_paths = db.get_indexed_file_paths().await; + let index_file_paths = index_file_paths.lock().await; + index_file_paths.iter().any(|path| path.to_str().unwrap() == value) + } + None => false, + } + } + async fn complete(&self, value: &String, context: &AtCommandsContext, top_n: usize) -> Vec { + match *context.global_context.read().await.vec_db.lock().await { + Some(ref db) => { + let index_file_paths = db.get_indexed_file_paths().await; + let index_file_paths = index_file_paths.lock().await; + let mapped_paths = index_file_paths.iter().map(|f| { + ( + f, + jaro_winkler( + if value.starts_with("/") { + f.to_str().unwrap() + } else { + f.file_name().unwrap().to_str().unwrap() + }, + &value.to_string() + ) + ) + }); + let sorted_paths = mapped_paths + .sorted_by(|(_, dist1), (_, dist2)| dist1.partial_cmp(dist2).unwrap()) + .rev() + .map(|(path, _)| path.to_str().unwrap().to_string()) + .take(top_n) + .collect::>(); + return sorted_paths; + } + None => vec![] + } + } +} diff --git a/src/at_commands/at_workspace.rs b/src/at_commands/at_workspace.rs new file mode 100644 index 000000000..924fb93de --- /dev/null +++ b/src/at_commands/at_workspace.rs @@ -0,0 +1,94 @@ +use std::sync::Arc; +use async_trait::async_trait; +use serde_json::{json, Value}; +use crate::at_commands::structs::{AtCommand, AtCommandsContext, AtParam, AtParamKind}; +use tokio::sync::Mutex as AMutex; +use crate::at_commands::utils::compose_context_file_msg_from_result; +use crate::call_validation::{ChatMessage, ContextFile}; +use crate::vecdb::structs::{Record, VecdbSearch}; + + +pub struct AtWorkspace { + pub name: String, + pub params: Vec>>, +} + +impl AtWorkspace { + pub fn new() -> Self { + AtWorkspace { + name: "@workspace".to_string(), + params: vec![], + } + } +} + +fn record2chat_message(record: &Record) -> ChatMessage { + ChatMessage { + role: "context_file".to_string(), + content: json!(ContextFile { + file_name: record.file_path.to_str().unwrap().to_string(), //.rsplit('/').next().unwrap_or(&record.file_path.to_str().unwrap()).to_string(), + file_content: record.window_text.clone(), + line1: record.start_line as i32, + line2: record.end_line as i32, + }).to_string() + } +} + +fn search2messages(results: &Vec) -> Vec { + let mut messages = vec![]; + for r in results { + messages.push(record2chat_message(r)); + } + messages +} + +pub fn search2json( + results: &Vec +) -> Value { + let context_files: Vec = results + .iter() + .map(|x| { record2chat_message(x) }).collect(); + compose_context_file_msg_from_result(&serde_json::to_value(&context_files).unwrap_or(json!(null))) +} + +#[async_trait] +impl AtCommand for AtWorkspace { + fn name(&self) -> &String { + &self.name + } + fn params(&self) -> &Vec>> + { + &self.params + } + async fn are_args_valid(&self, args: &Vec, context: &AtCommandsContext) -> Vec { + let mut results = Vec::new(); + for (arg, param) in args.iter().zip(self.params.iter()) { + let param = param.lock().await; + results.push(param.is_value_valid(arg, context).await); + } + results + } + + async fn can_execute(&self, args: &Vec, context: &AtCommandsContext) -> bool { + if self.are_args_valid(args, context).await.iter().any(|&x| x == false) || args.len() != self.params.len() { + return false; + } + return true; + } + + async fn execute(&self, query: &String, _args: &Vec, top_n: usize, context: &AtCommandsContext) -> Result<(Vec, Value), String> { + match *context.global_context.read().await.vec_db.lock().await { + Some(ref db) => { + let search_result = db.search(query.clone(), top_n).await?; + let mut results = search_result.results.clone(); + results.dedup_by(|a, b| a.file_path == b.file_path && a.window_text == b.window_text); + + Ok(( + search2messages(&results), + search2json(&results) + )) + } + None => Err("vecdb is not available".to_string()) + } + } +} diff --git a/src/at_commands/mod.rs b/src/at_commands/mod.rs new file mode 100644 index 000000000..1c1c55818 --- /dev/null +++ b/src/at_commands/mod.rs @@ -0,0 +1,7 @@ +pub mod at_commands; +pub mod at_file; +pub mod at_workspace; +pub mod structs; +pub mod at_params; +pub mod utils; +pub mod query; diff --git a/src/at_commands/query.rs b/src/at_commands/query.rs new file mode 100644 index 000000000..150862dd2 --- /dev/null +++ b/src/at_commands/query.rs @@ -0,0 +1,78 @@ +#[derive(Clone)] +pub struct QueryLine { + pub value: String, + pub cursor_line_start: i64, + pub args: Vec, // the first argument is the command, the rest are arguments +} + +impl QueryLine { + pub fn new( + value: String, + cursor_rel: i64, + cursor_line_start: i64 + ) -> Self { + QueryLine { + value: value.clone(), + cursor_line_start, + args: parse_args_from_line(&value).iter_mut().map(|x| { + x.pos2 += 1; + x.focused = cursor_rel >= x.pos1 && cursor_rel <= x.pos2; + x.pos1 += cursor_line_start; + x.pos2 += cursor_line_start; + x.clone() + }).collect(), + } + } + + pub fn command(&self) -> Option<&QueryLineArg> { + self.args.first() + } + + pub fn get_args(&self) -> Vec<&QueryLineArg> { + self.args.iter().skip(1).collect() + } +} + +#[derive(Clone)] +pub struct QueryLineArg { + pub value: String, + pub pos1: i64, + pub pos2: i64, + pub focused: bool, + pub type_name: String, +} + +fn parse_args_from_line(line: &String) -> Vec { + let mut pos1: i64 = -1; + let mut value: String = "".to_string(); + let mut args: Vec = vec![]; + for (idx, ch) in line.chars().enumerate() { + let idx = idx as i64; + if value.is_empty() && ch.to_string() != " " { + pos1 = idx; + } + + if ch.to_string() != " " { + value.push(ch); + } + + if pos1 != -1 && (ch.to_string() == " " || idx == (line.len() -1) as i64) { + args.push(QueryLineArg{ + value: value.clone(), + pos1, + pos2: idx, + focused: false, + type_name: { + if value.starts_with("@") { + "command".to_string() + } else { + "arg".to_string() + } + } + }); + pos1 = -1; + value = "".to_string(); + } + } + args +} diff --git a/src/at_commands/structs.rs b/src/at_commands/structs.rs new file mode 100644 index 000000000..7f29566f5 --- /dev/null +++ b/src/at_commands/structs.rs @@ -0,0 +1,122 @@ +use std::collections::HashMap; +use std::sync::Arc; +use async_trait::async_trait; +use serde_json::Value; +use crate::global_context::GlobalContext; +use crate::at_commands::at_commands::at_commands_dict; +use tokio::sync::RwLock as ARwLock; +use tokio::sync::Mutex as AMutex; +use crate::at_commands::at_file::AtFile; +use crate::at_commands::at_params::AtParamFilePath; +use crate::at_commands::at_workspace::AtWorkspace; +use crate::call_validation::ChatMessage; + +pub struct AtCommandsContext { + pub global_context: Arc>, + pub at_commands: HashMap>>, +} + +impl AtCommandsContext { + pub async fn new(global_context: Arc>) -> Self { + AtCommandsContext { + global_context, + at_commands: at_commands_dict().await, + } + } +} + +#[async_trait] +pub trait AtCommand: Send + Sync { + fn name(&self) -> &String; + fn params(&self) -> &Vec>>; + async fn are_args_valid(&self, args: &Vec, context: &AtCommandsContext) -> Vec; + async fn can_execute(&self, args: &Vec, context: &AtCommandsContext) -> bool; + async fn execute(&self, query: &String, args: &Vec, top_n: usize, context: &AtCommandsContext) -> Result<(Vec, Value), String>; +} + +#[async_trait] +pub trait AtParam { + fn name(&self) -> &String; + async fn is_value_valid(&self, value: &String, context: &AtCommandsContext) -> bool; + async fn complete(&self, value: &String, context: &AtCommandsContext, top_n: usize) -> Vec; +} + +pub struct AtCommandCall { + pub command: Arc>, + pub args: Vec, +} + +impl AtCommandCall { + pub fn new(command: Arc>, args: Vec) -> Self { + AtCommandCall { + command, + args + } + } +} + +pub enum AtCommandKind { + AtWorkspace(AtWorkspace), + AtFile(AtFile), +} + +#[async_trait] +impl AtCommand for AtCommandKind { + fn name(&self) -> &String { + match self { + AtCommandKind::AtWorkspace(workspace) => workspace.name(), + AtCommandKind::AtFile(file) => file.name(), + } + } + + fn params(&self) -> &Vec>> { + match self { + AtCommandKind::AtWorkspace(workspace) => workspace.params(), + AtCommandKind::AtFile(file) => file.params(), + } + } + + async fn are_args_valid(&self, args: &Vec, context: &AtCommandsContext) -> Vec { + match self { + AtCommandKind::AtWorkspace(workspace) => workspace.are_args_valid(args, context).await, + AtCommandKind::AtFile(file) => file.are_args_valid(args, context).await, + } + } + async fn can_execute(&self, args: &Vec, context: &AtCommandsContext) -> bool { + match self { + AtCommandKind::AtWorkspace(workspace) => workspace.can_execute(args, context).await, + AtCommandKind::AtFile(file) => file.can_execute(args, context).await, + } + } + + async fn execute(&self, query: &String, args: &Vec, top_n: usize, context: &AtCommandsContext) -> Result<(Vec, Value), String> { + match self { + AtCommandKind::AtWorkspace(workspace) => workspace.execute(query, args, top_n, context).await, + AtCommandKind::AtFile(file) => file.execute(query, args, top_n, context).await, + } + } +} +pub enum AtParamKind { + AtParamFilePath(AtParamFilePath), +} + +#[async_trait] +impl AtParam for AtParamKind { + fn name(&self) -> &String { + match self { + AtParamKind::AtParamFilePath(param) => param.name(), + } + } + + async fn is_value_valid(&self, value: &String, context: &AtCommandsContext) -> bool { + match self { + AtParamKind::AtParamFilePath(param) => param.is_value_valid(value, context).await, + } + } + + async fn complete(&self, value: &String, context: &AtCommandsContext, top_n: usize) -> Vec { + match self { + AtParamKind::AtParamFilePath(param) => param.complete(value, context, top_n).await, + } + } +} diff --git a/src/at_commands/utils.rs b/src/at_commands/utils.rs new file mode 100644 index 000000000..e1818e790 --- /dev/null +++ b/src/at_commands/utils.rs @@ -0,0 +1,43 @@ +use serde_json::{json, Value}; +use crate::at_commands::structs::{AtCommand, AtCommandCall, AtCommandsContext}; + +pub fn compose_context_file_msg_from_result( + in_json: &Value, +) -> Value { + return json!({ + "choices": [{ + "delta": { + "content": in_json.clone(), + "role": "context_file" + }, + "finish_reason": null, + "index": 0 + }], + }); +} + +pub async fn find_valid_at_commands_in_query( + query: &String, + context: &AtCommandsContext, +) -> Vec { + let mut results = vec![]; + for line in query.lines() { + let line_words: Vec<&str> = line.split_whitespace().collect(); + let q_cmd_args = line_words.iter().skip(1).map(|x|x.to_string()).collect::>(); + + let q_cmd = match line_words.first() { + Some(x) => x, + None => continue, + }; + + let (_, cmd) = match context.at_commands.iter().find(|&(k, _v)| k == q_cmd) { + Some(x) => x, + None => continue, + }; + if !cmd.lock().await.can_execute(&q_cmd_args, context).await { + continue; + } + results.push(AtCommandCall::new(cmd.clone(), q_cmd_args.clone())); + } + results +} diff --git a/src/http/routers/v1.rs b/src/http/routers/v1.rs index 1d74cb60a..049ab5364 100644 --- a/src/http/routers/v1.rs +++ b/src/http/routers/v1.rs @@ -21,6 +21,7 @@ use crate::http::routers::v1::lsp_like_handlers::handle_v1_lsp_initialize; use crate::http::routers::v1::lsp_like_handlers::handle_v1_lsp_did_change; use crate::http::utils::telemetry_wrapper; use crate::http::routers::v1::vecdb::{handle_v1_vecdb_search, handle_v1_vecdb_status, handle_v1_vecdb_caps}; +use crate::http::routers::v1::at_commands::{handle_v1_command_completion, handle_v1_command_preview}; pub mod code_completion; pub mod chat; @@ -30,6 +31,7 @@ pub mod caps; pub mod graceful_shutdown; pub mod lsp_like_handlers; pub mod vecdb; +mod at_commands; pub fn make_v1_router() -> Router { Router::new() @@ -44,6 +46,8 @@ pub fn make_v1_router() -> Router { .route("/vdb-search", telemetry_post!(handle_v1_vecdb_search)) .route("/vdb-status", telemetry_get!(handle_v1_vecdb_status)) .route("/vdb-caps", telemetry_get!(handle_v1_vecdb_caps)) + .route("/at-command-completion", telemetry_post!(handle_v1_command_completion)) + .route("/at-command-preview", telemetry_post!(handle_v1_command_preview)) .route("/lsp-initialize", telemetry_post!(handle_v1_lsp_initialize)) .route("/lsp-did-changed", telemetry_post!(handle_v1_lsp_did_change)) diff --git a/src/http/routers/v1/at_commands.rs b/src/http/routers/v1/at_commands.rs new file mode 100644 index 000000000..1b95d9618 --- /dev/null +++ b/src/http/routers/v1/at_commands.rs @@ -0,0 +1,169 @@ +use axum::response::Result; +use axum::Extension; +use hyper::{Body, Response, StatusCode}; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use itertools::Itertools; +use serde_json::{Value}; +use tokio::sync::RwLock as ARwLock; +use strsim::jaro_winkler; +use crate::at_commands::structs::{AtCommand, AtCommandsContext, AtParam}; +use crate::at_commands::query::QueryLine; + +use crate::custom_error::ScratchError; +use crate::global_context::GlobalContext; + +#[derive(Serialize, Deserialize, Clone)] +struct CommandCompletionPost { + query: String, + cursor: i64, +} +#[derive(Serialize, Deserialize, Clone)] +struct CommandCompletionResponse { + completions: Vec, + replace: (i64, i64), + is_cmd_executable: bool, +} + +#[derive(Serialize, Deserialize, Clone)] +struct CommandPreviewPost { + query: String, +} + +#[derive(Serialize, Deserialize, Clone)] +struct CommandPreviewResponse { + messages: Vec, +} + +pub async fn handle_v1_command_completion( + Extension(global_context): Extension>>, + body_bytes: hyper::body::Bytes, +) -> Result, ScratchError> { + let context = AtCommandsContext::new(global_context.clone()).await; + let post = serde_json::from_slice::(&body_bytes) + .map_err(|e| ScratchError::new(StatusCode::UNPROCESSABLE_ENTITY, format!("JSON problem: {}", e)))?; + + let (query_line_val, cursor_rel, cursor_line_start) = get_line_with_cursor(&post.query, post.cursor.clone())?; + let query_line_val = query_line_val.chars().take(cursor_rel as usize).collect::(); + // info!(query_line_val); + let query_line = QueryLine::new(query_line_val, cursor_rel, cursor_line_start); + // for arg in query_line.args.iter() { + // info!("value: {}, focused: {}, type_name: {}; pos1: {}; pos2: {}", arg.value, arg.focused, arg.type_name, arg.pos1, arg.pos2); + // } + let (completions, is_cmd_executable, pos1, pos2) = command_completion(&query_line, &context).await?; + + let response = CommandCompletionResponse { + completions: completions.clone(), + replace: (pos1, pos2), + is_cmd_executable, + }; + + Ok(Response::builder() + .status(StatusCode::OK) + .body(Body::from(serde_json::to_string(&response).unwrap())) + .unwrap()) +} + +pub async fn handle_v1_command_preview( + Extension(global_context): Extension>>, + body_bytes: hyper::body::Bytes, +) -> Result, ScratchError> { + let context = AtCommandsContext::new(global_context.clone()).await; + let post = serde_json::from_slice::(&body_bytes) + .map_err(|e| ScratchError::new(StatusCode::UNPROCESSABLE_ENTITY, format!("JSON problem: {}", e)))?; + let valid_commands = crate::at_commands::utils::find_valid_at_commands_in_query(&post.query, &context).await; + if valid_commands.is_empty() { + return Err(ScratchError::new(StatusCode::OK, "no valid commands in query".to_string())); + } + + let mut preview_msgs = vec![]; + for cmd in valid_commands { + match cmd.command.lock().await.execute(&post.query, &cmd.args, 5, &context).await { + Ok((_, in_json)) => { + preview_msgs.push(in_json); + }, + Err(_) => {} + } + } + + let response = CommandPreviewResponse { + messages: preview_msgs, + }; + + Ok(Response::builder() + .status(StatusCode::OK) + .body(Body::from(serde_json::to_string(&response).unwrap())) + .unwrap()) +} + +fn get_line_with_cursor(query: &String, mut cursor: i64) -> Result<(String, i64, i64), ScratchError> { + let mut cursor_rel = cursor; + for line in query.lines() { + let line_length = line.len() as i64; + if cursor_rel <= line_length { + if !line.starts_with("@") { + return Err(ScratchError::new(StatusCode::OK, "no command provided".to_string())); + } + return Ok((line.to_string(), cursor_rel, cursor - cursor_rel)); + } + cursor_rel -= line_length + 1; // +1 to account for the newline character + } + return Err(ScratchError::new(StatusCode::OK, "cursor is incorrect".to_string())); +} + +async fn command_completion( + query_line: &QueryLine, + context: &AtCommandsContext, +) -> Result<(Vec, bool, i64, i64), ScratchError> { // returns ([possible, completions], good_as_it_is) + let q_cmd = match query_line.command() { + Some(x) => x, + None => { return Err(ScratchError::new(StatusCode::OK, "no command given".to_string()));} + }; + + let (_, cmd) = match context.at_commands.iter().find(|&(k, _v)| k == &q_cmd.value) { + Some(x) => x, + None => { + if !q_cmd.focused { + return Err(ScratchError::new(StatusCode::OK, "incorrect command given".to_string())); + } + return Ok((command_completion_options(&q_cmd.value, &context, 5).await, false, q_cmd.pos1, q_cmd.pos2)); + } + }; + if cmd.lock().await.can_execute(&query_line.get_args().iter().map(|x|x.value.clone()).collect(), context).await { + return Ok((vec![], true, -1, -1)); + } + + for (arg, param) in query_line.get_args().iter().zip(cmd.lock().await.params()) { + let is_valid = param.lock().await.is_value_valid(&arg.value, context).await; + if !is_valid { + return if arg.focused { + Ok((param.lock().await.complete(&arg.value, context, 5).await, false, arg.pos1, arg.pos2)) + } else { + Err(ScratchError::new(StatusCode::OK, "invalid parameter".to_string())) + } + } + + } + + return Ok((vec![], false, -1, -1)); +} + + +async fn command_completion_options( + q_cmd: &String, + context: &AtCommandsContext, + top_n: usize, +) -> Vec { + let at_commands_names = context.at_commands.iter().map(|(name, _cmd)| name.clone()).collect::>(); + at_commands_names + .iter() + .filter(|command| command.starts_with(q_cmd)) + .map(|command| { + (command, jaro_winkler(&command, q_cmd)) + }) + .sorted_by(|(_, dist1), (_, dist2)| dist1.partial_cmp(dist2).unwrap()) + .rev() + .take(top_n) + .map(|(command, _)| command.clone()) + .collect() +} diff --git a/src/http/routers/v1/chat.rs b/src/http/routers/v1/chat.rs index ad777dc7f..d20f73e71 100644 --- a/src/http/routers/v1/chat.rs +++ b/src/http/routers/v1/chat.rs @@ -55,7 +55,6 @@ pub async fn handle_v1_chat( let cx_locked = global_context.write().await; (cx_locked.http_client.clone(), cx_locked.cmdline.api_key.clone()) }; - let vecdb_search = global_context.read().await.vec_db.clone(); let mut scratchpad = scratchpads::create_chat_scratchpad( global_context.clone(), caps, @@ -63,7 +62,6 @@ pub async fn handle_v1_chat( chat_post.clone(), &scratchpad_name, &scratchpad_patch, - vecdb_search, ).await.map_err(|e| ScratchError::new(StatusCode::BAD_REQUEST, e) )?; diff --git a/src/lsp.rs b/src/lsp.rs index d290d65ae..303557233 100644 --- a/src/lsp.rs +++ b/src/lsp.rs @@ -202,7 +202,7 @@ impl LanguageServer for Backend { if let Some(folders) = params.workspace_folders { match *self.gcx.read().await.vec_db.lock().await { - Some(ref db) => db.init_folders(folders).await, + Some(ref mut db) => db.init_folders(folders).await, None => {}, }; } @@ -326,7 +326,7 @@ impl LanguageServer for Backend { .filter(|x| is_valid_file(&x)) .collect(); match *self.gcx.read().await.vec_db.lock().await { - Some(ref mut db) => db.add_or_update_files(files, false).await, + Some(ref db) => db.add_or_update_files(files, false).await, None => {} }; } diff --git a/src/main.rs b/src/main.rs index d3743a3df..7f7b14212 100644 --- a/src/main.rs +++ b/src/main.rs @@ -25,6 +25,7 @@ mod background_tasks; mod receive_workspace_changes; mod vecdb; mod fetch_embedding; +mod at_commands; #[tokio::main] diff --git a/src/restream.rs b/src/restream.rs index 3c786cf7e..907657b40 100644 --- a/src/restream.rs +++ b/src/restream.rs @@ -156,12 +156,11 @@ pub async fn scratchpad_interaction_stream( loop { let value_maybe = scratch.response_spontaneous(); if let Ok(value) = value_maybe { - if value == json!(null) { - break; + for el in value { + let value_str = format!("data: {}\n\n", serde_json::to_string(&el).unwrap()); + info!("yield: {:?}", value_str); + yield Result::<_, String>::Ok(value_str); } - let value_str = format!("data: {}\n\n", serde_json::to_string(&value).unwrap()); - info!("yield: {:?}", value_str); - yield Result::<_, String>::Ok(value_str); } else { let err_str = value_maybe.unwrap_err(); error!("response_spontaneous error: {}", err_str); diff --git a/src/scratchpad_abstract.rs b/src/scratchpad_abstract.rs index 9d2ce9e08..09a4c4df7 100644 --- a/src/scratchpad_abstract.rs +++ b/src/scratchpad_abstract.rs @@ -4,6 +4,7 @@ use std::sync::RwLock; use tokenizers::Tokenizer; use crate::call_validation::SamplingParameters; use async_trait::async_trait; +use serde_json::Value; #[async_trait] @@ -32,7 +33,7 @@ pub trait ScratchpadAbstract: Send { stop_length: bool, ) -> Result<(serde_json::Value, bool), String>; - fn response_spontaneous(&mut self) -> Result; + fn response_spontaneous(&mut self) -> Result, String> ; } diff --git a/src/scratchpads/chat_generic.rs b/src/scratchpads/chat_generic.rs index 13c8109c5..464050940 100644 --- a/src/scratchpads/chat_generic.rs +++ b/src/scratchpads/chat_generic.rs @@ -2,22 +2,25 @@ use std::sync::Arc; use std::sync::RwLock; use async_trait::async_trait; +use serde_json::Value; use tokenizers::Tokenizer; use tokio::sync::Mutex as AMutex; +use tokio::sync::RwLock as ARwLock; use tracing::info; use crate::call_validation::{ChatMessage, ChatPost, ContextFile, SamplingParameters}; +use crate::global_context::GlobalContext; use crate::scratchpad_abstract::HasTokenizerAndEot; use crate::scratchpad_abstract::ScratchpadAbstract; use crate::scratchpads::chat_utils_deltadelta::DeltaDeltaChatStreamer; use crate::scratchpads::chat_utils_limit_history::limit_messages_history; -use crate::scratchpads::chat_utils_rag::{chat_functions_middleware, HasVecdb, HasVecdbResults}; +use crate::scratchpads::chat_utils_rag::{chat_functions_middleware, HasVecdbResults}; use crate::vecdb::structs::VecdbSearch; const DEBUG: bool = true; -pub struct GenericChatScratchpad { +pub struct GenericChatScratchpad { pub t: HasTokenizerAndEot, pub dd: DeltaDeltaChatStreamer, pub post: ChatPost, @@ -26,16 +29,16 @@ pub struct GenericChatScratchpad { pub keyword_user: String, pub keyword_asst: String, pub default_system_message: String, - pub vecdb_search: Arc>>, pub has_vecdb_results: HasVecdbResults, + pub global_context: Arc>, } -impl GenericChatScratchpad { +impl GenericChatScratchpad { pub fn new( tokenizer: Arc>, post: ChatPost, - vecdb_search: Arc>>, - ) -> Self where T: VecdbSearch + 'static + Sync { + global_context: Arc>, + ) -> Self { GenericChatScratchpad { t: HasTokenizerAndEot::new(tokenizer), dd: DeltaDeltaChatStreamer::new(), @@ -45,14 +48,14 @@ impl GenericChatScratchpad { keyword_user: "".to_string(), keyword_asst: "".to_string(), default_system_message: "".to_string(), - vecdb_search, has_vecdb_results: HasVecdbResults::new(), + global_context, } } } #[async_trait] -impl ScratchpadAbstract for GenericChatScratchpad { +impl ScratchpadAbstract for GenericChatScratchpad { fn apply_model_adaptation_patch( &mut self, patch: &serde_json::Value, @@ -86,10 +89,7 @@ impl ScratchpadAbstract for GenericChatScratchpad< context_size: usize, sampling_parameters_to_patch: &mut SamplingParameters, ) -> Result { - match *self.vecdb_search.lock().await { - Some(ref db) => chat_functions_middleware(db, &mut self.post, 6, &mut self.has_vecdb_results).await, - None => {} - } + chat_functions_middleware(self.global_context.clone(), &mut self.post, 6, &mut self.has_vecdb_results).await; let limited_msgs: Vec = limit_messages_history(&self.t, &self.post.messages, self.post.parameters.max_new_tokens, context_size, &self.default_system_message)?; sampling_parameters_to_patch.stop = Some(self.dd.stop_list.clone()); @@ -152,7 +152,7 @@ impl ScratchpadAbstract for GenericChatScratchpad< self.dd.response_streaming(delta, stop_toks) } - fn response_spontaneous(&mut self) -> Result { + fn response_spontaneous(&mut self) -> Result, String> { return self.has_vecdb_results.response_streaming(); } } diff --git a/src/scratchpads/chat_llama2.rs b/src/scratchpads/chat_llama2.rs index 5f9f5597a..bcb51e002 100644 --- a/src/scratchpads/chat_llama2.rs +++ b/src/scratchpads/chat_llama2.rs @@ -2,40 +2,41 @@ use std::sync::Arc; use std::sync::RwLock as StdRwLock; use async_trait::async_trait; +use serde_json::Value; use tokenizers::Tokenizer; -use tokio::sync::Mutex as AMutex; +use tokio::sync::RwLock as ARwLock; use tracing::info; use crate::call_validation::{ChatMessage, ChatPost, ContextFile, SamplingParameters}; +use crate::global_context::GlobalContext; use crate::scratchpad_abstract::HasTokenizerAndEot; use crate::scratchpad_abstract::ScratchpadAbstract; use crate::scratchpads::chat_utils_deltadelta::DeltaDeltaChatStreamer; use crate::scratchpads::chat_utils_limit_history::limit_messages_history; -use crate::scratchpads::chat_utils_rag::{chat_functions_middleware, HasVecdb, HasVecdbResults}; -use crate::vecdb::structs::VecdbSearch; +use crate::scratchpads::chat_utils_rag::{chat_functions_middleware, HasVecdbResults}; const DEBUG: bool = true; // #[derive(Debug)] -pub struct ChatLlama2 { +pub struct ChatLlama2 { pub t: HasTokenizerAndEot, pub dd: DeltaDeltaChatStreamer, pub post: ChatPost, pub keyword_s: String, // "SYSTEM:" keyword means it's not one token pub keyword_slash_s: String, pub default_system_message: String, - pub vecdb_search: Arc>>, pub has_vecdb_results: HasVecdbResults, + pub global_context: Arc>, } -impl ChatLlama2 { +impl ChatLlama2 { pub fn new( tokenizer: Arc>, post: ChatPost, - vecdb_search: Arc>>, - ) -> Self where T: VecdbSearch + Send { + global_context: Arc>, + ) -> Self { ChatLlama2 { t: HasTokenizerAndEot::new(tokenizer), dd: DeltaDeltaChatStreamer::new(), @@ -43,14 +44,14 @@ impl ChatLlama2 { keyword_s: "".to_string(), keyword_slash_s: "".to_string(), default_system_message: "".to_string(), - vecdb_search, has_vecdb_results: HasVecdbResults::new(), + global_context, } } } #[async_trait] -impl ScratchpadAbstract for ChatLlama2 { +impl ScratchpadAbstract for ChatLlama2 { fn apply_model_adaptation_patch( &mut self, patch: &serde_json::Value, @@ -72,10 +73,8 @@ impl ScratchpadAbstract for ChatLlama2 { context_size: usize, sampling_parameters_to_patch: &mut SamplingParameters, ) -> Result { - match *self.vecdb_search.lock().await { - Some(ref db) => chat_functions_middleware(db, &mut self.post, 6, &mut self.has_vecdb_results).await, - None => {} - } + chat_functions_middleware(self.global_context.clone(), &mut self.post, 6, &mut self.has_vecdb_results).await; + let limited_msgs: Vec = limit_messages_history(&self.t, &self.post.messages, self.post.parameters.max_new_tokens, context_size, &self.default_system_message)?; sampling_parameters_to_patch.stop = Some(self.dd.stop_list.clone()); // loosely adapted from https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/model.py#L24 @@ -140,7 +139,7 @@ impl ScratchpadAbstract for ChatLlama2 { self.dd.response_streaming(delta, stop_toks) } - fn response_spontaneous(&mut self) -> Result { + fn response_spontaneous(&mut self) -> Result, String> { return self.has_vecdb_results.response_streaming(); } } diff --git a/src/scratchpads/chat_passthrough.rs b/src/scratchpads/chat_passthrough.rs index 0d70ec230..d10c75d79 100644 --- a/src/scratchpads/chat_passthrough.rs +++ b/src/scratchpads/chat_passthrough.rs @@ -1,46 +1,49 @@ use std::sync::Arc; use async_trait::async_trait; +use serde_json::Value; use tokio::sync::Mutex as AMutex; -use tracing::info; +use tracing::{error, info}; +use tokio::sync::RwLock as ARwLock; use crate::call_validation::{ChatMessage, ChatPost, ContextFile, SamplingParameters}; +use crate::global_context::GlobalContext; use crate::scratchpad_abstract::ScratchpadAbstract; use crate::scratchpads::chat_utils_limit_history::limit_messages_history_in_bytes; -use crate::scratchpads::chat_utils_rag::{chat_functions_middleware, HasVecdb, HasVecdbResults}; +use crate::scratchpads::chat_utils_rag::{chat_functions_middleware, HasVecdbResults}; use crate::vecdb::structs::VecdbSearch; const DEBUG: bool = true; // #[derive(Debug)] -pub struct ChatPassthrough { +pub struct ChatPassthrough { pub post: ChatPost, pub default_system_message: String, pub limit_bytes: usize, - pub vecdb_search: Arc>>, pub has_vecdb_results: HasVecdbResults, + pub global_context: Arc>, } const DEFAULT_LIMIT_BYTES: usize = 4096*6; -impl ChatPassthrough { +impl ChatPassthrough { pub fn new( post: ChatPost, - vecdb_search: Arc>>, - ) -> Self where T: VecdbSearch + 'static + Sync { + global_context: Arc>, + ) -> Self { ChatPassthrough { post, default_system_message: "".to_string(), limit_bytes: DEFAULT_LIMIT_BYTES, // one token translates to 3 bytes (not unicode chars) - vecdb_search, has_vecdb_results: HasVecdbResults::new(), + global_context, } } } #[async_trait] -impl ScratchpadAbstract for ChatPassthrough { +impl ScratchpadAbstract for ChatPassthrough { fn apply_model_adaptation_patch( &mut self, patch: &serde_json::Value, @@ -55,10 +58,8 @@ impl ScratchpadAbstract for ChatPassthrough { _context_size: usize, _sampling_parameters_to_patch: &mut SamplingParameters, ) -> Result { - match *self.vecdb_search.lock().await { - Some(ref db) => chat_functions_middleware(db, &mut self.post, 6, &mut self.has_vecdb_results).await, - None => {} - } + chat_functions_middleware(self.global_context.clone(), &mut self.post, 6, &mut self.has_vecdb_results).await; + let limited_msgs: Vec = limit_messages_history_in_bytes(&self.post.messages, self.limit_bytes, &self.default_system_message)?; info!("chat passthrough {} messages -> {} messages after applying limits and possibly adding the default system message", &limited_msgs.len(), &limited_msgs.len()); let mut filtered_msgs: Vec = Vec::::new(); @@ -66,12 +67,18 @@ impl ScratchpadAbstract for ChatPassthrough { if msg.role == "assistant" || msg.role == "system" || msg.role == "user" { filtered_msgs.push(msg.clone()); } else if msg.role == "context_file" { - let vector_of_context_files: Vec = serde_json::from_str(&msg.content).unwrap(); // FIXME unwrap - for context_file in &vector_of_context_files { - filtered_msgs.push(ChatMessage { - role: "user".to_string(), - content: format!("{}\n```\n{}```", context_file.file_name, context_file.file_content), - }); + info!(msg.content); + match serde_json::from_str(&msg.content) { + Ok(res) => { + let vector_of_context_files: Vec = res; + for context_file in &vector_of_context_files { + filtered_msgs.push(ChatMessage { + role: "user".to_string(), + content: format!("{}\n```\n{}```", context_file.file_name, context_file.file_content), + }); + } + }, + Err(e) => { error!("error parsing context file: {}", e); } } } } @@ -125,7 +132,7 @@ impl ScratchpadAbstract for ChatPassthrough { }); Ok((ans, finished)) } - fn response_spontaneous(&mut self) -> Result { + fn response_spontaneous(&mut self) -> Result, String> { return self.has_vecdb_results.response_streaming(); } } diff --git a/src/scratchpads/chat_utils_rag.rs b/src/scratchpads/chat_utils_rag.rs index 621eb1b5c..d2d1b417a 100644 --- a/src/scratchpads/chat_utils_rag.rs +++ b/src/scratchpads/chat_utils_rag.rs @@ -1,156 +1,58 @@ -use async_trait::async_trait; -use serde_json::json; -use tracing::info; +use std::sync::Arc; +use serde_json::Value; +use tokio::sync::RwLock as ARwLock; +use crate::at_commands::structs::{AtCommand, AtCommandsContext}; -use crate::call_validation::{ChatMessage, ChatPost, ContextFile}; -use crate::vecdb::structs::{SearchResult, VecdbSearch}; +use crate::call_validation::ChatPost; +use crate::global_context::GlobalContext; -pub async fn chat_functions_middleware( - vecdb: &T, - post: &mut ChatPost, - limit_examples_cnt: usize, - has_vecdb: &mut dyn HasVecdb, -) where T: VecdbSearch { - let latest_msg_cont = &post.messages.last().unwrap().content; - if latest_msg_cont.starts_with("@workspace") { - embed_vecdb_results(vecdb, post, limit_examples_cnt, has_vecdb).await; - } -} - -async fn embed_vecdb_results( - vecdb: &T, +pub async fn chat_functions_middleware( + global_context: Arc>, post: &mut ChatPost, - limit_examples_cnt: usize, - has_vecdb: &mut dyn HasVecdb, -) where T: VecdbSearch { - let latest_msg_cont = &post.messages.last().unwrap().content; - let vdb_resp = vecdb.search(latest_msg_cont.clone(), limit_examples_cnt).await; - - has_vecdb.add2messages( - vdb_resp, - &mut post.messages, - ).await; -} - -fn vecdb_resp_to_json( - resp: &Result -) -> serde_json::Result { - let mut context_files: Vec = match resp { - Ok(search_res) => { - search_res.results.iter().map(|x| ContextFile { - file_name: x.file_path.to_str().unwrap().to_string(), - file_content: x.window_text.clone(), - line1: x.start_line as i32, - line2: x.end_line as i32, - }).collect() + top_n: usize, + has_vecdb: &mut HasVecdbResults, +) { + let context = AtCommandsContext::new(global_context.clone()).await; + let query = &post.messages.last().unwrap().content.clone(); // latest_msg_cont + let valid_commands = crate::at_commands::utils::find_valid_at_commands_in_query(&query, &context).await; + + for cmd in valid_commands { + match cmd.command.lock().await.execute(query, &cmd.args, top_n, &context).await { + Ok((msgs, in_json)) => { + post.messages.extend(msgs); + has_vecdb.push_in_json(in_json); + }, + Err(_) => {} } - Err(_) => vec![], - }; - - context_files.dedup_by(|a, b| { - a.file_name == b.file_name && a.file_content == b.file_content - }); - - context_files.iter_mut().for_each(|file| { - file.file_name = file.file_name - .rsplit('/') - .next() - .unwrap_or(&file.file_name) - .to_string(); - }); - - serde_json::to_value(&context_files) -} - -fn vecdb_resp_to_prompt( - resp_mb: &Result -) -> String { - let mut cont = "".to_string(); - - if resp_mb.is_err() { - info!("VECDB ERR"); - return cont - } - let resp = resp_mb.as_ref().unwrap(); - let mut results = resp.results.clone(); - results.dedup_by(|a, b| a.file_path == b.file_path && a.window_text == b.window_text); - - cont.push_str("CONTEXT:\n"); - for res in results.iter() { - cont.push_str("FILENAME:\n"); - cont.push_str(res.file_path.clone().to_str().unwrap_or_else( || "")); - cont.push_str("\nTEXT:"); - cont.push_str(res.window_text.clone().as_str()); - cont.push_str("\n"); } - cont.push_str("\nRefer to the context to answer my next question.\n"); - info!("VECDB prompt:\n{}", cont); - cont } - pub struct HasVecdbResults { pub was_sent: bool, - pub in_json: serde_json::Value, + pub in_json: Vec, } impl HasVecdbResults { pub fn new() -> Self { HasVecdbResults { was_sent: false, - in_json: json!(null) + in_json: vec![], } } } -#[async_trait] -pub trait HasVecdb: Send { - async fn add2messages( - &mut self, - vdb_result_mb: Result, - messages: &mut Vec, - ); - fn response_streaming(&mut self) -> Result; -} - -#[async_trait] -impl HasVecdb for HasVecdbResults { - async fn add2messages( - &mut self, - result_mb: Result, - messages: &mut Vec, - ) { - // if messages.len() > 1 { - // return; - // } - *messages = [ - &messages[..messages.len() -1], - &[ChatMessage { - role: "user".to_string(), - content: vecdb_resp_to_prompt(&result_mb), - }], - &messages[messages.len() -1..], - ].concat(); - - self.in_json = vecdb_resp_to_json(&result_mb).unwrap_or_else(|_| json!(null)); +impl HasVecdbResults { + pub fn push_in_json(&mut self, value: Value) { + self.in_json.push(value); } - fn response_streaming(&mut self) -> Result { - if self.was_sent == true || self.in_json.is_null() { - return Ok(json!(null)); + pub fn response_streaming(&mut self) -> Result, String> { + if self.was_sent == true || self.in_json.is_empty() { + return Ok(vec![]); } self.was_sent = true; - return Ok(json!({ - "choices": [{ - "delta": { - "content": self.in_json.clone(), - "role": "context_file" - }, - "finish_reason": serde_json::Value::Null, - "index": 0 - }], - })); + Ok(self.in_json.clone()) } } diff --git a/src/scratchpads/completion_single_file_fim.rs b/src/scratchpads/completion_single_file_fim.rs index 7ba20c21c..7c45cd715 100644 --- a/src/scratchpads/completion_single_file_fim.rs +++ b/src/scratchpads/completion_single_file_fim.rs @@ -11,6 +11,7 @@ use tokenizers::Tokenizer; use ropey::Rope; use tracing::info; use async_trait::async_trait; +use serde_json::Value; use crate::completion_cache; use crate::telemetry::telemetry_structs; @@ -285,7 +286,7 @@ impl ScratchpadAbstract for SingleFileFIM { }); Ok((ans, finished)) } - fn response_spontaneous(&mut self) -> Result { + fn response_spontaneous(&mut self) -> Result, String> { return Err("".to_string()); } } diff --git a/src/scratchpads/mod.rs b/src/scratchpads/mod.rs index a5498623c..f379efccb 100644 --- a/src/scratchpads/mod.rs +++ b/src/scratchpads/mod.rs @@ -10,7 +10,7 @@ pub mod chat_llama2; pub mod chat_passthrough; pub mod chat_utils_deltadelta; pub mod chat_utils_limit_history; -mod chat_utils_rag; +pub mod chat_utils_rag; use crate::call_validation::CodeCompletionPost; use crate::call_validation::ChatPost; @@ -52,25 +52,23 @@ pub async fn create_code_completion_scratchpad( Ok(result) } -pub async fn create_chat_scratchpad( +pub async fn create_chat_scratchpad( global_context: Arc>, caps: Arc>, model_name_for_tokenizer: String, post: ChatPost, scratchpad_name: &str, scratchpad_patch: &serde_json::Value, - vecdb_search: Arc>>, -) -> Result, String> - where T: VecdbSearch + 'static + Sync { +) -> Result, String> { let mut result: Box; if scratchpad_name == "CHAT-GENERIC" { - let tokenizer_arc: Arc> = cached_tokenizers::cached_tokenizer(caps, global_context, model_name_for_tokenizer).await?; - result = Box::new(chat_generic::GenericChatScratchpad::new(tokenizer_arc, post, vecdb_search)); + let tokenizer_arc: Arc> = cached_tokenizers::cached_tokenizer(caps, global_context.clone(), model_name_for_tokenizer).await?; + result = Box::new(chat_generic::GenericChatScratchpad::new(tokenizer_arc, post, global_context.clone())); } else if scratchpad_name == "CHAT-LLAMA2" { - let tokenizer_arc: Arc> = cached_tokenizers::cached_tokenizer(caps, global_context, model_name_for_tokenizer).await?; - result = Box::new(chat_llama2::ChatLlama2::new(tokenizer_arc, post, vecdb_search)); + let tokenizer_arc: Arc> = cached_tokenizers::cached_tokenizer(caps, global_context.clone(), model_name_for_tokenizer).await?; + result = Box::new(chat_llama2::ChatLlama2::new(tokenizer_arc, post, global_context.clone())); } else if scratchpad_name == "PASSTHROUGH" { - result = Box::new(chat_passthrough::ChatPassthrough::new(post, vecdb_search)); + result = Box::new(chat_passthrough::ChatPassthrough::new(post, global_context.clone())); } else { return Err(format!("This rust binary doesn't have chat scratchpad \"{}\" compiled in", scratchpad_name)); } diff --git a/src/vecdb/handler.rs b/src/vecdb/handler.rs index c278c787f..7ea838236 100644 --- a/src/vecdb/handler.rs +++ b/src/vecdb/handler.rs @@ -6,7 +6,7 @@ use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; use std::time::SystemTime; - +use tokio::sync::Mutex as AMutex; use arrow::array::ArrayData; use arrow::buffer::Buffer; use arrow::compute::concat_batches; @@ -24,7 +24,6 @@ use log::info; use rusqlite::{OpenFlags, params, Result}; use tempfile::{tempdir, TempDir}; use tokio::fs; -use tokio::sync::Mutex as AMutex; use tokio_rusqlite::Connection; use tracing::error; use vectordb::database::Database; @@ -47,6 +46,7 @@ pub struct VecDBHandler { schema: SchemaRef, data_table_hashes: HashSet, embedding_size: i32, + indexed_file_paths: Arc>>, } fn cosine_similarity(vec1: &Vec, vec2: &Vec) -> f32 { @@ -156,6 +156,7 @@ impl VecDBHandler { data_table, data_table_hashes: HashSet::new(), embedding_size, + indexed_file_paths: Arc::new(AMutex::new(vec![])), }) } @@ -369,6 +370,51 @@ impl VecDBHandler { }) } + pub async fn select_all_file_paths(&self) -> Vec { + let mut file_paths: HashSet = HashSet::new(); + let records: Vec = self.data_table + .filter(format!("file_path in (select file_path from data)")) + .execute() + .await.unwrap() + .try_collect::>() + .await.unwrap(); + + for rec_batch in records { + for record in VecDBHandler::parse_table_iter(rec_batch, false, None).unwrap() { + file_paths.insert(record.file_path.clone()); + } + } + return file_paths.into_iter().collect(); + } + + pub async fn get_file_orig_text(&mut self, file_path: String) -> String{ + let batches: Vec = self.data_table + .filter(format!("file_path == '{}'", file_path)) + .execute() + .await.unwrap() + .try_collect::>() + .await.unwrap(); + + let mut records = vec![]; + for rec_batch in batches { + for record in VecDBHandler::parse_table_iter(rec_batch, false, None).unwrap() { + records.push((record.start_line, record.end_line, record.window_text)); + } + } + records.sort_by(|a, b| a.1.cmp(&b.1)); + let text: String = records.into_iter().map(|rec| rec.2).collect::>().join("\n"); + text + } + + pub async fn update_indexed_file_paths(&mut self) { + let res = self.select_all_file_paths().await; + self.indexed_file_paths = Arc::new(AMutex::new(res)); + } + + pub async fn get_indexed_file_paths(&self) -> Arc>> { + return self.indexed_file_paths.clone(); + } + pub async fn try_add_from_cache(&mut self, data: Vec) -> Vec { if data.is_empty() { return vec![]; diff --git a/src/vecdb/vecdb.rs b/src/vecdb/vecdb.rs index 2b47fdded..cb171cc22 100644 --- a/src/vecdb/vecdb.rs +++ b/src/vecdb/vecdb.rs @@ -42,6 +42,11 @@ pub struct VecDb { constants: VecdbConstants, } +#[derive(Debug, Serialize, Clone)] +pub struct FileSearchResult { + pub file_path: String, + pub file_text: String, +} #[derive(Debug, Serialize)] pub struct VecDbCaps { @@ -158,7 +163,6 @@ async fn do_i_need_to_reload_vecdb( return (true, Some(consts)); } - pub async fn vecdb_background_reload( global_context: Arc>, ) { @@ -218,7 +222,7 @@ impl VecDb { return self.retriever_service.lock().await.start_background_tasks(self.vecdb_emb_client.clone()).await; } - pub async fn add_or_update_file(&mut self, file_path: PathBuf, force: bool) { + pub async fn add_or_update_file(&self, file_path: PathBuf, force: bool) { self.retriever_service.lock().await.process_file(file_path, force).await; } @@ -239,7 +243,16 @@ impl VecDb { folders.iter().map(|x| PathBuf::from(x.uri.path())).collect() ).await; self.add_or_update_files(files, true).await; - info!("vecdb: init_folders complete"); + info!("init_folders complete"); + } + + pub async fn get_indexed_file_paths(&self) -> Arc>> { + return self.vecdb_handler.lock().await.get_indexed_file_paths().await; + } + + pub async fn get_file_orig_text(&self, file_path: String) -> FileSearchResult { + let text = self.vecdb_handler.lock().await.get_file_orig_text(file_path.clone()).await; + FileSearchResult { file_path, file_text: text } } pub async fn caps(&self) -> VecDbCaps { @@ -270,7 +283,10 @@ impl VecdbSearch for VecDb { let mut handler_locked = self.vecdb_handler.lock().await; let t1 = std::time::Instant::now(); - let results = handler_locked.search(embedding_mb.unwrap(), top_n).await.unwrap(); + let results = match handler_locked.search(embedding_mb.unwrap(), top_n).await { + Ok(res) => res, + Err(_) => {return Err("error during search occurred".to_string()) }, + }; info!("search itself {:.3}s", t1.elapsed().as_secs_f64()); for rec in results.iter() { let last_30_chars: String = rec.file_path.display().to_string().chars().rev().take(30).collect::().chars().rev().collect(); diff --git a/src/vecdb/vectorizer_service.rs b/src/vecdb/vectorizer_service.rs index 16ef87db9..0aefb2b43 100644 --- a/src/vecdb/vectorizer_service.rs +++ b/src/vecdb/vectorizer_service.rs @@ -106,6 +106,10 @@ async fn vectorize_thread( None => { // No files left to process if !reported_vecdb_complete { + let t0 = std::time::Instant::now(); + vecdb_handler_ref.lock().await.update_indexed_file_paths().await; + info!("update_indexed_file_paths: it took {:.3}s", t0.elapsed().as_secs_f64()); + reported_vecdb_complete = true; info!("VECDB Creating index"); match vecdb_handler_ref.lock().await.create_index().await { @@ -309,4 +313,12 @@ impl FileVectorizerService { }; Ok(status) } + + pub async fn get_all_file_paths(&self) -> Arc>> { + return self.vecdb_handler.lock().await.get_indexed_file_paths().await; + } + + pub async fn get_file_orig_text(&self, file_path: String) -> String { + return self.vecdb_handler.lock().await.get_file_orig_text(file_path).await; + } } From 3f806bc5b1e62cb4a492ab4210dc350ded4b9cae Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Wed, 31 Jan 2024 10:11:42 +0100 Subject: [PATCH 30/54] logs --- src/at_commands/at_file.rs | 1 + src/at_commands/at_params.rs | 1 + src/at_commands/at_workspace.rs | 1 + src/at_commands/structs.rs | 4 ++++ src/at_commands/utils.rs | 4 ++++ src/http/routers/v1/at_commands.rs | 2 +- src/scratchpads/chat_generic.rs | 2 -- src/scratchpads/chat_passthrough.rs | 2 -- 8 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/at_commands/at_file.rs b/src/at_commands/at_file.rs index a3f8b841f..4e3568fb9 100644 --- a/src/at_commands/at_file.rs +++ b/src/at_commands/at_file.rs @@ -8,6 +8,7 @@ use crate::at_commands::utils::compose_context_file_msg_from_result; use crate::call_validation::{ChatMessage, ContextFile}; use crate::vecdb::vecdb::FileSearchResult; +#[derive(Debug)] pub struct AtFile { pub name: String, pub params: Vec>>, diff --git a/src/at_commands/at_params.rs b/src/at_commands/at_params.rs index d421cf305..f731ade9c 100644 --- a/src/at_commands/at_params.rs +++ b/src/at_commands/at_params.rs @@ -3,6 +3,7 @@ use async_trait::async_trait; use itertools::Itertools; use strsim::jaro_winkler; +#[derive(Debug)] pub struct AtParamFilePath { pub name: String, } diff --git a/src/at_commands/at_workspace.rs b/src/at_commands/at_workspace.rs index 924fb93de..22b383a0a 100644 --- a/src/at_commands/at_workspace.rs +++ b/src/at_commands/at_workspace.rs @@ -8,6 +8,7 @@ use crate::call_validation::{ChatMessage, ContextFile}; use crate::vecdb::structs::{Record, VecdbSearch}; +#[derive(Debug)] pub struct AtWorkspace { pub name: String, pub params: Vec>>, diff --git a/src/at_commands/structs.rs b/src/at_commands/structs.rs index 7f29566f5..43c936705 100644 --- a/src/at_commands/structs.rs +++ b/src/at_commands/structs.rs @@ -41,6 +41,7 @@ pub trait AtParam { async fn complete(&self, value: &String, context: &AtCommandsContext, top_n: usize) -> Vec; } +#[derive(Debug)] pub struct AtCommandCall { pub command: Arc>, pub args: Vec, @@ -55,6 +56,7 @@ impl AtCommandCall { } } +#[derive(Debug)] pub enum AtCommandKind { AtWorkspace(AtWorkspace), AtFile(AtFile), @@ -96,6 +98,8 @@ impl AtCommand for AtCommandKind { } } } + +#[derive(Debug)] pub enum AtParamKind { AtParamFilePath(AtParamFilePath), } diff --git a/src/at_commands/utils.rs b/src/at_commands/utils.rs index e1818e790..dc4205776 100644 --- a/src/at_commands/utils.rs +++ b/src/at_commands/utils.rs @@ -1,5 +1,7 @@ use serde_json::{json, Value}; use crate::at_commands::structs::{AtCommand, AtCommandCall, AtCommandsContext}; +use tracing::info; + pub fn compose_context_file_msg_from_result( in_json: &Value, @@ -35,8 +37,10 @@ pub async fn find_valid_at_commands_in_query( None => continue, }; if !cmd.lock().await.can_execute(&q_cmd_args, context).await { + info!("command {:?} is not executable with arguments {:?}", q_cmd, q_cmd_args); continue; } + info!("command {:?} is perfectly good", q_cmd); results.push(AtCommandCall::new(cmd.clone(), q_cmd_args.clone())); } results diff --git a/src/http/routers/v1/at_commands.rs b/src/http/routers/v1/at_commands.rs index 1b95d9618..a8bf9925a 100644 --- a/src/http/routers/v1/at_commands.rs +++ b/src/http/routers/v1/at_commands.rs @@ -96,7 +96,7 @@ pub async fn handle_v1_command_preview( .unwrap()) } -fn get_line_with_cursor(query: &String, mut cursor: i64) -> Result<(String, i64, i64), ScratchError> { +fn get_line_with_cursor(query: &String, cursor: i64) -> Result<(String, i64, i64), ScratchError> { let mut cursor_rel = cursor; for line in query.lines() { let line_length = line.len() as i64; diff --git a/src/scratchpads/chat_generic.rs b/src/scratchpads/chat_generic.rs index 464050940..6b1d68ea6 100644 --- a/src/scratchpads/chat_generic.rs +++ b/src/scratchpads/chat_generic.rs @@ -4,7 +4,6 @@ use std::sync::RwLock; use async_trait::async_trait; use serde_json::Value; use tokenizers::Tokenizer; -use tokio::sync::Mutex as AMutex; use tokio::sync::RwLock as ARwLock; use tracing::info; @@ -15,7 +14,6 @@ use crate::scratchpad_abstract::ScratchpadAbstract; use crate::scratchpads::chat_utils_deltadelta::DeltaDeltaChatStreamer; use crate::scratchpads::chat_utils_limit_history::limit_messages_history; use crate::scratchpads::chat_utils_rag::{chat_functions_middleware, HasVecdbResults}; -use crate::vecdb::structs::VecdbSearch; const DEBUG: bool = true; diff --git a/src/scratchpads/chat_passthrough.rs b/src/scratchpads/chat_passthrough.rs index d10c75d79..d56e301ce 100644 --- a/src/scratchpads/chat_passthrough.rs +++ b/src/scratchpads/chat_passthrough.rs @@ -2,7 +2,6 @@ use std::sync::Arc; use async_trait::async_trait; use serde_json::Value; -use tokio::sync::Mutex as AMutex; use tracing::{error, info}; use tokio::sync::RwLock as ARwLock; @@ -11,7 +10,6 @@ use crate::global_context::GlobalContext; use crate::scratchpad_abstract::ScratchpadAbstract; use crate::scratchpads::chat_utils_limit_history::limit_messages_history_in_bytes; use crate::scratchpads::chat_utils_rag::{chat_functions_middleware, HasVecdbResults}; -use crate::vecdb::structs::VecdbSearch; const DEBUG: bool = true; From 856af6a9b10aef99bd91641ed58447985a49d24f Mon Sep 17 00:00:00 2001 From: Valeryi Date: Wed, 31 Jan 2024 13:56:44 +0000 Subject: [PATCH 31/54] init --- src/at_commands/utils.rs | 11 +++++++++-- src/http/routers/v1/at_commands.rs | 3 ++- src/scratchpads/chat_utils_rag.rs | 16 +++++++++++----- 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/src/at_commands/utils.rs b/src/at_commands/utils.rs index dc4205776..100d30a82 100644 --- a/src/at_commands/utils.rs +++ b/src/at_commands/utils.rs @@ -19,11 +19,12 @@ pub fn compose_context_file_msg_from_result( } pub async fn find_valid_at_commands_in_query( - query: &String, + query: &mut String, context: &AtCommandsContext, ) -> Vec { let mut results = vec![]; - for line in query.lines() { + let mut valid_command_lines = vec![]; + for (idx, line) in query.lines().enumerate() { let line_words: Vec<&str> = line.split_whitespace().collect(); let q_cmd_args = line_words.iter().skip(1).map(|x|x.to_string()).collect::>(); @@ -42,6 +43,12 @@ pub async fn find_valid_at_commands_in_query( } info!("command {:?} is perfectly good", q_cmd); results.push(AtCommandCall::new(cmd.clone(), q_cmd_args.clone())); + valid_command_lines.push(idx); } + // remove the lines that are valid commands from query + *query = query.lines().enumerate() + .filter(|(idx, _line)| !valid_command_lines.contains(idx)) + .map(|(_idx, line)| line) + .collect::>().join("\n"); results } diff --git a/src/http/routers/v1/at_commands.rs b/src/http/routers/v1/at_commands.rs index a8bf9925a..5849fe0d8 100644 --- a/src/http/routers/v1/at_commands.rs +++ b/src/http/routers/v1/at_commands.rs @@ -71,7 +71,8 @@ pub async fn handle_v1_command_preview( let context = AtCommandsContext::new(global_context.clone()).await; let post = serde_json::from_slice::(&body_bytes) .map_err(|e| ScratchError::new(StatusCode::UNPROCESSABLE_ENTITY, format!("JSON problem: {}", e)))?; - let valid_commands = crate::at_commands::utils::find_valid_at_commands_in_query(&post.query, &context).await; + let mut query = post.query.clone(); + let valid_commands = crate::at_commands::utils::find_valid_at_commands_in_query(&mut query, &context).await; if valid_commands.is_empty() { return Err(ScratchError::new(StatusCode::OK, "no valid commands in query".to_string())); } diff --git a/src/scratchpads/chat_utils_rag.rs b/src/scratchpads/chat_utils_rag.rs index d2d1b417a..3c9cd710e 100644 --- a/src/scratchpads/chat_utils_rag.rs +++ b/src/scratchpads/chat_utils_rag.rs @@ -1,9 +1,9 @@ use std::sync::Arc; -use serde_json::Value; +use serde_json::{json, Value}; use tokio::sync::RwLock as ARwLock; use crate::at_commands::structs::{AtCommand, AtCommandsContext}; -use crate::call_validation::ChatPost; +use crate::call_validation::{ChatMessage, ChatPost}; use crate::global_context::GlobalContext; @@ -14,11 +14,11 @@ pub async fn chat_functions_middleware( has_vecdb: &mut HasVecdbResults, ) { let context = AtCommandsContext::new(global_context.clone()).await; - let query = &post.messages.last().unwrap().content.clone(); // latest_msg_cont - let valid_commands = crate::at_commands::utils::find_valid_at_commands_in_query(&query, &context).await; + let mut query = post.messages.last().unwrap().content.clone(); // latest_msg_cont + let valid_commands = crate::at_commands::utils::find_valid_at_commands_in_query(&mut query, &context).await; for cmd in valid_commands { - match cmd.command.lock().await.execute(query, &cmd.args, top_n, &context).await { + match cmd.command.lock().await.execute(&query, &cmd.args, top_n, &context).await { Ok((msgs, in_json)) => { post.messages.extend(msgs); has_vecdb.push_in_json(in_json); @@ -26,6 +26,12 @@ pub async fn chat_functions_middleware( Err(_) => {} } } + let msg = ChatMessage{ + role: "user".to_string(), + content: query, + }; + post.messages.push(msg.clone()); + has_vecdb.push_in_json(json!(msg)); } pub struct HasVecdbResults { From a6e89316973f96d8820d52a4b30cca50470ed23d Mon Sep 17 00:00:00 2001 From: Valeryi Date: Thu, 1 Feb 2024 13:41:40 +0000 Subject: [PATCH 32/54] Oleg's suggestion: content in context_file is a string now --- src/at_commands/utils.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/at_commands/utils.rs b/src/at_commands/utils.rs index 100d30a82..edd4a2a63 100644 --- a/src/at_commands/utils.rs +++ b/src/at_commands/utils.rs @@ -6,10 +6,11 @@ use tracing::info; pub fn compose_context_file_msg_from_result( in_json: &Value, ) -> Value { + let in_json_string = in_json.to_string(); return json!({ "choices": [{ "delta": { - "content": in_json.clone(), + "content": in_json_string, "role": "context_file" }, "finish_reason": null, From 180deabccfa0cd9d5f0c6df1b37dd5a9fa0ab999 Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Thu, 1 Feb 2024 17:07:30 +0100 Subject: [PATCH 33/54] comments, spaces, rename --- src/at_commands/structs.rs | 2 ++ src/restream.rs | 3 +-- src/scratchpad_abstract.rs | 2 +- src/scratchpads/chat_generic.rs | 4 ++-- src/scratchpads/chat_llama2.rs | 4 ++-- src/scratchpads/chat_passthrough.rs | 5 +++-- src/scratchpads/chat_utils_rag.rs | 6 +++--- src/scratchpads/completion_single_file_fim.rs | 1 + 8 files changed, 15 insertions(+), 12 deletions(-) diff --git a/src/at_commands/structs.rs b/src/at_commands/structs.rs index 43c936705..464d75121 100644 --- a/src/at_commands/structs.rs +++ b/src/at_commands/structs.rs @@ -62,6 +62,7 @@ pub enum AtCommandKind { AtFile(AtFile), } +// TODO: remove this, replace with dyn AtCommand #[async_trait] impl AtCommand for AtCommandKind { fn name(&self) -> &String { @@ -84,6 +85,7 @@ impl AtCommand for AtCommandKind { AtCommandKind::AtFile(file) => file.are_args_valid(args, context).await, } } + async fn can_execute(&self, args: &Vec, context: &AtCommandsContext) -> bool { match self { AtCommandKind::AtWorkspace(workspace) => workspace.can_execute(args, context).await, diff --git a/src/restream.rs b/src/restream.rs index 907657b40..65f1617b7 100644 --- a/src/restream.rs +++ b/src/restream.rs @@ -153,7 +153,7 @@ pub async fn scratchpad_interaction_stream( let mut save_url: String = String::new(); let permit = slowdown_arc.acquire().await; loop { - loop { + { let value_maybe = scratch.response_spontaneous(); if let Ok(value) = value_maybe { for el in value { @@ -167,7 +167,6 @@ pub async fn scratchpad_interaction_stream( let value_str = format!("data: {}\n\n", serde_json::to_string(&json!({"detail": err_str})).unwrap()); yield Result::<_, String>::Ok(value_str); } - break; } let event_source_maybe = if endpoint_style == "hf" { diff --git a/src/scratchpad_abstract.rs b/src/scratchpad_abstract.rs index 09a4c4df7..110ee04bb 100644 --- a/src/scratchpad_abstract.rs +++ b/src/scratchpad_abstract.rs @@ -33,7 +33,7 @@ pub trait ScratchpadAbstract: Send { stop_length: bool, ) -> Result<(serde_json::Value, bool), String>; - fn response_spontaneous(&mut self) -> Result, String> ; + fn response_spontaneous(&mut self) -> Result, String>; } diff --git a/src/scratchpads/chat_generic.rs b/src/scratchpads/chat_generic.rs index 6b1d68ea6..c20d7a865 100644 --- a/src/scratchpads/chat_generic.rs +++ b/src/scratchpads/chat_generic.rs @@ -13,7 +13,7 @@ use crate::scratchpad_abstract::HasTokenizerAndEot; use crate::scratchpad_abstract::ScratchpadAbstract; use crate::scratchpads::chat_utils_deltadelta::DeltaDeltaChatStreamer; use crate::scratchpads::chat_utils_limit_history::limit_messages_history; -use crate::scratchpads::chat_utils_rag::{chat_functions_middleware, HasVecdbResults}; +use crate::scratchpads::chat_utils_rag::{run_at_commands, HasVecdbResults}; const DEBUG: bool = true; @@ -87,7 +87,7 @@ impl ScratchpadAbstract for GenericChatScratchpad { context_size: usize, sampling_parameters_to_patch: &mut SamplingParameters, ) -> Result { - chat_functions_middleware(self.global_context.clone(), &mut self.post, 6, &mut self.has_vecdb_results).await; + run_at_commands(self.global_context.clone(), &mut self.post, 6, &mut self.has_vecdb_results).await; let limited_msgs: Vec = limit_messages_history(&self.t, &self.post.messages, self.post.parameters.max_new_tokens, context_size, &self.default_system_message)?; sampling_parameters_to_patch.stop = Some(self.dd.stop_list.clone()); diff --git a/src/scratchpads/chat_llama2.rs b/src/scratchpads/chat_llama2.rs index bcb51e002..8de6edcdc 100644 --- a/src/scratchpads/chat_llama2.rs +++ b/src/scratchpads/chat_llama2.rs @@ -13,7 +13,7 @@ use crate::scratchpad_abstract::HasTokenizerAndEot; use crate::scratchpad_abstract::ScratchpadAbstract; use crate::scratchpads::chat_utils_deltadelta::DeltaDeltaChatStreamer; use crate::scratchpads::chat_utils_limit_history::limit_messages_history; -use crate::scratchpads::chat_utils_rag::{chat_functions_middleware, HasVecdbResults}; +use crate::scratchpads::chat_utils_rag::{run_at_commands, HasVecdbResults}; const DEBUG: bool = true; @@ -73,7 +73,7 @@ impl ScratchpadAbstract for ChatLlama2 { context_size: usize, sampling_parameters_to_patch: &mut SamplingParameters, ) -> Result { - chat_functions_middleware(self.global_context.clone(), &mut self.post, 6, &mut self.has_vecdb_results).await; + run_at_commands(self.global_context.clone(), &mut self.post, 6, &mut self.has_vecdb_results).await; let limited_msgs: Vec = limit_messages_history(&self.t, &self.post.messages, self.post.parameters.max_new_tokens, context_size, &self.default_system_message)?; sampling_parameters_to_patch.stop = Some(self.dd.stop_list.clone()); diff --git a/src/scratchpads/chat_passthrough.rs b/src/scratchpads/chat_passthrough.rs index d56e301ce..60dccf39c 100644 --- a/src/scratchpads/chat_passthrough.rs +++ b/src/scratchpads/chat_passthrough.rs @@ -9,7 +9,7 @@ use crate::call_validation::{ChatMessage, ChatPost, ContextFile, SamplingParamet use crate::global_context::GlobalContext; use crate::scratchpad_abstract::ScratchpadAbstract; use crate::scratchpads::chat_utils_limit_history::limit_messages_history_in_bytes; -use crate::scratchpads::chat_utils_rag::{chat_functions_middleware, HasVecdbResults}; +use crate::scratchpads::chat_utils_rag::{run_at_commands, HasVecdbResults}; const DEBUG: bool = true; @@ -56,7 +56,7 @@ impl ScratchpadAbstract for ChatPassthrough { _context_size: usize, _sampling_parameters_to_patch: &mut SamplingParameters, ) -> Result { - chat_functions_middleware(self.global_context.clone(), &mut self.post, 6, &mut self.has_vecdb_results).await; + run_at_commands(self.global_context.clone(), &mut self.post, 6, &mut self.has_vecdb_results).await; let limited_msgs: Vec = limit_messages_history_in_bytes(&self.post.messages, self.limit_bytes, &self.default_system_message)?; info!("chat passthrough {} messages -> {} messages after applying limits and possibly adding the default system message", &limited_msgs.len(), &limited_msgs.len()); @@ -130,6 +130,7 @@ impl ScratchpadAbstract for ChatPassthrough { }); Ok((ans, finished)) } + fn response_spontaneous(&mut self) -> Result, String> { return self.has_vecdb_results.response_streaming(); } diff --git a/src/scratchpads/chat_utils_rag.rs b/src/scratchpads/chat_utils_rag.rs index 3c9cd710e..6ef4a9d09 100644 --- a/src/scratchpads/chat_utils_rag.rs +++ b/src/scratchpads/chat_utils_rag.rs @@ -7,12 +7,13 @@ use crate::call_validation::{ChatMessage, ChatPost}; use crate::global_context::GlobalContext; -pub async fn chat_functions_middleware( +pub async fn run_at_commands( global_context: Arc>, post: &mut ChatPost, top_n: usize, has_vecdb: &mut HasVecdbResults, ) { + // TODO: don't operate on `post`, return a copy of the messages let context = AtCommandsContext::new(global_context.clone()).await; let mut query = post.messages.last().unwrap().content.clone(); // latest_msg_cont let valid_commands = crate::at_commands::utils::find_valid_at_commands_in_query(&mut query, &context).await; @@ -28,7 +29,7 @@ pub async fn chat_functions_middleware( } let msg = ChatMessage{ role: "user".to_string(), - content: query, + content: query, // stream back to the user, without commands }; post.messages.push(msg.clone()); has_vecdb.push_in_json(json!(msg)); @@ -48,7 +49,6 @@ impl HasVecdbResults { } } - impl HasVecdbResults { pub fn push_in_json(&mut self, value: Value) { self.in_json.push(value); diff --git a/src/scratchpads/completion_single_file_fim.rs b/src/scratchpads/completion_single_file_fim.rs index 7c45cd715..0e2d181e7 100644 --- a/src/scratchpads/completion_single_file_fim.rs +++ b/src/scratchpads/completion_single_file_fim.rs @@ -286,6 +286,7 @@ impl ScratchpadAbstract for SingleFileFIM { }); Ok((ans, finished)) } + fn response_spontaneous(&mut self) -> Result, String> { return Err("".to_string()); } From 9774caf3cdca50a591a79d2f8600c70ce6208be6 Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Thu, 1 Feb 2024 17:23:23 +0100 Subject: [PATCH 34/54] remove choices [] in fixing user messages --- src/at_commands/utils.rs | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/at_commands/utils.rs b/src/at_commands/utils.rs index edd4a2a63..874d94fe0 100644 --- a/src/at_commands/utils.rs +++ b/src/at_commands/utils.rs @@ -8,14 +8,8 @@ pub fn compose_context_file_msg_from_result( ) -> Value { let in_json_string = in_json.to_string(); return json!({ - "choices": [{ - "delta": { - "content": in_json_string, - "role": "context_file" - }, - "finish_reason": null, - "index": 0 - }], + "content": in_json_string, + "role": "context_file", }); } From 604ef23d6807662ae3b6c388a9d68839eaa61490 Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Thu, 1 Feb 2024 18:31:16 +0100 Subject: [PATCH 35/54] 50/50 working prompt for luke1 --- examples/luke1.py | 110 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 examples/luke1.py diff --git a/examples/luke1.py b/examples/luke1.py new file mode 100644 index 000000000..4dd8f917c --- /dev/null +++ b/examples/luke1.py @@ -0,0 +1,110 @@ +import requests, json + + +my_prompt = """ +You are a bot good at explaining the purpose for the given code. + +STEPS: + +In the [ORIGINAL_CODE_STEP] user will provide code surrounding the code snippet in question, and then the snippet itself will start with 🔥code and backquotes. + +In the [PROVIDE_COMMANDS_STEP] you have to ask for an extra context to completely understand the 🔥code and it's role in the project. +Run several commands in a single message. Don't write any explanations on this step. +Write the number of commands you plan to issue as a first line of your response, +and then write all the commands. +Commands available: + +🔍SEARCH to find more information in other source files in the project or documentation. It's good for looking up definitions and usage. + +🔍FILE to dump whole file text. + +Ask for definitions of types used in the 🔥code. +Ask for usages of the class or function defined in the 🔥code. +Don't look up symbols you already have. + +An examples of commands: + +🔍SEARCH usages of function f + +🔍SEARCH definition of Type2 + +🔍FILE repo1/test_file.cpp + +In the [GENERATE_DOCUMENTATION_STEP] you have to generate a comprehensive context-aware explanation of the 🔥code. +Don't explain line-by-line, rather answer the questions "why it exists", "how does it fit into broader context". +""" + +to_explain = """pub struct DeltaDeltaChatStreamer { + pub delta1: String, + pub delta2: String, + pub finished: bool, + pub stop_list: Vec, + pub role: String, +} +""" + +messages = [ +{"role": "system", "content": my_prompt}, +{"role": "user", "content": + "[ORIGINAL_CODE_STEP]\n" + + "@file /home/user/.refact/tmp/unpacked-files/refact-lsp/src/scratchpads/chat_utils_deltadelta.rs\n" + + "Why this 🔥code exists:\n```\n[CODE]```\n".replace("[CODE]", to_explain) + + "[PROVIDE_COMMANDS_STEP]\n"}, +] + +def ask_chat(messages): + response = requests.post( + "http://127.0.0.1:8001/v1/chat", + json={ + "messages": messages, + "temperature": 0.1, + "max_tokens": 300, + "model": "gpt-3.5-turbo", + }, + headers={ + "Content-Type": "application/json", + "Authorization": "Bearer XXX", + }, + timeout=60, + ) + # data: {"choices":[{"delta":{"content":"The","role":"assistant"},"finish_reason":null,"index":0}],"created":1706779319.409,"model":"gpt-3.5-turbo"} + # data: {"choices":[{"delta":{"content":" code","role":"assistant"},"finish_reason":null,"index":0}],"created":1706779319.409,"model":"gpt-3.5-turbo"} + # Collect all delta/content from the response + messages_back = [] + accum_content = "" + accum_role = "" + # print(response.text) + for x in response.text.splitlines(): + if not x.strip(): + continue + if not x.startswith("data: "): + print(x) + print("ERROR: unexpected response format") + continue + if x[6:].startswith("[DONE]"): + break + j = json.loads(x[6:]) + if "choices" in j: + # streaming + choice0 = j["choices"][0] + accum_role = choice0["delta"]["role"] + accum_content += choice0["delta"]["content"] + else: + # content/role without streaming, replacing the last user message + messages_back.append({"role": j["role"], "content": j["content"]}) + if accum_role: + messages_back.append({"role": accum_role, "content": accum_content}) + return messages_back + + +import termcolor +for msgdict in messages: + role = termcolor.colored(msgdict["role"], "blue") + content = termcolor.colored(msgdict["content"], "green") + print(role, content) + +messages_back = ask_chat(messages) +for msgdict in messages_back: + role = termcolor.colored(msgdict["role"], "blue") + content = termcolor.colored(msgdict["content"], "red") + print(role, content) From d0dfb090a777b816dcb575263fd59632d4ef8a39 Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Thu, 1 Feb 2024 20:50:18 +0100 Subject: [PATCH 36/54] simplify at commands --- examples/luke1.py | 59 ++++++++++++++++++++++------- src/at_commands/at_file.rs | 40 +++++++------------ src/at_commands/at_workspace.rs | 59 +++++++++++------------------ src/at_commands/structs.rs | 5 +-- src/at_commands/utils.rs | 11 ------ src/http/routers/v1/at_commands.rs | 6 +-- src/scratchpads/chat_passthrough.rs | 11 ++++-- src/scratchpads/chat_utils_rag.rs | 6 +-- 8 files changed, 97 insertions(+), 100 deletions(-) diff --git a/examples/luke1.py b/examples/luke1.py index 4dd8f917c..b48adf7b6 100644 --- a/examples/luke1.py +++ b/examples/luke1.py @@ -1,4 +1,4 @@ -import requests, json +import requests, json, termcolor my_prompt = """ @@ -30,8 +30,8 @@ 🔍FILE repo1/test_file.cpp -In the [GENERATE_DOCUMENTATION_STEP] you have to generate a comprehensive context-aware explanation of the 🔥code. -Don't explain line-by-line, rather answer the questions "why it exists", "how does it fit into broader context". +In the [GENERATE_DOCUMENTATION_STEP] you have to generate an explanation of the 🔥code. +Answer questions "why it exists", "how does it fit into broader context". Don't explain line-by-line. Don't explain class data fields. """ to_explain = """pub struct DeltaDeltaChatStreamer { @@ -43,7 +43,7 @@ } """ -messages = [ +initial_messages = [ {"role": "system", "content": my_prompt}, {"role": "user", "content": "[ORIGINAL_CODE_STEP]\n" + @@ -97,14 +97,45 @@ def ask_chat(messages): return messages_back -import termcolor -for msgdict in messages: - role = termcolor.colored(msgdict["role"], "blue") - content = termcolor.colored(msgdict["content"], "green") - print(role, content) +def rewrite_assistant_says_to_at_commands(ass): + out = "" + for s in ass.splitlines(): + s = s.strip() + if not s: + continue + if s.startswith("🔍SEARCH"): + out += "@workspace " + s[8:] + "\n" + if s.startswith("🔍FILE"): + out += "@file " + s[6:] + "\n" + return out + + +def dialog_turn(messages): + for msgdict in messages: + print(termcolor.colored(msgdict["role"], "blue")) + print(termcolor.colored(msgdict["content"], "green")) + messages_back = ask_chat(messages) + for msgdict in messages_back: + print(termcolor.colored(msgdict["role"], "blue")) + print(termcolor.colored(msgdict["content"], "red")) + assistant_says = messages_back[-1]["content"] + messages_without_last_user = messages[:-1] + next_step_messages = messages_without_last_user + messages_back + automated_new_user = rewrite_assistant_says_to_at_commands(assistant_says) + if not automated_new_user: + return next_step_messages, False + automated_new_user += "[GENERATE_DOCUMENTATION_STEP]" + next_step_messages.append({"role": "user", "content": automated_new_user}) + return next_step_messages, True + + +def do_all(): + messages = initial_messages.copy() + for step in range(2): + print("-"*40, "STEP%02d" % step, "-"*40) + messages, need_automated_post = dialog_turn(messages) + if not need_automated_post: + break + -messages_back = ask_chat(messages) -for msgdict in messages_back: - role = termcolor.colored(msgdict["role"], "blue") - content = termcolor.colored(msgdict["content"], "red") - print(role, content) +do_all() diff --git a/src/at_commands/at_file.rs b/src/at_commands/at_file.rs index 4e3568fb9..854c55bc8 100644 --- a/src/at_commands/at_file.rs +++ b/src/at_commands/at_file.rs @@ -1,10 +1,9 @@ use std::sync::Arc; use async_trait::async_trait; -use serde_json::{json, Value}; +use serde_json::json; use crate::at_commands::structs::{AtCommand, AtCommandsContext, AtParam, AtParamKind}; use crate::at_commands::at_params::AtParamFilePath; use tokio::sync::Mutex as AMutex; -use crate::at_commands::utils::compose_context_file_msg_from_result; use crate::call_validation::{ChatMessage, ContextFile}; use crate::vecdb::vecdb::FileSearchResult; @@ -25,24 +24,6 @@ impl AtFile { } } -fn search2messages(result: &FileSearchResult) -> Vec { - // TODO: change to context_file, encode json including line1 line2 - vec![ChatMessage { - role: "user".to_string(), - content: format!("FILENAME:\n{}\nTEXT:\n{}\n", result.file_path, result.file_text) - }] -} - -fn search2json(result: &FileSearchResult) -> Value { - let cf = ContextFile { - file_name: result.file_path.clone().rsplit("/").next().unwrap_or(&result.file_path).to_string(), - file_content: result.file_text.clone(), - line1: 0, - line2: result.file_text.lines().count() as i32, - }; - compose_context_file_msg_from_result(&serde_json::to_value(&vec![cf]).unwrap_or(json!(null))) -} - #[async_trait] impl AtCommand for AtFile { fn name(&self) -> &String { @@ -67,7 +48,7 @@ impl AtCommand for AtFile { return true; } - async fn execute(&self, _query: &String, args: &Vec, _top_n: usize, context: &AtCommandsContext) -> Result<(Vec, Value), String> { + async fn execute(&self, _query: &String, args: &Vec, _top_n: usize, context: &AtCommandsContext) -> Result { let can_execute = self.can_execute(args, context).await; match *context.global_context.read().await.vec_db.lock().await { Some(ref db) => { @@ -78,11 +59,18 @@ impl AtCommand for AtFile { Some(x) => x, None => return Err("no file path".to_string()), }; - let search_result = db.get_file_orig_text(file_path.clone()).await; - Ok(( - search2messages(&search_result), - search2json(&search_result) - )) + let path_and_text: FileSearchResult = db.get_file_orig_text(file_path.clone()).await; + let mut vector_of_context_file: Vec = vec![]; + vector_of_context_file.push(ContextFile { + file_name: path_and_text.file_path, + file_content: path_and_text.file_text.clone(), + line1: 0, + line2: path_and_text.file_text.lines().count() as i32, + }); + Ok(ChatMessage { + role: "context_file".to_string(), + content: json!(vector_of_context_file).to_string(), + }) } None => Err("vecdb is not available".to_string()) } diff --git a/src/at_commands/at_workspace.rs b/src/at_commands/at_workspace.rs index 22b383a0a..e15d8ed68 100644 --- a/src/at_commands/at_workspace.rs +++ b/src/at_commands/at_workspace.rs @@ -1,9 +1,8 @@ use std::sync::Arc; use async_trait::async_trait; -use serde_json::{json, Value}; +use serde_json::json; use crate::at_commands::structs::{AtCommand, AtCommandsContext, AtParam, AtParamKind}; use tokio::sync::Mutex as AMutex; -use crate::at_commands::utils::compose_context_file_msg_from_result; use crate::call_validation::{ChatMessage, ContextFile}; use crate::vecdb::structs::{Record, VecdbSearch}; @@ -23,33 +22,20 @@ impl AtWorkspace { } } -fn record2chat_message(record: &Record) -> ChatMessage { +fn results2message(results: &Vec) -> ChatMessage { + let mut vector_of_context_file: Vec = vec![]; + for r in results { + vector_of_context_file.push(ContextFile { + file_name: r.file_path.to_str().unwrap().to_string(), + file_content: r.window_text.clone(), + line1: r.start_line as i32, + line2: r.end_line as i32, + }); + } ChatMessage { role: "context_file".to_string(), - content: json!(ContextFile { - file_name: record.file_path.to_str().unwrap().to_string(), //.rsplit('/').next().unwrap_or(&record.file_path.to_str().unwrap()).to_string(), - file_content: record.window_text.clone(), - line1: record.start_line as i32, - line2: record.end_line as i32, - }).to_string() - } -} - -fn search2messages(results: &Vec) -> Vec { - let mut messages = vec![]; - for r in results { - messages.push(record2chat_message(r)); + content: json!(vector_of_context_file).to_string(), } - messages -} - -pub fn search2json( - results: &Vec -) -> Value { - let context_files: Vec = results - .iter() - .map(|x| { record2chat_message(x) }).collect(); - compose_context_file_msg_from_result(&serde_json::to_value(&context_files).unwrap_or(json!(null))) } #[async_trait] @@ -57,10 +43,12 @@ impl AtCommand for AtWorkspace { fn name(&self) -> &String { &self.name } + fn params(&self) -> &Vec>> { &self.params } + async fn are_args_valid(&self, args: &Vec, context: &AtCommandsContext) -> Vec { let mut results = Vec::new(); for (arg, param) in args.iter().zip(self.params.iter()) { @@ -70,24 +58,21 @@ impl AtCommand for AtWorkspace { results } - async fn can_execute(&self, args: &Vec, context: &AtCommandsContext) -> bool { - if self.are_args_valid(args, context).await.iter().any(|&x| x == false) || args.len() != self.params.len() { - return false; - } + async fn can_execute(&self, _args: &Vec, _context: &AtCommandsContext) -> bool { return true; } - async fn execute(&self, query: &String, _args: &Vec, top_n: usize, context: &AtCommandsContext) -> Result<(Vec, Value), String> { + async fn execute(&self, query: &String, args: &Vec, top_n: usize, context: &AtCommandsContext) -> Result { match *context.global_context.read().await.vec_db.lock().await { Some(ref db) => { - let search_result = db.search(query.clone(), top_n).await?; + let mut db_query = args.join(" "); + if db_query.is_empty() { + db_query = query.clone(); + } + let search_result = db.search(db_query, top_n).await?; let mut results = search_result.results.clone(); results.dedup_by(|a, b| a.file_path == b.file_path && a.window_text == b.window_text); - - Ok(( - search2messages(&results), - search2json(&results) - )) + Ok(results2message(&results)) } None => Err("vecdb is not available".to_string()) } diff --git a/src/at_commands/structs.rs b/src/at_commands/structs.rs index 464d75121..cbcc197cd 100644 --- a/src/at_commands/structs.rs +++ b/src/at_commands/structs.rs @@ -1,7 +1,6 @@ use std::collections::HashMap; use std::sync::Arc; use async_trait::async_trait; -use serde_json::Value; use crate::global_context::GlobalContext; use crate::at_commands::at_commands::at_commands_dict; use tokio::sync::RwLock as ARwLock; @@ -31,7 +30,7 @@ pub trait AtCommand: Send + Sync { fn params(&self) -> &Vec>>; async fn are_args_valid(&self, args: &Vec, context: &AtCommandsContext) -> Vec; async fn can_execute(&self, args: &Vec, context: &AtCommandsContext) -> bool; - async fn execute(&self, query: &String, args: &Vec, top_n: usize, context: &AtCommandsContext) -> Result<(Vec, Value), String>; + async fn execute(&self, query: &String, args: &Vec, top_n: usize, context: &AtCommandsContext) -> Result; } #[async_trait] @@ -93,7 +92,7 @@ impl AtCommand for AtCommandKind { } } - async fn execute(&self, query: &String, args: &Vec, top_n: usize, context: &AtCommandsContext) -> Result<(Vec, Value), String> { + async fn execute(&self, query: &String, args: &Vec, top_n: usize, context: &AtCommandsContext) -> Result { match self { AtCommandKind::AtWorkspace(workspace) => workspace.execute(query, args, top_n, context).await, AtCommandKind::AtFile(file) => file.execute(query, args, top_n, context).await, diff --git a/src/at_commands/utils.rs b/src/at_commands/utils.rs index 874d94fe0..ebdf96e6b 100644 --- a/src/at_commands/utils.rs +++ b/src/at_commands/utils.rs @@ -1,18 +1,7 @@ -use serde_json::{json, Value}; use crate::at_commands::structs::{AtCommand, AtCommandCall, AtCommandsContext}; use tracing::info; -pub fn compose_context_file_msg_from_result( - in_json: &Value, -) -> Value { - let in_json_string = in_json.to_string(); - return json!({ - "content": in_json_string, - "role": "context_file", - }); -} - pub async fn find_valid_at_commands_in_query( query: &mut String, context: &AtCommandsContext, diff --git a/src/http/routers/v1/at_commands.rs b/src/http/routers/v1/at_commands.rs index 5849fe0d8..be8f42153 100644 --- a/src/http/routers/v1/at_commands.rs +++ b/src/http/routers/v1/at_commands.rs @@ -4,7 +4,7 @@ use hyper::{Body, Response, StatusCode}; use serde::{Deserialize, Serialize}; use std::sync::Arc; use itertools::Itertools; -use serde_json::{Value}; +use serde_json::{json, Value}; use tokio::sync::RwLock as ARwLock; use strsim::jaro_winkler; use crate::at_commands::structs::{AtCommand, AtCommandsContext, AtParam}; @@ -80,8 +80,8 @@ pub async fn handle_v1_command_preview( let mut preview_msgs = vec![]; for cmd in valid_commands { match cmd.command.lock().await.execute(&post.query, &cmd.args, 5, &context).await { - Ok((_, in_json)) => { - preview_msgs.push(in_json); + Ok(msg) => { + preview_msgs.push(json!(msg)); }, Err(_) => {} } diff --git a/src/scratchpads/chat_passthrough.rs b/src/scratchpads/chat_passthrough.rs index 60dccf39c..10fc011d9 100644 --- a/src/scratchpads/chat_passthrough.rs +++ b/src/scratchpads/chat_passthrough.rs @@ -56,10 +56,10 @@ impl ScratchpadAbstract for ChatPassthrough { _context_size: usize, _sampling_parameters_to_patch: &mut SamplingParameters, ) -> Result { + info!("chat passthrough {} messages at start", &self.post.messages.len()); run_at_commands(self.global_context.clone(), &mut self.post, 6, &mut self.has_vecdb_results).await; - let limited_msgs: Vec = limit_messages_history_in_bytes(&self.post.messages, self.limit_bytes, &self.default_system_message)?; - info!("chat passthrough {} messages -> {} messages after applying limits and possibly adding the default system message", &limited_msgs.len(), &limited_msgs.len()); + info!("chat passthrough {} messages -> {} messages after applying at-commands and limits, possibly adding the default system message", &limited_msgs.len(), &limited_msgs.len()); let mut filtered_msgs: Vec = Vec::::new(); for msg in &limited_msgs { if msg.role == "assistant" || msg.role == "system" || msg.role == "user" { @@ -83,7 +83,12 @@ impl ScratchpadAbstract for ChatPassthrough { let prompt = "PASSTHROUGH ".to_string() + &serde_json::to_string(&filtered_msgs).unwrap(); if DEBUG { for msg in &filtered_msgs { - info!("filtered message: {:?}", msg); + let mut last_30_chars: String = msg.content.chars().take(30).collect(); + if last_30_chars.len() == 30 { + last_30_chars.push_str("..."); + } + last_30_chars = last_30_chars.replace("\n", "\\n"); + info!("filtered message role={} {}", msg.role, &last_30_chars); } } Ok(prompt.to_string()) diff --git a/src/scratchpads/chat_utils_rag.rs b/src/scratchpads/chat_utils_rag.rs index 6ef4a9d09..5156d46c5 100644 --- a/src/scratchpads/chat_utils_rag.rs +++ b/src/scratchpads/chat_utils_rag.rs @@ -20,9 +20,9 @@ pub async fn run_at_commands( for cmd in valid_commands { match cmd.command.lock().await.execute(&query, &cmd.args, top_n, &context).await { - Ok((msgs, in_json)) => { - post.messages.extend(msgs); - has_vecdb.push_in_json(in_json); + Ok(msg) => { + post.messages.push(msg.clone()); + has_vecdb.push_in_json(json!(msg)); }, Err(_) => {} } From e5c2fc2b988f68f61a2ea9cd1290c3e8696eee50 Mon Sep 17 00:00:00 2001 From: Valeryi Date: Thu, 1 Feb 2024 13:20:52 +0000 Subject: [PATCH 37/54] Changelist (Marc's suggestions): * CommandCompletion API update: added top_n field * Command Completion works even if argument is empty --- src/at_commands/query.rs | 8 +++++++- src/http/routers/v1/at_commands.rs | 22 +++++++++++++++++----- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/src/at_commands/query.rs b/src/at_commands/query.rs index 150862dd2..549ea3b8f 100644 --- a/src/at_commands/query.rs +++ b/src/at_commands/query.rs @@ -60,7 +60,13 @@ fn parse_args_from_line(line: &String) -> Vec { args.push(QueryLineArg{ value: value.clone(), pos1, - pos2: idx, + pos2: { + if ch.to_string() != " " { + idx + } else { + idx - 1 + } + }, focused: false, type_name: { if value.starts_with("@") { diff --git a/src/http/routers/v1/at_commands.rs b/src/http/routers/v1/at_commands.rs index be8f42153..db415be3e 100644 --- a/src/http/routers/v1/at_commands.rs +++ b/src/http/routers/v1/at_commands.rs @@ -17,6 +17,7 @@ use crate::global_context::GlobalContext; struct CommandCompletionPost { query: String, cursor: i64, + top_n: usize, } #[derive(Serialize, Deserialize, Clone)] struct CommandCompletionResponse { @@ -43,14 +44,15 @@ pub async fn handle_v1_command_completion( let post = serde_json::from_slice::(&body_bytes) .map_err(|e| ScratchError::new(StatusCode::UNPROCESSABLE_ENTITY, format!("JSON problem: {}", e)))?; - let (query_line_val, cursor_rel, cursor_line_start) = get_line_with_cursor(&post.query, post.cursor.clone())?; + let (query_line_val, cursor_rel, cursor_line_start) = get_line_with_cursor(&post.query, post.cursor)?; let query_line_val = query_line_val.chars().take(cursor_rel as usize).collect::(); // info!(query_line_val); + // info!("cursor_rel: {}, cursor_line_start: {}", cursor_rel, cursor_line_start); let query_line = QueryLine::new(query_line_val, cursor_rel, cursor_line_start); // for arg in query_line.args.iter() { // info!("value: {}, focused: {}, type_name: {}; pos1: {}; pos2: {}", arg.value, arg.focused, arg.type_name, arg.pos1, arg.pos2); // } - let (completions, is_cmd_executable, pos1, pos2) = command_completion(&query_line, &context).await?; + let (completions, is_cmd_executable, pos1, pos2) = command_completion(&query_line, &context, post.cursor, post.top_n).await?; let response = CommandCompletionResponse { completions: completions.clone(), @@ -115,6 +117,8 @@ fn get_line_with_cursor(query: &String, cursor: i64) -> Result<(String, i64, i64 async fn command_completion( query_line: &QueryLine, context: &AtCommandsContext, + cursor_abs: i64, + top_n: usize, ) -> Result<(Vec, bool, i64, i64), ScratchError> { // returns ([possible, completions], good_as_it_is) let q_cmd = match query_line.command() { Some(x) => x, @@ -127,7 +131,7 @@ async fn command_completion( if !q_cmd.focused { return Err(ScratchError::new(StatusCode::OK, "incorrect command given".to_string())); } - return Ok((command_completion_options(&q_cmd.value, &context, 5).await, false, q_cmd.pos1, q_cmd.pos2)); + return Ok((command_completion_options(&q_cmd.value, &context, top_n).await, false, q_cmd.pos1, q_cmd.pos2)); } }; if cmd.lock().await.can_execute(&query_line.get_args().iter().map(|x|x.value.clone()).collect(), context).await { @@ -138,12 +142,20 @@ async fn command_completion( let is_valid = param.lock().await.is_value_valid(&arg.value, context).await; if !is_valid { return if arg.focused { - Ok((param.lock().await.complete(&arg.value, context, 5).await, false, arg.pos1, arg.pos2)) + Ok((param.lock().await.complete(&arg.value, context, top_n).await, false, arg.pos1, arg.pos2)) } else { Err(ScratchError::new(StatusCode::OK, "invalid parameter".to_string())) } } - + } + // if command is not focused, and the argument is empty we should make suggestions + if !q_cmd.focused { + match cmd.lock().await.params().get(query_line.get_args().len()) { + Some(param) => { + return Ok((param.lock().await.complete(&"".to_string(), context, top_n).await, false, cursor_abs, cursor_abs)); + }, + None => {} + } } return Ok((vec![], false, -1, -1)); From 7b1a8a4f94900ee5d24b6db7b6db700216307ec9 Mon Sep 17 00:00:00 2001 From: Valeryi Date: Thu, 1 Feb 2024 16:51:54 +0000 Subject: [PATCH 38/54] no more enums --- src/at_commands/at_commands.rs | 10 ++-- src/at_commands/at_file.rs | 9 ++-- src/at_commands/at_workspace.rs | 7 ++- src/at_commands/structs.rs | 85 ++------------------------------- src/at_commands/utils.rs | 5 +- 5 files changed, 20 insertions(+), 96 deletions(-) diff --git a/src/at_commands/at_commands.rs b/src/at_commands/at_commands.rs index fdd715fda..b9cd65c55 100644 --- a/src/at_commands/at_commands.rs +++ b/src/at_commands/at_commands.rs @@ -4,12 +4,12 @@ use tokio::sync::Mutex as AMutex; use crate::at_commands::at_workspace::AtWorkspace; use crate::at_commands::at_file::AtFile; -use crate::at_commands::structs::AtCommandKind; +use crate::at_commands::structs::AtCommand; -pub async fn at_commands_dict() -> HashMap>> { +pub async fn at_commands_dict() -> HashMap>>> { return HashMap::from([ - ("@workspace".to_string(), Arc::new(AMutex::new(AtCommandKind::AtWorkspace(AtWorkspace::new())))), - ("@file".to_string(), Arc::new(AMutex::new(AtCommandKind::AtFile(AtFile::new())))), + ("@workspace".to_string(), Arc::new(AMutex::new(Box::new(AtWorkspace::new()) as Box))), + ("@file".to_string(), Arc::new(AMutex::new(Box::new(AtFile::new()) as Box))), ]); -} +} \ No newline at end of file diff --git a/src/at_commands/at_file.rs b/src/at_commands/at_file.rs index 854c55bc8..31f29f6c3 100644 --- a/src/at_commands/at_file.rs +++ b/src/at_commands/at_file.rs @@ -1,16 +1,15 @@ use std::sync::Arc; use async_trait::async_trait; use serde_json::json; -use crate::at_commands::structs::{AtCommand, AtCommandsContext, AtParam, AtParamKind}; +use crate::at_commands::structs::{AtCommand, AtCommandsContext, AtParam}; use crate::at_commands::at_params::AtParamFilePath; use tokio::sync::Mutex as AMutex; use crate::call_validation::{ChatMessage, ContextFile}; use crate::vecdb::vecdb::FileSearchResult; -#[derive(Debug)] pub struct AtFile { pub name: String, - pub params: Vec>>, + pub params: Vec>>, } impl AtFile { @@ -18,7 +17,7 @@ impl AtFile { AtFile { name: "@file".to_string(), params: vec![ - Arc::new(AMutex::new(AtParamKind::AtParamFilePath(AtParamFilePath::new()))) + Arc::new(AMutex::new(AtParamFilePath::new())) ], } } @@ -29,7 +28,7 @@ impl AtCommand for AtFile { fn name(&self) -> &String { &self.name } - fn params(&self) -> &Vec>> { + fn params(&self) -> &Vec>> { &self.params } async fn are_args_valid(&self, args: &Vec, context: &AtCommandsContext) -> Vec { diff --git a/src/at_commands/at_workspace.rs b/src/at_commands/at_workspace.rs index e15d8ed68..d6208d221 100644 --- a/src/at_commands/at_workspace.rs +++ b/src/at_commands/at_workspace.rs @@ -1,16 +1,15 @@ use std::sync::Arc; use async_trait::async_trait; use serde_json::json; -use crate::at_commands::structs::{AtCommand, AtCommandsContext, AtParam, AtParamKind}; +use crate::at_commands::structs::{AtCommand, AtCommandsContext, AtParam}; use tokio::sync::Mutex as AMutex; use crate::call_validation::{ChatMessage, ContextFile}; use crate::vecdb::structs::{Record, VecdbSearch}; -#[derive(Debug)] pub struct AtWorkspace { pub name: String, - pub params: Vec>>, + pub params: Vec>>, } impl AtWorkspace { @@ -44,7 +43,7 @@ impl AtCommand for AtWorkspace { &self.name } - fn params(&self) -> &Vec>> + fn params(&self) -> &Vec>> { &self.params } diff --git a/src/at_commands/structs.rs b/src/at_commands/structs.rs index cbcc197cd..b8e095d9f 100644 --- a/src/at_commands/structs.rs +++ b/src/at_commands/structs.rs @@ -5,14 +5,11 @@ use crate::global_context::GlobalContext; use crate::at_commands::at_commands::at_commands_dict; use tokio::sync::RwLock as ARwLock; use tokio::sync::Mutex as AMutex; -use crate::at_commands::at_file::AtFile; -use crate::at_commands::at_params::AtParamFilePath; -use crate::at_commands::at_workspace::AtWorkspace; use crate::call_validation::ChatMessage; pub struct AtCommandsContext { pub global_context: Arc>, - pub at_commands: HashMap>>, + pub at_commands: HashMap>>>, } impl AtCommandsContext { @@ -27,101 +24,29 @@ impl AtCommandsContext { #[async_trait] pub trait AtCommand: Send + Sync { fn name(&self) -> &String; - fn params(&self) -> &Vec>>; + fn params(&self) -> &Vec>>; async fn are_args_valid(&self, args: &Vec, context: &AtCommandsContext) -> Vec; async fn can_execute(&self, args: &Vec, context: &AtCommandsContext) -> bool; async fn execute(&self, query: &String, args: &Vec, top_n: usize, context: &AtCommandsContext) -> Result; } #[async_trait] -pub trait AtParam { +pub trait AtParam: Send + Sync { fn name(&self) -> &String; async fn is_value_valid(&self, value: &String, context: &AtCommandsContext) -> bool; async fn complete(&self, value: &String, context: &AtCommandsContext, top_n: usize) -> Vec; } -#[derive(Debug)] pub struct AtCommandCall { - pub command: Arc>, + pub command: Arc>>, pub args: Vec, } impl AtCommandCall { - pub fn new(command: Arc>, args: Vec) -> Self { + pub fn new(command: Arc>>, args: Vec) -> Self { AtCommandCall { command, args } } } - -#[derive(Debug)] -pub enum AtCommandKind { - AtWorkspace(AtWorkspace), - AtFile(AtFile), -} - -// TODO: remove this, replace with dyn AtCommand -#[async_trait] -impl AtCommand for AtCommandKind { - fn name(&self) -> &String { - match self { - AtCommandKind::AtWorkspace(workspace) => workspace.name(), - AtCommandKind::AtFile(file) => file.name(), - } - } - - fn params(&self) -> &Vec>> { - match self { - AtCommandKind::AtWorkspace(workspace) => workspace.params(), - AtCommandKind::AtFile(file) => file.params(), - } - } - - async fn are_args_valid(&self, args: &Vec, context: &AtCommandsContext) -> Vec { - match self { - AtCommandKind::AtWorkspace(workspace) => workspace.are_args_valid(args, context).await, - AtCommandKind::AtFile(file) => file.are_args_valid(args, context).await, - } - } - - async fn can_execute(&self, args: &Vec, context: &AtCommandsContext) -> bool { - match self { - AtCommandKind::AtWorkspace(workspace) => workspace.can_execute(args, context).await, - AtCommandKind::AtFile(file) => file.can_execute(args, context).await, - } - } - - async fn execute(&self, query: &String, args: &Vec, top_n: usize, context: &AtCommandsContext) -> Result { - match self { - AtCommandKind::AtWorkspace(workspace) => workspace.execute(query, args, top_n, context).await, - AtCommandKind::AtFile(file) => file.execute(query, args, top_n, context).await, - } - } -} - -#[derive(Debug)] -pub enum AtParamKind { - AtParamFilePath(AtParamFilePath), -} - -#[async_trait] -impl AtParam for AtParamKind { - fn name(&self) -> &String { - match self { - AtParamKind::AtParamFilePath(param) => param.name(), - } - } - - async fn is_value_valid(&self, value: &String, context: &AtCommandsContext) -> bool { - match self { - AtParamKind::AtParamFilePath(param) => param.is_value_valid(value, context).await, - } - } - - async fn complete(&self, value: &String, context: &AtCommandsContext, top_n: usize) -> Vec { - match self { - AtParamKind::AtParamFilePath(param) => param.complete(value, context, top_n).await, - } - } -} diff --git a/src/at_commands/utils.rs b/src/at_commands/utils.rs index ebdf96e6b..58836f0e5 100644 --- a/src/at_commands/utils.rs +++ b/src/at_commands/utils.rs @@ -1,4 +1,5 @@ -use crate::at_commands::structs::{AtCommand, AtCommandCall, AtCommandsContext}; +use std::sync::Arc; +use crate::at_commands::structs::{AtCommandCall, AtCommandsContext}; use tracing::info; @@ -26,7 +27,7 @@ pub async fn find_valid_at_commands_in_query( continue; } info!("command {:?} is perfectly good", q_cmd); - results.push(AtCommandCall::new(cmd.clone(), q_cmd_args.clone())); + results.push(AtCommandCall::new(Arc::clone(&cmd), q_cmd_args.clone())); valid_command_lines.push(idx); } // remove the lines that are valid commands from query From 10d4b8a6353646a9a83ecba868401286ccf9767f Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Thu, 1 Feb 2024 20:59:05 +0100 Subject: [PATCH 39/54] log fix --- src/scratchpads/chat_passthrough.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scratchpads/chat_passthrough.rs b/src/scratchpads/chat_passthrough.rs index 10fc011d9..2bf32b00a 100644 --- a/src/scratchpads/chat_passthrough.rs +++ b/src/scratchpads/chat_passthrough.rs @@ -59,7 +59,7 @@ impl ScratchpadAbstract for ChatPassthrough { info!("chat passthrough {} messages at start", &self.post.messages.len()); run_at_commands(self.global_context.clone(), &mut self.post, 6, &mut self.has_vecdb_results).await; let limited_msgs: Vec = limit_messages_history_in_bytes(&self.post.messages, self.limit_bytes, &self.default_system_message)?; - info!("chat passthrough {} messages -> {} messages after applying at-commands and limits, possibly adding the default system message", &limited_msgs.len(), &limited_msgs.len()); + info!("chat passthrough {} messages -> {} messages after applying at-commands and limits, possibly adding the default system message", &self.post.messages.len(), &limited_msgs.len()); let mut filtered_msgs: Vec = Vec::::new(); for msg in &limited_msgs { if msg.role == "assistant" || msg.role == "system" || msg.role == "user" { From d00769daf554cc5fc31a175dd7d844a6d8638694 Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Fri, 2 Feb 2024 07:53:29 +0100 Subject: [PATCH 40/54] nicer_logs --- src/main.rs | 1 + src/nicer_logs.rs | 15 +++++++++++++++ src/restream.rs | 3 ++- src/scratchpads/chat_passthrough.rs | 10 +++------- src/scratchpads/chat_utils_limit_history.rs | 2 ++ src/vecdb/vecdb.rs | 2 +- 6 files changed, 24 insertions(+), 9 deletions(-) create mode 100644 src/nicer_logs.rs diff --git a/src/main.rs b/src/main.rs index 7f7b14212..03d8356cb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -26,6 +26,7 @@ mod receive_workspace_changes; mod vecdb; mod fetch_embedding; mod at_commands; +mod nicer_logs; #[tokio::main] diff --git a/src/nicer_logs.rs b/src/nicer_logs.rs new file mode 100644 index 000000000..9f7741461 --- /dev/null +++ b/src/nicer_logs.rs @@ -0,0 +1,15 @@ +pub fn first_n_chars(msg: &String, n: usize) -> String { + let mut last_n_chars: String = msg.chars().take(n).collect(); + if last_n_chars.len() == n { + last_n_chars.push_str("..."); + } + return last_n_chars.replace("\n", "\\n"); +} + +pub fn last_n_chars(msg: &String, n: usize) -> String { + let mut last_n_chars: String = msg.chars().rev().take(n).collect::().chars().rev().collect(); + if last_n_chars.len() == n { + last_n_chars.insert_str(0, "..."); + } + return last_n_chars.replace("\n", "\\n"); +} diff --git a/src/restream.rs b/src/restream.rs index 65f1617b7..2d7481ace 100644 --- a/src/restream.rs +++ b/src/restream.rs @@ -7,6 +7,7 @@ use futures::StreamExt; use async_stream::stream; use hyper::{Body, Response, StatusCode}; +use crate::nicer_logs; use crate::scratchpad_abstract::ScratchpadAbstract; use crate::forward_to_hf_endpoint; use crate::forward_to_openai_endpoint; @@ -158,7 +159,7 @@ pub async fn scratchpad_interaction_stream( if let Ok(value) = value_maybe { for el in value { let value_str = format!("data: {}\n\n", serde_json::to_string(&el).unwrap()); - info!("yield: {:?}", value_str); + info!("yield: {:?}", nicer_logs::first_n_chars(&value_str, 40)); yield Result::<_, String>::Ok(value_str); } } else { diff --git a/src/scratchpads/chat_passthrough.rs b/src/scratchpads/chat_passthrough.rs index 2bf32b00a..ed3201cca 100644 --- a/src/scratchpads/chat_passthrough.rs +++ b/src/scratchpads/chat_passthrough.rs @@ -57,7 +57,8 @@ impl ScratchpadAbstract for ChatPassthrough { _sampling_parameters_to_patch: &mut SamplingParameters, ) -> Result { info!("chat passthrough {} messages at start", &self.post.messages.len()); - run_at_commands(self.global_context.clone(), &mut self.post, 6, &mut self.has_vecdb_results).await; + let top_n = 6; + run_at_commands(self.global_context.clone(), &mut self.post, top_n, &mut self.has_vecdb_results).await; let limited_msgs: Vec = limit_messages_history_in_bytes(&self.post.messages, self.limit_bytes, &self.default_system_message)?; info!("chat passthrough {} messages -> {} messages after applying at-commands and limits, possibly adding the default system message", &self.post.messages.len(), &limited_msgs.len()); let mut filtered_msgs: Vec = Vec::::new(); @@ -83,12 +84,7 @@ impl ScratchpadAbstract for ChatPassthrough { let prompt = "PASSTHROUGH ".to_string() + &serde_json::to_string(&filtered_msgs).unwrap(); if DEBUG { for msg in &filtered_msgs { - let mut last_30_chars: String = msg.content.chars().take(30).collect(); - if last_30_chars.len() == 30 { - last_30_chars.push_str("..."); - } - last_30_chars = last_30_chars.replace("\n", "\\n"); - info!("filtered message role={} {}", msg.role, &last_30_chars); + info!("filtered message role={} {}", msg.role, crate::nicer_logs::first_n_chars(&msg.content, 40)); } } Ok(prompt.to_string()) diff --git a/src/scratchpads/chat_utils_limit_history.rs b/src/scratchpads/chat_utils_limit_history.rs index 6803031c6..7f7c4c4a0 100644 --- a/src/scratchpads/chat_utils_limit_history.rs +++ b/src/scratchpads/chat_utils_limit_history.rs @@ -1,5 +1,6 @@ use crate::scratchpad_abstract::HasTokenizerAndEot; use crate::call_validation::ChatMessage; +use tracing::info; pub fn limit_messages_history( @@ -73,6 +74,7 @@ pub fn limit_messages_history_in_bytes( } for i in (0..messages.len()).rev() { let bytes = messages[i].content.len(); + info!("limit_messages_history_in_bytes: message{}, bytes_used={} += {}", i, bytes_used, bytes); if !message_take[i] { if bytes_used + bytes < bytes_limit { message_take[i] = true; diff --git a/src/vecdb/vecdb.rs b/src/vecdb/vecdb.rs index cb171cc22..0789ff007 100644 --- a/src/vecdb/vecdb.rs +++ b/src/vecdb/vecdb.rs @@ -289,7 +289,7 @@ impl VecdbSearch for VecDb { }; info!("search itself {:.3}s", t1.elapsed().as_secs_f64()); for rec in results.iter() { - let last_30_chars: String = rec.file_path.display().to_string().chars().rev().take(30).collect::().chars().rev().collect(); + let last_30_chars = crate::nicer_logs::last_n_chars(&rec.file_path.display().to_string(), 30); info!("distance {:.3}, found ...{}:{}-{}, ", rec.distance, last_30_chars, rec.start_line, rec.end_line); } let t2 = std::time::Instant::now(); From b07d863329ace49d89a6e8d5e20fe425d9eb9f5a Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Fri, 2 Feb 2024 08:02:43 +0100 Subject: [PATCH 41/54] move structs to at_commands.rs --- src/at_commands/at_commands.rs | 50 ++++++++++++++++++++++++++-- src/at_commands/at_file.rs | 2 +- src/at_commands/at_params.rs | 2 +- src/at_commands/at_workspace.rs | 2 +- src/at_commands/mod.rs | 1 - src/at_commands/structs.rs | 52 ------------------------------ src/at_commands/utils.rs | 2 +- src/http/routers/v1/at_commands.rs | 4 +-- src/scratchpads/chat_utils_rag.rs | 2 +- 9 files changed, 55 insertions(+), 62 deletions(-) delete mode 100644 src/at_commands/structs.rs diff --git a/src/at_commands/at_commands.rs b/src/at_commands/at_commands.rs index b9cd65c55..cf4054e87 100644 --- a/src/at_commands/at_commands.rs +++ b/src/at_commands/at_commands.rs @@ -1,15 +1,61 @@ use std::collections::HashMap; use std::sync::Arc; use tokio::sync::Mutex as AMutex; +use tokio::sync::RwLock as ARwLock; +use async_trait::async_trait; +use crate::global_context::GlobalContext; use crate::at_commands::at_workspace::AtWorkspace; use crate::at_commands::at_file::AtFile; -use crate::at_commands::structs::AtCommand; +use crate::call_validation::ChatMessage; +pub struct AtCommandsContext { + pub global_context: Arc>, + pub at_commands: HashMap>>>, +} + +impl AtCommandsContext { + pub async fn new(global_context: Arc>) -> Self { + AtCommandsContext { + global_context, + at_commands: at_commands_dict().await, + } + } +} + +#[async_trait] +pub trait AtCommand: Send + Sync { + fn name(&self) -> &String; + fn params(&self) -> &Vec>>; + async fn are_args_valid(&self, args: &Vec, context: &AtCommandsContext) -> Vec; + async fn can_execute(&self, args: &Vec, context: &AtCommandsContext) -> bool; + async fn execute(&self, query: &String, args: &Vec, top_n: usize, context: &AtCommandsContext) -> Result; +} + +#[async_trait] +pub trait AtParam: Send + Sync { + fn name(&self) -> &String; + async fn is_value_valid(&self, value: &String, context: &AtCommandsContext) -> bool; + async fn complete(&self, value: &String, context: &AtCommandsContext, top_n: usize) -> Vec; +} + +pub struct AtCommandCall { + pub command: Arc>>, + pub args: Vec, +} + +impl AtCommandCall { + pub fn new(command: Arc>>, args: Vec) -> Self { + AtCommandCall { + command, + args + } + } +} pub async fn at_commands_dict() -> HashMap>>> { return HashMap::from([ ("@workspace".to_string(), Arc::new(AMutex::new(Box::new(AtWorkspace::new()) as Box))), ("@file".to_string(), Arc::new(AMutex::new(Box::new(AtFile::new()) as Box))), ]); -} \ No newline at end of file +} diff --git a/src/at_commands/at_file.rs b/src/at_commands/at_file.rs index 31f29f6c3..09f2271e7 100644 --- a/src/at_commands/at_file.rs +++ b/src/at_commands/at_file.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use async_trait::async_trait; use serde_json::json; -use crate::at_commands::structs::{AtCommand, AtCommandsContext, AtParam}; +use crate::at_commands::at_commands::{AtCommand, AtCommandsContext, AtParam}; use crate::at_commands::at_params::AtParamFilePath; use tokio::sync::Mutex as AMutex; use crate::call_validation::{ChatMessage, ContextFile}; diff --git a/src/at_commands/at_params.rs b/src/at_commands/at_params.rs index f731ade9c..516dcd8f0 100644 --- a/src/at_commands/at_params.rs +++ b/src/at_commands/at_params.rs @@ -1,4 +1,4 @@ -use crate::at_commands::structs::{AtCommandsContext, AtParam}; +use crate::at_commands::at_commands::{AtCommandsContext, AtParam}; use async_trait::async_trait; use itertools::Itertools; use strsim::jaro_winkler; diff --git a/src/at_commands/at_workspace.rs b/src/at_commands/at_workspace.rs index d6208d221..ba66c3d29 100644 --- a/src/at_commands/at_workspace.rs +++ b/src/at_commands/at_workspace.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use async_trait::async_trait; use serde_json::json; -use crate::at_commands::structs::{AtCommand, AtCommandsContext, AtParam}; +use crate::at_commands::at_commands::{AtCommand, AtCommandsContext, AtParam}; use tokio::sync::Mutex as AMutex; use crate::call_validation::{ChatMessage, ContextFile}; use crate::vecdb::structs::{Record, VecdbSearch}; diff --git a/src/at_commands/mod.rs b/src/at_commands/mod.rs index 1c1c55818..070648a15 100644 --- a/src/at_commands/mod.rs +++ b/src/at_commands/mod.rs @@ -1,7 +1,6 @@ pub mod at_commands; pub mod at_file; pub mod at_workspace; -pub mod structs; pub mod at_params; pub mod utils; pub mod query; diff --git a/src/at_commands/structs.rs b/src/at_commands/structs.rs deleted file mode 100644 index b8e095d9f..000000000 --- a/src/at_commands/structs.rs +++ /dev/null @@ -1,52 +0,0 @@ -use std::collections::HashMap; -use std::sync::Arc; -use async_trait::async_trait; -use crate::global_context::GlobalContext; -use crate::at_commands::at_commands::at_commands_dict; -use tokio::sync::RwLock as ARwLock; -use tokio::sync::Mutex as AMutex; -use crate::call_validation::ChatMessage; - -pub struct AtCommandsContext { - pub global_context: Arc>, - pub at_commands: HashMap>>>, -} - -impl AtCommandsContext { - pub async fn new(global_context: Arc>) -> Self { - AtCommandsContext { - global_context, - at_commands: at_commands_dict().await, - } - } -} - -#[async_trait] -pub trait AtCommand: Send + Sync { - fn name(&self) -> &String; - fn params(&self) -> &Vec>>; - async fn are_args_valid(&self, args: &Vec, context: &AtCommandsContext) -> Vec; - async fn can_execute(&self, args: &Vec, context: &AtCommandsContext) -> bool; - async fn execute(&self, query: &String, args: &Vec, top_n: usize, context: &AtCommandsContext) -> Result; -} - -#[async_trait] -pub trait AtParam: Send + Sync { - fn name(&self) -> &String; - async fn is_value_valid(&self, value: &String, context: &AtCommandsContext) -> bool; - async fn complete(&self, value: &String, context: &AtCommandsContext, top_n: usize) -> Vec; -} - -pub struct AtCommandCall { - pub command: Arc>>, - pub args: Vec, -} - -impl AtCommandCall { - pub fn new(command: Arc>>, args: Vec) -> Self { - AtCommandCall { - command, - args - } - } -} diff --git a/src/at_commands/utils.rs b/src/at_commands/utils.rs index 58836f0e5..6d1da294e 100644 --- a/src/at_commands/utils.rs +++ b/src/at_commands/utils.rs @@ -1,5 +1,5 @@ use std::sync::Arc; -use crate::at_commands::structs::{AtCommandCall, AtCommandsContext}; +use crate::at_commands::at_commands::{AtCommandCall, AtCommandsContext}; use tracing::info; diff --git a/src/http/routers/v1/at_commands.rs b/src/http/routers/v1/at_commands.rs index db415be3e..cf13ae2d3 100644 --- a/src/http/routers/v1/at_commands.rs +++ b/src/http/routers/v1/at_commands.rs @@ -3,11 +3,11 @@ use axum::Extension; use hyper::{Body, Response, StatusCode}; use serde::{Deserialize, Serialize}; use std::sync::Arc; -use itertools::Itertools; use serde_json::{json, Value}; use tokio::sync::RwLock as ARwLock; use strsim::jaro_winkler; -use crate::at_commands::structs::{AtCommand, AtCommandsContext, AtParam}; +use itertools::Itertools; +use crate::at_commands::at_commands::AtCommandsContext; use crate::at_commands::query::QueryLine; use crate::custom_error::ScratchError; diff --git a/src/scratchpads/chat_utils_rag.rs b/src/scratchpads/chat_utils_rag.rs index 5156d46c5..d51d49804 100644 --- a/src/scratchpads/chat_utils_rag.rs +++ b/src/scratchpads/chat_utils_rag.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use serde_json::{json, Value}; use tokio::sync::RwLock as ARwLock; -use crate::at_commands::structs::{AtCommand, AtCommandsContext}; +use crate::at_commands::at_commands::AtCommandsContext; use crate::call_validation::{ChatMessage, ChatPost}; use crate::global_context::GlobalContext; From 268370a4963da6e01d5526d01601e8f60b0e12ae Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Fri, 2 Feb 2024 08:29:58 +0100 Subject: [PATCH 42/54] usefullness --- src/at_commands/at_file.rs | 1 + src/at_commands/at_workspace.rs | 4 +++- src/call_validation.rs | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/at_commands/at_file.rs b/src/at_commands/at_file.rs index 09f2271e7..0dbc1de32 100644 --- a/src/at_commands/at_file.rs +++ b/src/at_commands/at_file.rs @@ -65,6 +65,7 @@ impl AtCommand for AtFile { file_content: path_and_text.file_text.clone(), line1: 0, line2: path_and_text.file_text.lines().count() as i32, + usefullness: 100.0, }); Ok(ChatMessage { role: "context_file".to_string(), diff --git a/src/at_commands/at_workspace.rs b/src/at_commands/at_workspace.rs index ba66c3d29..c863dfc71 100644 --- a/src/at_commands/at_workspace.rs +++ b/src/at_commands/at_workspace.rs @@ -23,12 +23,14 @@ impl AtWorkspace { fn results2message(results: &Vec) -> ChatMessage { let mut vector_of_context_file: Vec = vec![]; - for r in results { + for i in 0..results.len() { + let r = &results[i]; vector_of_context_file.push(ContextFile { file_name: r.file_path.to_str().unwrap().to_string(), file_content: r.window_text.clone(), line1: r.start_line as i32, line2: r.end_line as i32, + usefullness: 100.0 / ((i + 1) as f32), }); } ChatMessage { diff --git a/src/call_validation.rs b/src/call_validation.rs index fb469d6f8..eeb7d4c75 100644 --- a/src/call_validation.rs +++ b/src/call_validation.rs @@ -171,7 +171,6 @@ mod tests { }; assert!(validate_post(post).is_err()); } - } #[derive(Debug, Serialize, Deserialize, Clone)] @@ -180,6 +179,7 @@ pub struct ContextFile { pub file_content: String, pub line1: i32, pub line2: i32, + pub usefullness: f32, // the higher the better } #[derive(Debug, Serialize, Deserialize, Clone)] From 54e557d3bff7e54eef5689103f0e5a1e3be8e11b Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Fri, 2 Feb 2024 10:24:21 +0100 Subject: [PATCH 43/54] postprocess_at_results() --- examples/luke1.py | 21 +++- src/caps.rs | 10 ++ src/scratchpads/chat_utils_limit_history.rs | 1 + src/scratchpads/chat_utils_rag.rs | 101 +++++++++++++++++++- 4 files changed, 124 insertions(+), 9 deletions(-) diff --git a/examples/luke1.py b/examples/luke1.py index b48adf7b6..bbcd3e994 100644 --- a/examples/luke1.py +++ b/examples/luke1.py @@ -59,7 +59,7 @@ def ask_chat(messages): "messages": messages, "temperature": 0.1, "max_tokens": 300, - "model": "gpt-3.5-turbo", + "model": "gpt-3.5-turbo-0125", }, headers={ "Content-Type": "application/json", @@ -67,13 +67,9 @@ def ask_chat(messages): }, timeout=60, ) - # data: {"choices":[{"delta":{"content":"The","role":"assistant"},"finish_reason":null,"index":0}],"created":1706779319.409,"model":"gpt-3.5-turbo"} - # data: {"choices":[{"delta":{"content":" code","role":"assistant"},"finish_reason":null,"index":0}],"created":1706779319.409,"model":"gpt-3.5-turbo"} - # Collect all delta/content from the response messages_back = [] accum_content = "" accum_role = "" - # print(response.text) for x in response.text.splitlines(): if not x.strip(): continue @@ -118,6 +114,7 @@ def dialog_turn(messages): for msgdict in messages_back: print(termcolor.colored(msgdict["role"], "blue")) print(termcolor.colored(msgdict["content"], "red")) + assistant_says = messages_back[-1]["content"] messages_without_last_user = messages[:-1] next_step_messages = messages_without_last_user + messages_back @@ -139,3 +136,17 @@ def do_all(): do_all() + + + + + +def do_less(): + messages = [ + {"role": "user", "content": "@workspace DeltaDeltaChatStreamer\n@file /home/user/.refact/tmp/unpacked-files/refact-lsp/src/scratchpads/chat_utils_deltadelta.rs\nHello world"} + ] + messages_back = ask_chat(messages) + print(messages_back) + + +# do_less() diff --git a/src/caps.rs b/src/caps.rs index c950019a4..08b52c6a4 100644 --- a/src/caps.rs +++ b/src/caps.rs @@ -160,6 +160,16 @@ const KNOWN_MODELS: &str = r####" "similar_models": [ ] }, + "gpt-3.5-turbo-0125": { + "n_ctx": 16384, + "supports_scratchpads": { + "PASSTHROUGH": { + "default_system_message": "You are a coding assistant that outputs short answers, gives links to documentation." + } + }, + "similar_models": [ + ] + }, "gpt-4": { "n_ctx": 4096, "supports_scratchpads": { diff --git a/src/scratchpads/chat_utils_limit_history.rs b/src/scratchpads/chat_utils_limit_history.rs index 7f7c4c4a0..c1402bc76 100644 --- a/src/scratchpads/chat_utils_limit_history.rs +++ b/src/scratchpads/chat_utils_limit_history.rs @@ -80,6 +80,7 @@ pub fn limit_messages_history_in_bytes( message_take[i] = true; bytes_used += bytes; } else { + info!("limit_messages_history_in_bytes: overflow, drop message {} and before", i); break; } } diff --git a/src/scratchpads/chat_utils_rag.rs b/src/scratchpads/chat_utils_rag.rs index d51d49804..6f564c62d 100644 --- a/src/scratchpads/chat_utils_rag.rs +++ b/src/scratchpads/chat_utils_rag.rs @@ -1,12 +1,96 @@ use std::sync::Arc; +use std::cmp::Ordering; +use tracing::info; use serde_json::{json, Value}; use tokio::sync::RwLock as ARwLock; use crate::at_commands::at_commands::AtCommandsContext; -use crate::call_validation::{ChatMessage, ChatPost}; +use crate::call_validation::{ChatMessage, ChatPost, ContextFile}; use crate::global_context::GlobalContext; +const SMALL_GAP_LINES: i32 = 10; // lines + +pub fn postprocess_at_results( + messages: Vec, + max_bytes: usize, +) -> Vec { + // 1. Decode all + let mut cxfile_list: Vec = vec![]; + for msg in messages { + cxfile_list.extend(serde_json::from_str::>(&msg.content).unwrap()); + } + // 2. Sort by usefullness + cxfile_list.sort_by(|a, b| { + b.usefullness.partial_cmp(&a.usefullness).unwrap_or(Ordering::Equal) + }); + for cxfile in cxfile_list.iter() { + info!("sorted file {}:{}-{} usefullness {}", crate::nicer_logs::last_n_chars(&cxfile.file_name, 40), cxfile.line1, cxfile.line2, cxfile.usefullness); + } + // 3. Truncate less useful to max_bytes + let mut total_bytes: usize = cxfile_list.iter().map(|x| x.file_content.len()).sum(); + while total_bytes > max_bytes { + let least_useful = cxfile_list.pop(); + match least_useful { + Some(file) => { + total_bytes -= file.file_content.len(); + }, + None => break, + } + } + // 4. Remove small gaps in lines and deduplicate + let mut merged: Vec = vec![]; + let list_len = cxfile_list.len(); + let mut eaten: Vec = vec![false; list_len]; + loop { + let mut merged_anything = false; + let cxfile_list_copy = cxfile_list.clone(); // unnecessary operation because of rust borrow rules :/ + for i in 0..list_len { + if eaten[i] { + continue; + } + let x: &mut ContextFile = cxfile_list.get_mut(i).unwrap(); + for j in (i+1)..list_len { + if eaten[j] { + continue; + } + let y: &ContextFile = cxfile_list_copy.get(j).unwrap(); + if x.file_name != y.file_name { + continue; + } + let possible_merge_line1 = x.line1.min(y.line1); + let possible_merge_line2 = x.line2.max(y.line2); + if possible_merge_line2 - possible_merge_line1 <= (x.line2 - x.line1) + (y.line2 - y.line1) + SMALL_GAP_LINES { + // good, makes sense to merge + info!("merging file {} range {}-{} with range {}-{}", x.file_name, x.line1, x.line2, y.line1, y.line2); + eaten[j] = true; + x.line1 = possible_merge_line1; + x.line2 = possible_merge_line2; + merged_anything = true; + } + } + } + if !merged_anything { + break; + } + } + for i in 0..list_len { + if eaten[i] { + continue; + } + merged.push(cxfile_list[i].clone()); + info!("merged {}:{}-{}", cxfile_list[i].file_name, cxfile_list[i].line1, cxfile_list[i].line2); + } + // 5. Encode back into a single message + let mut processed_messages: Vec = vec![]; + let message = ChatMessage { + role: "user".to_string(), + content: serde_json::to_string(&merged).unwrap(), + }; + processed_messages.push(message); + processed_messages +} + pub async fn run_at_commands( global_context: Arc>, post: &mut ChatPost, @@ -18,16 +102,25 @@ pub async fn run_at_commands( let mut query = post.messages.last().unwrap().content.clone(); // latest_msg_cont let valid_commands = crate::at_commands::utils::find_valid_at_commands_in_query(&mut query, &context).await; + let mut messages_for_postprocessing = vec![]; for cmd in valid_commands { match cmd.command.lock().await.execute(&query, &cmd.args, top_n, &context).await { Ok(msg) => { - post.messages.push(msg.clone()); - has_vecdb.push_in_json(json!(msg)); + messages_for_postprocessing.push(msg); }, Err(_) => {} } } - let msg = ChatMessage{ + let max_bytes = 5*1024; + let processed = postprocess_at_results( + messages_for_postprocessing, + max_bytes + ); + for msg in processed { + post.messages.push(msg.clone()); + has_vecdb.push_in_json(json!(msg)); + } + let msg = ChatMessage { role: "user".to_string(), content: query, // stream back to the user, without commands }; From f1bd3c8c1224649c2fdd6100eee34101dac28227 Mon Sep 17 00:00:00 2001 From: Valerii Date: Tue, 6 Feb 2024 11:02:07 +0000 Subject: [PATCH 44/54] Race fix (#79) --- src/vecdb/file_watcher_service.rs | 42 ++++++++++++++++++------------- src/vecdb/vecdb.rs | 14 ++++++----- 2 files changed, 32 insertions(+), 24 deletions(-) diff --git a/src/vecdb/file_watcher_service.rs b/src/vecdb/file_watcher_service.rs index 7edb36d2f..c651a53ff 100644 --- a/src/vecdb/file_watcher_service.rs +++ b/src/vecdb/file_watcher_service.rs @@ -14,6 +14,7 @@ use tokio::io::BufReader; use tokio::sync::RwLock as ARwLock; use crate::global_context::GlobalContext; +use crate::vecdb::vecdb::VecDb; fn make_async_watcher() -> notify::Result<(RecommendedWatcher, Receiver>)> { let (mut tx, rx) = channel(1); @@ -53,35 +54,40 @@ async fn parse_jsonl(path: &PathBuf) -> Result, String> { Ok(paths) } +pub async fn read_and_load_jsonl( + files_set_path: &PathBuf, + vec_db: &VecDb +) { + let filenames_vec = match parse_jsonl(&files_set_path).await { + Ok(data) => data, + Err(_) => { + info!("invalid jsonl file: {:?}", files_set_path); + vec![] + } + }; + vec_db.add_or_update_files(filenames_vec, true).await; +} + pub async fn file_watcher_task( + files_set_path: PathBuf, global_context: Arc>, ) -> () { let (mut watcher, mut rx) = make_async_watcher().expect("Failed to make file watcher"); - let maybe_path = global_context.read().await.cmdline.files_set_path.clone(); - if maybe_path.is_empty() { - info!("files_set_path is empty: no files to watch"); + if files_set_path.to_str().unwrap_or("").is_empty() { + info!("files_set_path is empty. Exiting."); return; } - let path = PathBuf::from(maybe_path); let load_data = || async { - let filenames_vec = match parse_jsonl(&path).await { - Ok(data) => data, - Err(_) => { - info!("invalid jsonl file: {:?}", path); - vec![] - } - }; match *global_context.read().await.vec_db.lock().await { - Some(ref mut db) => db.add_or_update_files(filenames_vec, true).await, + Some(ref mut db) => read_and_load_jsonl(&files_set_path, db).await, None => {} - }; + } }; - if watcher.watch(path.as_ref(), RecursiveMode::Recursive).is_err() { - error!("file watcher {:?} failed to start watching", path); + if watcher.watch(&files_set_path, RecursiveMode::Recursive).is_err() { + error!("file watcher {:?} failed to start watching", files_set_path); return; } - load_data().await; while let Some(res) = rx.next().await { match res { Ok(event) => { @@ -89,13 +95,13 @@ pub async fn file_watcher_task( EventKind::Any => {} EventKind::Access(_) => {} EventKind::Create(_) => { - info!("file {:?} was created", path) + info!("file {:?} was created", files_set_path) } EventKind::Modify(_) => { load_data().await; } EventKind::Remove(_) => { - info!("file {:?} was removed", path) + info!("file {:?} was removed", files_set_path) // TODO: should we remove everything inside the database? } EventKind::Other => {} diff --git a/src/vecdb/vecdb.rs b/src/vecdb/vecdb.rs index 0789ff007..9aeb4d5dd 100644 --- a/src/vecdb/vecdb.rs +++ b/src/vecdb/vecdb.rs @@ -15,6 +15,7 @@ use crate::background_tasks::BackgroundTasksHolder; use crate::fetch_embedding; use crate::vecdb; use crate::vecdb::file_filter; +use crate::vecdb::file_watcher_service::read_and_load_jsonl; use crate::vecdb::handler::VecDBHandler; use crate::vecdb::vectorizer_service::FileVectorizerService; use crate::vecdb::structs::{SearchResult, VecdbSearch, VecDbStatus, VecdbConstants}; @@ -117,15 +118,16 @@ async fn create_vecdb( db.init_folders(folders).await; } } - // FIXME: - // Race between file_watcher_task that puts files into the queue and - // vectorizer_service that can report empty queue as "vectorization finished" - let mut tasks = vec_db.start_background_tasks().await; - tasks.extend(vec![tokio::spawn(vecdb::file_watcher_service::file_watcher_task(global_context.clone()))]); - background_tasks.extend(tasks); + let files_set_path = PathBuf::from(gcx_locked.cmdline.files_set_path.clone()); + read_and_load_jsonl(&files_set_path, &vec_db).await; + let mut tasks = vec_db.start_background_tasks().await; gcx_locked.vec_db = Arc::new(AMutex::new(Some(vec_db))); + tasks.extend(vec![ + tokio::spawn(vecdb::file_watcher_service::file_watcher_task(files_set_path, global_context.clone())) + ]); + background_tasks.extend(tasks); } Ok(()) } From 6a19e3a932737e1fc39c438de0c99cc659d0efab Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Tue, 6 Feb 2024 16:04:01 +0100 Subject: [PATCH 45/54] chat_with_at_command example and friends --- examples/chat_with_at_command.py | 57 +++++++++++++++++++++++++++++ examples/http_completion.sh | 1 + examples/http_vecdb_search.sh | 7 ++++ examples/luke1.py | 60 +++---------------------------- src/scratchpads/chat_utils_rag.rs | 2 +- 5 files changed, 70 insertions(+), 57 deletions(-) create mode 100644 examples/chat_with_at_command.py create mode 100644 examples/http_vecdb_search.sh diff --git a/examples/chat_with_at_command.py b/examples/chat_with_at_command.py new file mode 100644 index 000000000..6b79b9d77 --- /dev/null +++ b/examples/chat_with_at_command.py @@ -0,0 +1,57 @@ +import requests, json, termcolor + + +initial_messages = [ +{"role": "user", "content": + "@file /home/user/.refact/tmp/unpacked-files/refact-lsp/src/scratchpads/chat_utils_deltadelta.rs\n" + + "summarize the code in one paragraph" +}, +] + +def ask_chat(messages): + response = requests.post( + "http://127.0.0.1:8001/v1/chat", + json={ + "messages": messages, + "temperature": 0.1, + "max_tokens": 300, + "model": "gpt-3.5-turbo-0125", + }, + headers={ + "Content-Type": "application/json", + "Authorization": "Bearer XXX", + }, + timeout=60, + ) + messages_back = [] + accum_content = "" + accum_role = "" + for x in response.text.splitlines(): + if not x.strip(): + continue + if not x.startswith("data: "): + print(x) + print("ERROR: unexpected response format") + continue + if x[6:].startswith("[DONE]"): + break + j = json.loads(x[6:]) + if "choices" in j: + # streaming + choice0 = j["choices"][0] + accum_role = choice0["delta"]["role"] + accum_content += choice0["delta"]["content"] + else: + # content/role without streaming, replacing the last user message + messages_back.append({"role": j["role"], "content": j["content"]}) + if accum_role: + messages_back.append({"role": accum_role, "content": accum_content}) + return messages_back + + + +messages_back = ask_chat(initial_messages) +for msgdict in messages_back: + print(termcolor.colored(msgdict["role"], "blue")) + print(termcolor.colored(msgdict["content"], "white")) + diff --git a/examples/http_completion.sh b/examples/http_completion.sh index dcc64cf03..35b4f32d4 100644 --- a/examples/http_completion.sh +++ b/examples/http_completion.sh @@ -11,6 +11,7 @@ curl http://127.0.0.1:8001/v1/code-completion -k \ "multiline": true }, "stream": false, + "no_cache": true, "parameters": { "temperature": 0.1, "max_new_tokens": 20 diff --git a/examples/http_vecdb_search.sh b/examples/http_vecdb_search.sh new file mode 100644 index 000000000..e67e79c73 --- /dev/null +++ b/examples/http_vecdb_search.sh @@ -0,0 +1,7 @@ +curl http://127.0.0.1:8001/v1/vdb-search -k \ + -H 'Content-Type: application/json' \ + -d '{ + "query": "Hello world", + "top_n": 3 +}' + diff --git a/examples/luke1.py b/examples/luke1.py index bbcd3e994..52abcdb21 100644 --- a/examples/luke1.py +++ b/examples/luke1.py @@ -1,4 +1,5 @@ -import requests, json, termcolor +import termcolor +import chat_with_at_command my_prompt = """ @@ -52,46 +53,6 @@ "[PROVIDE_COMMANDS_STEP]\n"}, ] -def ask_chat(messages): - response = requests.post( - "http://127.0.0.1:8001/v1/chat", - json={ - "messages": messages, - "temperature": 0.1, - "max_tokens": 300, - "model": "gpt-3.5-turbo-0125", - }, - headers={ - "Content-Type": "application/json", - "Authorization": "Bearer XXX", - }, - timeout=60, - ) - messages_back = [] - accum_content = "" - accum_role = "" - for x in response.text.splitlines(): - if not x.strip(): - continue - if not x.startswith("data: "): - print(x) - print("ERROR: unexpected response format") - continue - if x[6:].startswith("[DONE]"): - break - j = json.loads(x[6:]) - if "choices" in j: - # streaming - choice0 = j["choices"][0] - accum_role = choice0["delta"]["role"] - accum_content += choice0["delta"]["content"] - else: - # content/role without streaming, replacing the last user message - messages_back.append({"role": j["role"], "content": j["content"]}) - if accum_role: - messages_back.append({"role": accum_role, "content": accum_content}) - return messages_back - def rewrite_assistant_says_to_at_commands(ass): out = "" @@ -110,10 +71,10 @@ def dialog_turn(messages): for msgdict in messages: print(termcolor.colored(msgdict["role"], "blue")) print(termcolor.colored(msgdict["content"], "green")) - messages_back = ask_chat(messages) + messages_back = chat_with_at_command.ask_chat(messages) for msgdict in messages_back: print(termcolor.colored(msgdict["role"], "blue")) - print(termcolor.colored(msgdict["content"], "red")) + print(termcolor.colored(msgdict["content"], "white")) assistant_says = messages_back[-1]["content"] messages_without_last_user = messages[:-1] @@ -137,16 +98,3 @@ def do_all(): do_all() - - - - -def do_less(): - messages = [ - {"role": "user", "content": "@workspace DeltaDeltaChatStreamer\n@file /home/user/.refact/tmp/unpacked-files/refact-lsp/src/scratchpads/chat_utils_deltadelta.rs\nHello world"} - ] - messages_back = ask_chat(messages) - print(messages_back) - - -# do_less() diff --git a/src/scratchpads/chat_utils_rag.rs b/src/scratchpads/chat_utils_rag.rs index 6f564c62d..257837496 100644 --- a/src/scratchpads/chat_utils_rag.rs +++ b/src/scratchpads/chat_utils_rag.rs @@ -84,7 +84,7 @@ pub fn postprocess_at_results( // 5. Encode back into a single message let mut processed_messages: Vec = vec![]; let message = ChatMessage { - role: "user".to_string(), + role: "context_file".to_string(), content: serde_json::to_string(&merged).unwrap(), }; processed_messages.push(message); From 6890ac6632fe8eda00f7a189c3244b35be454244 Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Tue, 6 Feb 2024 16:31:44 +0100 Subject: [PATCH 46/54] chat_with_at_command -- improved version --- examples/chat_with_at_command.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/examples/chat_with_at_command.py b/examples/chat_with_at_command.py index 6b79b9d77..ff3e4df9c 100644 --- a/examples/chat_with_at_command.py +++ b/examples/chat_with_at_command.py @@ -3,7 +3,8 @@ initial_messages = [ {"role": "user", "content": - "@file /home/user/.refact/tmp/unpacked-files/refact-lsp/src/scratchpads/chat_utils_deltadelta.rs\n" + + "@workspace definition of DeltaDeltaChatStreamer\n" + + # "@file /home/user/.refact/tmp/unpacked-files/refact-lsp/src/scratchpads/chat_utils_deltadelta.rs\n" + "summarize the code in one paragraph" }, ] @@ -49,9 +50,17 @@ def ask_chat(messages): return messages_back +def example_single_response(): + messages_back = ask_chat(initial_messages) + for msgdict in messages_back: + print(termcolor.colored(msgdict["role"], "blue")) + if msgdict["role"] == "context_file": + for x in json.loads(msgdict["content"]): + print("%s:%i-%i" % (x["file_name"], x["line1"], x["line2"])) + else: + print(termcolor.colored(msgdict["content"], "white")) + -messages_back = ask_chat(initial_messages) -for msgdict in messages_back: - print(termcolor.colored(msgdict["role"], "blue")) - print(termcolor.colored(msgdict["content"], "white")) +if __name__ == "__main__": + example_single_response() From 9fbc479e49876850b340dd54f42f1180c2d5c621 Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Tue, 6 Feb 2024 16:41:55 +0100 Subject: [PATCH 47/54] chat_at_command_completion.py --- examples/chat_at_command_completion.py | 48 ++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 examples/chat_at_command_completion.py diff --git a/examples/chat_at_command_completion.py b/examples/chat_at_command_completion.py new file mode 100644 index 000000000..a93f77a01 --- /dev/null +++ b/examples/chat_at_command_completion.py @@ -0,0 +1,48 @@ +import json +import requests +import termcolor + + +url = "http://localhost:8001/v1/at-command-completion" + + +def test_at_command_completion(query): + query_real = query.replace("", "") + payload = json.dumps({ + "query": query_real, + "cursor": query.find(""), + "top_n": 3, + }) + response = requests.post(url, data=payload) + print(payload) + print(termcolor.colored(response.text, 'red')) + j = json.loads(response.text) + r = j["replace"] + if len(j["completions"]) > 0: + query_completed = query_real[:r[0]] + j["completions"][0] + query_real[r[1]:] + print(query_completed) + else: + print("no completions") + + +test_at_command_completion(""" +other line -3 +other line -2 +other line -1 +@file deltadelta.rs +other line 1 +other line 2 +other line 3 +""") + +test_at_command_completion(""" +other line -3 +other line -2 +other line -1 +@work +other line 1 +other line 2 +other line 3 +""") + + From c418516e85d9ddd98013ffb05c67bd38b4d69db4 Mon Sep 17 00:00:00 2001 From: JegernOUTT Date: Thu, 8 Feb 2024 18:10:27 +1030 Subject: [PATCH 48/54] turn the embedding index creation off --- src/vecdb/handler.rs | 4 +--- src/vecdb/vectorizer_service.rs | 14 ++++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/vecdb/handler.rs b/src/vecdb/handler.rs index 7ea838236..bd2da7101 100644 --- a/src/vecdb/handler.rs +++ b/src/vecdb/handler.rs @@ -572,9 +572,7 @@ impl VecDBHandler { pub async fn create_index(&mut self) -> vectordb::error::Result<()> { let size = self.size().await.unwrap_or(0); if size == 0 { - return Err(vectordb::error::Error::Lance { - message: "The vector database is empty".to_string(), - }.into()); + return Ok(()); } self.data_table.create_index( IvfPQIndexBuilder::default() diff --git a/src/vecdb/vectorizer_service.rs b/src/vecdb/vectorizer_service.rs index 0aefb2b43..f8fd56d8f 100644 --- a/src/vecdb/vectorizer_service.rs +++ b/src/vecdb/vectorizer_service.rs @@ -111,11 +111,12 @@ async fn vectorize_thread( info!("update_indexed_file_paths: it took {:.3}s", t0.elapsed().as_secs_f64()); reported_vecdb_complete = true; - info!("VECDB Creating index"); - match vecdb_handler_ref.lock().await.create_index().await { - Ok(_) => info!("VECDB CREATED INDEX"), - Err(err) => info!("VECDB Error creating index: {}", err) - } + // By the time we do not create index 'cause it hurts quality of retrieval + // info!("VECDB Creating index"); + // match vecdb_handler_ref.lock().await.create_index().await { + // Ok(_) => info!("VECDB CREATED INDEX"), + // Err(err) => info!("VECDB Error creating index: {}", err) + // } write!(std::io::stderr(), "VECDB COMPLETE\n").unwrap(); info!("VECDB COMPLETE"); // you can see "VECDB COMPLETE" sometimes faster vs logs } @@ -220,7 +221,8 @@ async fn cleanup_thread(vecdb_handler: Arc>) { { let mut vecdb = vecdb_handler.lock().await; let _ = vecdb.cleanup_old_records().await; - let _ = vecdb.create_index().await; + // By the time we do not create index 'cause it hurts quality of retrieval + // let _ = vecdb.create_index().await; } tokio::time::sleep(tokio::time::Duration::from_secs(2 * 3600)).await; } From 8df70b8e9a90f285989971871e5bc57206cf34c2 Mon Sep 17 00:00:00 2001 From: Valerii Date: Thu, 8 Feb 2024 09:08:15 +0000 Subject: [PATCH 49/54] at-commands work even in args are spelled incorrectly, tries to correct them first --- src/at_commands/at_file.rs | 6 ++--- src/at_commands/utils.rs | 49 ++++++++++++++++++++++++++++++++++---- 2 files changed, 47 insertions(+), 8 deletions(-) diff --git a/src/at_commands/at_file.rs b/src/at_commands/at_file.rs index 0dbc1de32..5821009d9 100644 --- a/src/at_commands/at_file.rs +++ b/src/at_commands/at_file.rs @@ -49,11 +49,11 @@ impl AtCommand for AtFile { async fn execute(&self, _query: &String, args: &Vec, _top_n: usize, context: &AtCommandsContext) -> Result { let can_execute = self.can_execute(args, context).await; + if !can_execute { + return Err("incorrect arguments".to_string()); + } match *context.global_context.read().await.vec_db.lock().await { Some(ref db) => { - if !can_execute { - return Err("incorrect arguments".to_string()); - } let file_path = match args.get(0) { Some(x) => x, None => return Err("no file path".to_string()), diff --git a/src/at_commands/utils.rs b/src/at_commands/utils.rs index 6d1da294e..cae9eb51e 100644 --- a/src/at_commands/utils.rs +++ b/src/at_commands/utils.rs @@ -1,6 +1,7 @@ use std::sync::Arc; -use crate::at_commands::at_commands::{AtCommandCall, AtCommandsContext}; +use crate::at_commands::at_commands::{AtCommandCall, AtCommandsContext, AtParam}; use tracing::info; +use tokio::sync::Mutex as AMutex; pub async fn find_valid_at_commands_in_query( @@ -22,10 +23,15 @@ pub async fn find_valid_at_commands_in_query( Some(x) => x, None => continue, }; - if !cmd.lock().await.can_execute(&q_cmd_args, context).await { - info!("command {:?} is not executable with arguments {:?}", q_cmd, q_cmd_args); - continue; - } + let can_execute = cmd.lock().await.can_execute(&q_cmd_args, context).await; + let q_cmd_args = match correct_arguments_if_needed(cmd.lock().await.params(), &q_cmd_args, can_execute, context).await { + Ok(x) => x, + Err(e) => { + info!("command {:?} is not executable with arguments {:?}; error: {:?}", q_cmd, q_cmd_args, e); + continue; + } + }; + info!("command {:?} is perfectly good", q_cmd); results.push(AtCommandCall::new(Arc::clone(&cmd), q_cmd_args.clone())); valid_command_lines.push(idx); @@ -37,3 +43,36 @@ pub async fn find_valid_at_commands_in_query( .collect::>().join("\n"); results } + +pub async fn correct_arguments_if_needed( + params: &Vec>>, + args: &Vec, + can_execute: bool, + context: &AtCommandsContext, +) -> Result, String> { + if can_execute { + return Ok(args.clone()); + } + if params.len() != args.len() { + return Err(format!("incorrect number of arguments: {} given; {} required", args.len(), params.len())); + } + let mut args_new = vec![]; + for (param, arg) in params.iter().zip(args.iter()) { + let param = param.lock().await; + if param.is_value_valid(arg, context).await { + args_new.push(arg.clone()); + continue; + } + let completion = param.complete(arg, context, 1).await; + let arg_completed = match completion.get(0) { + Some(x) => x, + None => return Err(format!("arg '{}' is not valid and correction failed", arg)), + }; + if !param.is_value_valid(arg_completed, context).await { + return Err(format!("arg '{}' is not valid and correction failed", arg)); + } + info!("arg '{}' is corrected as '{}'", arg, arg_completed); + args_new.push(arg_completed.clone()); + } + Ok(args_new) +} From 699ac57a7a606da4f0dae1e0cdb1f834cbdc5e0f Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Thu, 8 Feb 2024 14:19:18 +0100 Subject: [PATCH 50/54] typo --- src/scratchpads/chat_generic.rs | 2 +- src/scratchpads/chat_llama2.rs | 2 +- src/scratchpads/completion_single_file_fim.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/scratchpads/chat_generic.rs b/src/scratchpads/chat_generic.rs index c20d7a865..997b3e0bf 100644 --- a/src/scratchpads/chat_generic.rs +++ b/src/scratchpads/chat_generic.rs @@ -128,7 +128,7 @@ impl ScratchpadAbstract for GenericChatScratchpad { } if DEBUG { info!("chat prompt\n{}", prompt); - info!("chat re-encode whole prompt again gives {} tokes", self.t.count_tokens(prompt.as_str())?); + info!("chat re-encode whole prompt again gives {} tokens", self.t.count_tokens(prompt.as_str())?); } Ok(prompt) } diff --git a/src/scratchpads/chat_llama2.rs b/src/scratchpads/chat_llama2.rs index 8de6edcdc..24d86be2d 100644 --- a/src/scratchpads/chat_llama2.rs +++ b/src/scratchpads/chat_llama2.rs @@ -117,7 +117,7 @@ impl ScratchpadAbstract for ChatLlama2 { if DEBUG { // info!("llama2 chat vdb_suggestion {:?}", vdb_suggestion); info!("llama2 chat prompt\n{}", prompt); - info!("llama2 chat re-encode whole prompt again gives {} tokes", self.t.count_tokens(prompt.as_str())?); + info!("llama2 chat re-encode whole prompt again gives {} tokens", self.t.count_tokens(prompt.as_str())?); } Ok(prompt) } diff --git a/src/scratchpads/completion_single_file_fim.rs b/src/scratchpads/completion_single_file_fim.rs index 0e2d181e7..1b7720dac 100644 --- a/src/scratchpads/completion_single_file_fim.rs +++ b/src/scratchpads/completion_single_file_fim.rs @@ -203,7 +203,7 @@ impl ScratchpadAbstract for SingleFileFIM { if DEBUG { info!("cursor position\n{:?}", self.post.inputs.cursor); info!("prompt\n{}", prompt); - info!("re-encode whole prompt again gives {} tokes", self.t.count_tokens(prompt.as_str())?); + info!("re-encode whole prompt again gives {} tokens", self.t.count_tokens(prompt.as_str())?); } Ok(prompt) } From 0689c2741b49881e77b5d1984e405a75866db5e0 Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Thu, 8 Feb 2024 14:21:46 +0100 Subject: [PATCH 51/54] fix comments, don't retry vectorization (infinite loop) --- src/vecdb/vectorizer_service.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/vecdb/vectorizer_service.rs b/src/vecdb/vectorizer_service.rs index f8fd56d8f..1e8f418ae 100644 --- a/src/vecdb/vectorizer_service.rs +++ b/src/vecdb/vectorizer_service.rs @@ -111,14 +111,14 @@ async fn vectorize_thread( info!("update_indexed_file_paths: it took {:.3}s", t0.elapsed().as_secs_f64()); reported_vecdb_complete = true; - // By the time we do not create index 'cause it hurts quality of retrieval + // For now we do not create index 'cause it hurts quality of retrieval // info!("VECDB Creating index"); // match vecdb_handler_ref.lock().await.create_index().await { // Ok(_) => info!("VECDB CREATED INDEX"), // Err(err) => info!("VECDB Error creating index: {}", err) // } write!(std::io::stderr(), "VECDB COMPLETE\n").unwrap(); - info!("VECDB COMPLETE"); // you can see "VECDB COMPLETE" sometimes faster vs logs + info!("VECDB COMPLETE"); // you can see stderr "VECDB COMPLETE" sometimes faster vs logs } tokio::time::sleep(tokio::time::Duration::from_millis(5000)).await; continue; @@ -202,7 +202,7 @@ async fn vectorize_thread( } Err(e) => { info!("Error retrieving embeddings for {}: {}", data_res.file_path.to_str().unwrap(), e); - queue.lock().await.push_back(data_res.file_path); // push it back again + // queue.lock().await.push_back(data_res.file_path); // push it back again } } } From faca349811eb90cf21e7506eed27a6f1e8d08426 Mon Sep 17 00:00:00 2001 From: Oleg Klimov Date: Thu, 8 Feb 2024 15:20:58 +0100 Subject: [PATCH 52/54] logs --- src/vecdb/vecdb.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vecdb/vecdb.rs b/src/vecdb/vecdb.rs index 9aeb4d5dd..26928d150 100644 --- a/src/vecdb/vecdb.rs +++ b/src/vecdb/vecdb.rs @@ -292,7 +292,7 @@ impl VecdbSearch for VecDb { info!("search itself {:.3}s", t1.elapsed().as_secs_f64()); for rec in results.iter() { let last_30_chars = crate::nicer_logs::last_n_chars(&rec.file_path.display().to_string(), 30); - info!("distance {:.3}, found ...{}:{}-{}, ", rec.distance, last_30_chars, rec.start_line, rec.end_line); + info!("distance {:.3}, found {}:{}-{}, ", rec.distance, last_30_chars, rec.start_line, rec.end_line); } let t2 = std::time::Instant::now(); handler_locked.update_record_statistic(results.clone()).await; From aaeb76f3e8efdda531364c8f750f648208be4396 Mon Sep 17 00:00:00 2001 From: Valeryi Date: Thu, 8 Feb 2024 19:16:51 +0000 Subject: [PATCH 53/54] changed distance to normalized_damerau_levenshtein --- src/at_commands/at_params.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/at_commands/at_params.rs b/src/at_commands/at_params.rs index 516dcd8f0..dd46d4804 100644 --- a/src/at_commands/at_params.rs +++ b/src/at_commands/at_params.rs @@ -1,7 +1,7 @@ use crate::at_commands::at_commands::{AtCommandsContext, AtParam}; use async_trait::async_trait; use itertools::Itertools; -use strsim::jaro_winkler; +use strsim::normalized_damerau_levenshtein; #[derive(Debug)] pub struct AtParamFilePath { @@ -39,7 +39,7 @@ impl AtParam for AtParamFilePath { let mapped_paths = index_file_paths.iter().map(|f| { ( f, - jaro_winkler( + normalized_damerau_levenshtein( if value.starts_with("/") { f.to_str().unwrap() } else { From 5b8ffcb16a2b8e338ea01b063cae3dab2c756bba Mon Sep 17 00:00:00 2001 From: JegernOUTT Date: Thu, 15 Feb 2024 01:06:47 +1030 Subject: [PATCH 54/54] rebase fix --- src/global_context.rs | 3 --- src/main.rs | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/global_context.rs b/src/global_context.rs index 0d54ced39..dc32319bf 100644 --- a/src/global_context.rs +++ b/src/global_context.rs @@ -21,10 +21,7 @@ use async_trait::async_trait; use crate::caps::CodeAssistantCaps; use crate::completion_cache::CompletionCache; use crate::telemetry::telemetry_structs; -use crate::custom_error::ScratchError; -use hyper::StatusCode; use tokio::sync::mpsc::Permit; -use tower_lsp::lsp_types::WorkspaceFolder; use crate::receive_workspace_changes::Document; use crate::vecdb::vecdb::VecDb; diff --git a/src/main.rs b/src/main.rs index 03d8356cb..f4acca239 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,6 @@ use std::io::Write; use tokio::task::JoinHandle; -use tracing::info; +use tracing::{info, Level}; use tracing_appender; use crate::background_tasks::start_background_tasks;