diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 709dd23ff..9f5178bf9 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -34,14 +34,14 @@ jobs: - os: windows-latest target: aarch64-pc-windows-msvc code-target: win32-arm64 - - os: ubuntu-20.04 + - os: ubuntu-22.04 target: x86_64-unknown-linux-gnu code-target: linux-x64 container: ubuntu:18.04 - - os: ubuntu-20.04 + - os: ubuntu-22.04 target: aarch64-unknown-linux-gnu code-target: linux-arm64 - - os: ubuntu-20.04 + - os: ubuntu-22.04 target: arm-unknown-linux-gnueabihf code-target: linux-armhf - os: macos-11 @@ -68,7 +68,12 @@ jobs: if: matrix.container == 'ubuntu:18.04' shell: bash run: | - apt-get update && apt-get install -y build-essential curl libssl-dev pkg-config + apt-get update && apt install -y build-essential manpages-dev software-properties-common + add-apt-repository ppa:ubuntu-toolchain-r/test + apt-get install -y sudo build-essential curl libssl-dev pkg-config wget git zlib1g zlib1g-dev gcc-11 g++-11 + unlink /usr/bin/gcc && ln -s /usr/bin/gcc-11 /usr/bin/gcc && unlink /usr/bin/g++ && ln -s /usr/bin/g++-11 /usr/bin/g++ + wget https://github.com/Kitware/CMake/releases/download/v3.20.0/cmake-3.20.0.tar.gz + tar -zxvf cmake-3.20.0.tar.gz && cd cmake-3.20.0 && ./bootstrap && make && make install && cd.. && cmake --version curl --proto '=https' --tlsv1.2 --retry 10 --retry-connrefused -fsSL "https://sh.rustup.rs" | sh -s -- --profile minimal --default-toolchain none -y echo "${CARGO_HOME:-$HOME/.cargo}/bin" >> $GITHUB_PATH @@ -88,8 +93,22 @@ jobs: - name: Install ARM target toolchain if: matrix.target == 'arm-unknown-linux-gnueabihf' - run: sudo apt-get install gcc-multilib-arm-linux-gnueabihf - + run: sudo apt-get install gcc-arm-linux-gnueabihf + + - name: Install toolchain dependencies (MacOS) + if: startsWith(matrix.os, 'macos') + run: brew install protobuf + + - name: Install toolchain dependencies (Windows) + if: startsWith(matrix.os, 'windows') + run: choco install protoc + + - name: Build and install protobuf library + if: matrix.target == 'arm-unknown-linux-gnueabihf' || matrix.target == 'x86_64-unknown-linux-gnu' || matrix.target == 'aarch64-unknown-linux-gnu' || matrix.target == 'x86_64-unknown-linux-musl' + run: | + git clone --depth=1 --recursive -b v25.1 https://github.com/protocolbuffers/protobuf + cd protobuf && cmake . && cmake --build . --parallel 10 && sudo cmake --install . + - name: Build and test artifact if: matrix.target == 'x86_64-unknown-linux-gnu' run: | @@ -131,7 +150,12 @@ jobs: steps: - name: Install dependencies - run: apk add --no-cache git clang lld musl-dev nodejs npm openssl-dev pkgconfig g++ + run: apk add --no-cache sudo git clang lld musl-dev nodejs npm openssl-dev pkgconfig g++ make cmake zlib linux-headers + + - name: Build and install protobuf library + run: | + git clone --depth=1 --recursive -b v25.1 https://github.com/protocolbuffers/protobuf + cd protobuf && cmake . && cmake --build . --parallel 10 && sudo cmake --install . - name: Checkout repository uses: actions/checkout@v3 diff --git a/Cargo.toml b/Cargo.toml index cae702a6f..bbf95d065 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,3 +38,25 @@ async-trait = "0.1.73" similar = "2.3.0" axum = "0.6.20" uuid = { version = "1", features = ["v4"] } + +arrow = "47.0.0" +arrow-array = "47.0.0" +arrow-schema= "47.0.0" +async_once= "0.2.6" +async-process = "2.0.1" +itertools = "0.11.0" +lance = "=0.9.0" +lance-linalg = "=0.9.0" +lance-index = "=0.9.0" +log = "0.4.20" +md5 = "0.7" +mockito = "0.28.0" +notify = "6.1.1" +rusqlite = { version = "0.30.0", features = ["bundled"] } +tempfile = "3.8.1" +time = "0.3.30" +tokio-rusqlite = "0.5.0" +vectordb = "=0.4.0" +walkdir = "2.3" +which = "5.0.0" +strsim = "0.8.0" diff --git a/examples/chat_at_command_completion.py b/examples/chat_at_command_completion.py new file mode 100644 index 000000000..a93f77a01 --- /dev/null +++ b/examples/chat_at_command_completion.py @@ -0,0 +1,48 @@ +import json +import requests +import termcolor + + +url = "http://localhost:8001/v1/at-command-completion" + + +def test_at_command_completion(query): + query_real = query.replace("", "") + payload = json.dumps({ + "query": query_real, + "cursor": query.find(""), + "top_n": 3, + }) + response = requests.post(url, data=payload) + print(payload) + print(termcolor.colored(response.text, 'red')) + j = json.loads(response.text) + r = j["replace"] + if len(j["completions"]) > 0: + query_completed = query_real[:r[0]] + j["completions"][0] + query_real[r[1]:] + print(query_completed) + else: + print("no completions") + + +test_at_command_completion(""" +other line -3 +other line -2 +other line -1 +@file deltadelta.rs +other line 1 +other line 2 +other line 3 +""") + +test_at_command_completion(""" +other line -3 +other line -2 +other line -1 +@work +other line 1 +other line 2 +other line 3 +""") + + diff --git a/examples/chat_with_at_command.py b/examples/chat_with_at_command.py new file mode 100644 index 000000000..ff3e4df9c --- /dev/null +++ b/examples/chat_with_at_command.py @@ -0,0 +1,66 @@ +import requests, json, termcolor + + +initial_messages = [ +{"role": "user", "content": + "@workspace definition of DeltaDeltaChatStreamer\n" + + # "@file /home/user/.refact/tmp/unpacked-files/refact-lsp/src/scratchpads/chat_utils_deltadelta.rs\n" + + "summarize the code in one paragraph" +}, +] + +def ask_chat(messages): + response = requests.post( + "http://127.0.0.1:8001/v1/chat", + json={ + "messages": messages, + "temperature": 0.1, + "max_tokens": 300, + "model": "gpt-3.5-turbo-0125", + }, + headers={ + "Content-Type": "application/json", + "Authorization": "Bearer XXX", + }, + timeout=60, + ) + messages_back = [] + accum_content = "" + accum_role = "" + for x in response.text.splitlines(): + if not x.strip(): + continue + if not x.startswith("data: "): + print(x) + print("ERROR: unexpected response format") + continue + if x[6:].startswith("[DONE]"): + break + j = json.loads(x[6:]) + if "choices" in j: + # streaming + choice0 = j["choices"][0] + accum_role = choice0["delta"]["role"] + accum_content += choice0["delta"]["content"] + else: + # content/role without streaming, replacing the last user message + messages_back.append({"role": j["role"], "content": j["content"]}) + if accum_role: + messages_back.append({"role": accum_role, "content": accum_content}) + return messages_back + + +def example_single_response(): + messages_back = ask_chat(initial_messages) + for msgdict in messages_back: + print(termcolor.colored(msgdict["role"], "blue")) + if msgdict["role"] == "context_file": + for x in json.loads(msgdict["content"]): + print("%s:%i-%i" % (x["file_name"], x["line1"], x["line2"])) + else: + print(termcolor.colored(msgdict["content"], "white")) + + +if __name__ == "__main__": + example_single_response() + diff --git a/examples/http_chat_passthrough.sh b/examples/http_chat_passthrough.sh new file mode 100644 index 000000000..041298cbf --- /dev/null +++ b/examples/http_chat_passthrough.sh @@ -0,0 +1,13 @@ +curl http://127.0.0.1:8001/v1/chat -k \ + -H 'Content-Type: application/json' \ + -d '{ + "messages": [ + {"role": "user", "content": "Who is Bill Clinton? What is his favorite programming language?"} + ], + "model": "gpt-3.5-turbo", + "parameters": { + "temperature": 0.1, + "max_new_tokens": 10 + } +}' + diff --git a/examples/http_completion.sh b/examples/http_completion.sh index dcc64cf03..35b4f32d4 100644 --- a/examples/http_completion.sh +++ b/examples/http_completion.sh @@ -11,6 +11,7 @@ curl http://127.0.0.1:8001/v1/code-completion -k \ "multiline": true }, "stream": false, + "no_cache": true, "parameters": { "temperature": 0.1, "max_new_tokens": 20 diff --git a/examples/http_vecdb_search.sh b/examples/http_vecdb_search.sh new file mode 100644 index 000000000..e67e79c73 --- /dev/null +++ b/examples/http_vecdb_search.sh @@ -0,0 +1,7 @@ +curl http://127.0.0.1:8001/v1/vdb-search -k \ + -H 'Content-Type: application/json' \ + -d '{ + "query": "Hello world", + "top_n": 3 +}' + diff --git a/examples/http_vecdb_status.sh b/examples/http_vecdb_status.sh new file mode 100644 index 000000000..4a074de63 --- /dev/null +++ b/examples/http_vecdb_status.sh @@ -0,0 +1,2 @@ +curl http://127.0.0.1:8001/v1/vdb-status + diff --git a/examples/luke1.py b/examples/luke1.py new file mode 100644 index 000000000..52abcdb21 --- /dev/null +++ b/examples/luke1.py @@ -0,0 +1,100 @@ +import termcolor +import chat_with_at_command + + +my_prompt = """ +You are a bot good at explaining the purpose for the given code. + +STEPS: + +In the [ORIGINAL_CODE_STEP] user will provide code surrounding the code snippet in question, and then the snippet itself will start with 🔥code and backquotes. + +In the [PROVIDE_COMMANDS_STEP] you have to ask for an extra context to completely understand the 🔥code and it's role in the project. +Run several commands in a single message. Don't write any explanations on this step. +Write the number of commands you plan to issue as a first line of your response, +and then write all the commands. +Commands available: + +🔍SEARCH to find more information in other source files in the project or documentation. It's good for looking up definitions and usage. + +🔍FILE to dump whole file text. + +Ask for definitions of types used in the 🔥code. +Ask for usages of the class or function defined in the 🔥code. +Don't look up symbols you already have. + +An examples of commands: + +🔍SEARCH usages of function f + +🔍SEARCH definition of Type2 + +🔍FILE repo1/test_file.cpp + +In the [GENERATE_DOCUMENTATION_STEP] you have to generate an explanation of the 🔥code. +Answer questions "why it exists", "how does it fit into broader context". Don't explain line-by-line. Don't explain class data fields. +""" + +to_explain = """pub struct DeltaDeltaChatStreamer { + pub delta1: String, + pub delta2: String, + pub finished: bool, + pub stop_list: Vec, + pub role: String, +} +""" + +initial_messages = [ +{"role": "system", "content": my_prompt}, +{"role": "user", "content": + "[ORIGINAL_CODE_STEP]\n" + + "@file /home/user/.refact/tmp/unpacked-files/refact-lsp/src/scratchpads/chat_utils_deltadelta.rs\n" + + "Why this 🔥code exists:\n```\n[CODE]```\n".replace("[CODE]", to_explain) + + "[PROVIDE_COMMANDS_STEP]\n"}, +] + + +def rewrite_assistant_says_to_at_commands(ass): + out = "" + for s in ass.splitlines(): + s = s.strip() + if not s: + continue + if s.startswith("🔍SEARCH"): + out += "@workspace " + s[8:] + "\n" + if s.startswith("🔍FILE"): + out += "@file " + s[6:] + "\n" + return out + + +def dialog_turn(messages): + for msgdict in messages: + print(termcolor.colored(msgdict["role"], "blue")) + print(termcolor.colored(msgdict["content"], "green")) + messages_back = chat_with_at_command.ask_chat(messages) + for msgdict in messages_back: + print(termcolor.colored(msgdict["role"], "blue")) + print(termcolor.colored(msgdict["content"], "white")) + + assistant_says = messages_back[-1]["content"] + messages_without_last_user = messages[:-1] + next_step_messages = messages_without_last_user + messages_back + automated_new_user = rewrite_assistant_says_to_at_commands(assistant_says) + if not automated_new_user: + return next_step_messages, False + automated_new_user += "[GENERATE_DOCUMENTATION_STEP]" + next_step_messages.append({"role": "user", "content": automated_new_user}) + return next_step_messages, True + + +def do_all(): + messages = initial_messages.copy() + for step in range(2): + print("-"*40, "STEP%02d" % step, "-"*40) + messages, need_automated_post = dialog_turn(messages) + if not need_automated_post: + break + + +do_all() + diff --git a/src/at_commands/at_commands.rs b/src/at_commands/at_commands.rs new file mode 100644 index 000000000..cf4054e87 --- /dev/null +++ b/src/at_commands/at_commands.rs @@ -0,0 +1,61 @@ +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::Mutex as AMutex; +use tokio::sync::RwLock as ARwLock; +use async_trait::async_trait; + +use crate::global_context::GlobalContext; +use crate::at_commands::at_workspace::AtWorkspace; +use crate::at_commands::at_file::AtFile; +use crate::call_validation::ChatMessage; + +pub struct AtCommandsContext { + pub global_context: Arc>, + pub at_commands: HashMap>>>, +} + +impl AtCommandsContext { + pub async fn new(global_context: Arc>) -> Self { + AtCommandsContext { + global_context, + at_commands: at_commands_dict().await, + } + } +} + +#[async_trait] +pub trait AtCommand: Send + Sync { + fn name(&self) -> &String; + fn params(&self) -> &Vec>>; + async fn are_args_valid(&self, args: &Vec, context: &AtCommandsContext) -> Vec; + async fn can_execute(&self, args: &Vec, context: &AtCommandsContext) -> bool; + async fn execute(&self, query: &String, args: &Vec, top_n: usize, context: &AtCommandsContext) -> Result; +} + +#[async_trait] +pub trait AtParam: Send + Sync { + fn name(&self) -> &String; + async fn is_value_valid(&self, value: &String, context: &AtCommandsContext) -> bool; + async fn complete(&self, value: &String, context: &AtCommandsContext, top_n: usize) -> Vec; +} + +pub struct AtCommandCall { + pub command: Arc>>, + pub args: Vec, +} + +impl AtCommandCall { + pub fn new(command: Arc>>, args: Vec) -> Self { + AtCommandCall { + command, + args + } + } +} + +pub async fn at_commands_dict() -> HashMap>>> { + return HashMap::from([ + ("@workspace".to_string(), Arc::new(AMutex::new(Box::new(AtWorkspace::new()) as Box))), + ("@file".to_string(), Arc::new(AMutex::new(Box::new(AtFile::new()) as Box))), + ]); +} diff --git a/src/at_commands/at_file.rs b/src/at_commands/at_file.rs new file mode 100644 index 000000000..5821009d9 --- /dev/null +++ b/src/at_commands/at_file.rs @@ -0,0 +1,78 @@ +use std::sync::Arc; +use async_trait::async_trait; +use serde_json::json; +use crate::at_commands::at_commands::{AtCommand, AtCommandsContext, AtParam}; +use crate::at_commands::at_params::AtParamFilePath; +use tokio::sync::Mutex as AMutex; +use crate::call_validation::{ChatMessage, ContextFile}; +use crate::vecdb::vecdb::FileSearchResult; + +pub struct AtFile { + pub name: String, + pub params: Vec>>, +} + +impl AtFile { + pub fn new() -> Self { + AtFile { + name: "@file".to_string(), + params: vec![ + Arc::new(AMutex::new(AtParamFilePath::new())) + ], + } + } +} + +#[async_trait] +impl AtCommand for AtFile { + fn name(&self) -> &String { + &self.name + } + fn params(&self) -> &Vec>> { + &self.params + } + async fn are_args_valid(&self, args: &Vec, context: &AtCommandsContext) -> Vec { + let mut results = Vec::new(); + for (arg, param) in args.iter().zip(self.params.iter()) { + let param = param.lock().await; + results.push(param.is_value_valid(arg, context).await); + } + results + } + + async fn can_execute(&self, args: &Vec, context: &AtCommandsContext) -> bool { + if self.are_args_valid(args, context).await.iter().any(|&x| x == false) || args.len() != self.params.len() { + return false; + } + return true; + } + + async fn execute(&self, _query: &String, args: &Vec, _top_n: usize, context: &AtCommandsContext) -> Result { + let can_execute = self.can_execute(args, context).await; + if !can_execute { + return Err("incorrect arguments".to_string()); + } + match *context.global_context.read().await.vec_db.lock().await { + Some(ref db) => { + let file_path = match args.get(0) { + Some(x) => x, + None => return Err("no file path".to_string()), + }; + let path_and_text: FileSearchResult = db.get_file_orig_text(file_path.clone()).await; + let mut vector_of_context_file: Vec = vec![]; + vector_of_context_file.push(ContextFile { + file_name: path_and_text.file_path, + file_content: path_and_text.file_text.clone(), + line1: 0, + line2: path_and_text.file_text.lines().count() as i32, + usefullness: 100.0, + }); + Ok(ChatMessage { + role: "context_file".to_string(), + content: json!(vector_of_context_file).to_string(), + }) + } + None => Err("vecdb is not available".to_string()) + } + } +} diff --git a/src/at_commands/at_params.rs b/src/at_commands/at_params.rs new file mode 100644 index 000000000..dd46d4804 --- /dev/null +++ b/src/at_commands/at_params.rs @@ -0,0 +1,63 @@ +use crate::at_commands::at_commands::{AtCommandsContext, AtParam}; +use async_trait::async_trait; +use itertools::Itertools; +use strsim::normalized_damerau_levenshtein; + +#[derive(Debug)] +pub struct AtParamFilePath { + pub name: String, +} + +impl AtParamFilePath { + pub fn new() -> Self { + Self { + name: "file_path".to_string() + } + } +} + +#[async_trait] +impl AtParam for AtParamFilePath { + fn name(&self) -> &String { + &self.name + } + async fn is_value_valid(&self, value: &String, context: &AtCommandsContext) -> bool { + match *context.global_context.read().await.vec_db.lock().await { + Some(ref db) => { + let index_file_paths = db.get_indexed_file_paths().await; + let index_file_paths = index_file_paths.lock().await; + index_file_paths.iter().any(|path| path.to_str().unwrap() == value) + } + None => false, + } + } + async fn complete(&self, value: &String, context: &AtCommandsContext, top_n: usize) -> Vec { + match *context.global_context.read().await.vec_db.lock().await { + Some(ref db) => { + let index_file_paths = db.get_indexed_file_paths().await; + let index_file_paths = index_file_paths.lock().await; + let mapped_paths = index_file_paths.iter().map(|f| { + ( + f, + normalized_damerau_levenshtein( + if value.starts_with("/") { + f.to_str().unwrap() + } else { + f.file_name().unwrap().to_str().unwrap() + }, + &value.to_string() + ) + ) + }); + let sorted_paths = mapped_paths + .sorted_by(|(_, dist1), (_, dist2)| dist1.partial_cmp(dist2).unwrap()) + .rev() + .map(|(path, _)| path.to_str().unwrap().to_string()) + .take(top_n) + .collect::>(); + return sorted_paths; + } + None => vec![] + } + } +} diff --git a/src/at_commands/at_workspace.rs b/src/at_commands/at_workspace.rs new file mode 100644 index 000000000..c863dfc71 --- /dev/null +++ b/src/at_commands/at_workspace.rs @@ -0,0 +1,81 @@ +use std::sync::Arc; +use async_trait::async_trait; +use serde_json::json; +use crate::at_commands::at_commands::{AtCommand, AtCommandsContext, AtParam}; +use tokio::sync::Mutex as AMutex; +use crate::call_validation::{ChatMessage, ContextFile}; +use crate::vecdb::structs::{Record, VecdbSearch}; + + +pub struct AtWorkspace { + pub name: String, + pub params: Vec>>, +} + +impl AtWorkspace { + pub fn new() -> Self { + AtWorkspace { + name: "@workspace".to_string(), + params: vec![], + } + } +} + +fn results2message(results: &Vec) -> ChatMessage { + let mut vector_of_context_file: Vec = vec![]; + for i in 0..results.len() { + let r = &results[i]; + vector_of_context_file.push(ContextFile { + file_name: r.file_path.to_str().unwrap().to_string(), + file_content: r.window_text.clone(), + line1: r.start_line as i32, + line2: r.end_line as i32, + usefullness: 100.0 / ((i + 1) as f32), + }); + } + ChatMessage { + role: "context_file".to_string(), + content: json!(vector_of_context_file).to_string(), + } +} + +#[async_trait] +impl AtCommand for AtWorkspace { + fn name(&self) -> &String { + &self.name + } + + fn params(&self) -> &Vec>> + { + &self.params + } + + async fn are_args_valid(&self, args: &Vec, context: &AtCommandsContext) -> Vec { + let mut results = Vec::new(); + for (arg, param) in args.iter().zip(self.params.iter()) { + let param = param.lock().await; + results.push(param.is_value_valid(arg, context).await); + } + results + } + + async fn can_execute(&self, _args: &Vec, _context: &AtCommandsContext) -> bool { + return true; + } + + async fn execute(&self, query: &String, args: &Vec, top_n: usize, context: &AtCommandsContext) -> Result { + match *context.global_context.read().await.vec_db.lock().await { + Some(ref db) => { + let mut db_query = args.join(" "); + if db_query.is_empty() { + db_query = query.clone(); + } + let search_result = db.search(db_query, top_n).await?; + let mut results = search_result.results.clone(); + results.dedup_by(|a, b| a.file_path == b.file_path && a.window_text == b.window_text); + Ok(results2message(&results)) + } + None => Err("vecdb is not available".to_string()) + } + } +} diff --git a/src/at_commands/mod.rs b/src/at_commands/mod.rs new file mode 100644 index 000000000..070648a15 --- /dev/null +++ b/src/at_commands/mod.rs @@ -0,0 +1,6 @@ +pub mod at_commands; +pub mod at_file; +pub mod at_workspace; +pub mod at_params; +pub mod utils; +pub mod query; diff --git a/src/at_commands/query.rs b/src/at_commands/query.rs new file mode 100644 index 000000000..549ea3b8f --- /dev/null +++ b/src/at_commands/query.rs @@ -0,0 +1,84 @@ +#[derive(Clone)] +pub struct QueryLine { + pub value: String, + pub cursor_line_start: i64, + pub args: Vec, // the first argument is the command, the rest are arguments +} + +impl QueryLine { + pub fn new( + value: String, + cursor_rel: i64, + cursor_line_start: i64 + ) -> Self { + QueryLine { + value: value.clone(), + cursor_line_start, + args: parse_args_from_line(&value).iter_mut().map(|x| { + x.pos2 += 1; + x.focused = cursor_rel >= x.pos1 && cursor_rel <= x.pos2; + x.pos1 += cursor_line_start; + x.pos2 += cursor_line_start; + x.clone() + }).collect(), + } + } + + pub fn command(&self) -> Option<&QueryLineArg> { + self.args.first() + } + + pub fn get_args(&self) -> Vec<&QueryLineArg> { + self.args.iter().skip(1).collect() + } +} + +#[derive(Clone)] +pub struct QueryLineArg { + pub value: String, + pub pos1: i64, + pub pos2: i64, + pub focused: bool, + pub type_name: String, +} + +fn parse_args_from_line(line: &String) -> Vec { + let mut pos1: i64 = -1; + let mut value: String = "".to_string(); + let mut args: Vec = vec![]; + for (idx, ch) in line.chars().enumerate() { + let idx = idx as i64; + if value.is_empty() && ch.to_string() != " " { + pos1 = idx; + } + + if ch.to_string() != " " { + value.push(ch); + } + + if pos1 != -1 && (ch.to_string() == " " || idx == (line.len() -1) as i64) { + args.push(QueryLineArg{ + value: value.clone(), + pos1, + pos2: { + if ch.to_string() != " " { + idx + } else { + idx - 1 + } + }, + focused: false, + type_name: { + if value.starts_with("@") { + "command".to_string() + } else { + "arg".to_string() + } + } + }); + pos1 = -1; + value = "".to_string(); + } + } + args +} diff --git a/src/at_commands/utils.rs b/src/at_commands/utils.rs new file mode 100644 index 000000000..cae9eb51e --- /dev/null +++ b/src/at_commands/utils.rs @@ -0,0 +1,78 @@ +use std::sync::Arc; +use crate::at_commands::at_commands::{AtCommandCall, AtCommandsContext, AtParam}; +use tracing::info; +use tokio::sync::Mutex as AMutex; + + +pub async fn find_valid_at_commands_in_query( + query: &mut String, + context: &AtCommandsContext, +) -> Vec { + let mut results = vec![]; + let mut valid_command_lines = vec![]; + for (idx, line) in query.lines().enumerate() { + let line_words: Vec<&str> = line.split_whitespace().collect(); + let q_cmd_args = line_words.iter().skip(1).map(|x|x.to_string()).collect::>(); + + let q_cmd = match line_words.first() { + Some(x) => x, + None => continue, + }; + + let (_, cmd) = match context.at_commands.iter().find(|&(k, _v)| k == q_cmd) { + Some(x) => x, + None => continue, + }; + let can_execute = cmd.lock().await.can_execute(&q_cmd_args, context).await; + let q_cmd_args = match correct_arguments_if_needed(cmd.lock().await.params(), &q_cmd_args, can_execute, context).await { + Ok(x) => x, + Err(e) => { + info!("command {:?} is not executable with arguments {:?}; error: {:?}", q_cmd, q_cmd_args, e); + continue; + } + }; + + info!("command {:?} is perfectly good", q_cmd); + results.push(AtCommandCall::new(Arc::clone(&cmd), q_cmd_args.clone())); + valid_command_lines.push(idx); + } + // remove the lines that are valid commands from query + *query = query.lines().enumerate() + .filter(|(idx, _line)| !valid_command_lines.contains(idx)) + .map(|(_idx, line)| line) + .collect::>().join("\n"); + results +} + +pub async fn correct_arguments_if_needed( + params: &Vec>>, + args: &Vec, + can_execute: bool, + context: &AtCommandsContext, +) -> Result, String> { + if can_execute { + return Ok(args.clone()); + } + if params.len() != args.len() { + return Err(format!("incorrect number of arguments: {} given; {} required", args.len(), params.len())); + } + let mut args_new = vec![]; + for (param, arg) in params.iter().zip(args.iter()) { + let param = param.lock().await; + if param.is_value_valid(arg, context).await { + args_new.push(arg.clone()); + continue; + } + let completion = param.complete(arg, context, 1).await; + let arg_completed = match completion.get(0) { + Some(x) => x, + None => return Err(format!("arg '{}' is not valid and correction failed", arg)), + }; + if !param.is_value_valid(arg_completed, context).await { + return Err(format!("arg '{}' is not valid and correction failed", arg)); + } + info!("arg '{}' is corrected as '{}'", arg, arg_completed); + args_new.push(arg_completed.clone()); + } + Ok(args_new) +} diff --git a/src/background_tasks.rs b/src/background_tasks.rs index 42c349859..7558e41e8 100644 --- a/src/background_tasks.rs +++ b/src/background_tasks.rs @@ -1,13 +1,14 @@ +use std::iter::IntoIterator; use std::sync::Arc; use std::vec; use tokio::sync::RwLock as ARwLock; use tokio::task::JoinHandle; -use crate::global_context; -use crate::telemetry::basic_transmit; +use crate::vecdb; use crate::global_context::GlobalContext; use crate::snippets_transmit; +use crate::telemetry::basic_transmit; pub struct BackgroundTasksHolder { tasks: Vec>, @@ -21,7 +22,15 @@ impl BackgroundTasksHolder { } pub fn push_back(&mut self, task: JoinHandle<()>) { - self.tasks.push(task) + self.tasks.push(task); + } + + + pub fn extend(&mut self, tasks: T) + where + T: IntoIterator>, + { + self.tasks.extend(tasks); } pub async fn abort(self) { @@ -36,5 +45,6 @@ pub fn start_background_tasks(global_context: Arc>) -> Ba BackgroundTasksHolder::new(vec![ tokio::spawn(basic_transmit::telemetry_background_task(global_context.clone())), tokio::spawn(snippets_transmit::tele_snip_background_task(global_context.clone())), + tokio::spawn(vecdb::vecdb::vecdb_background_reload(global_context.clone())), ]) } diff --git a/src/call_validation.rs b/src/call_validation.rs index 03347800d..eeb7d4c75 100644 --- a/src/call_validation.rs +++ b/src/call_validation.rs @@ -171,13 +171,15 @@ mod tests { }; assert!(validate_post(post).is_err()); } - } #[derive(Debug, Serialize, Deserialize, Clone)] pub struct ContextFile { pub file_name: String, pub file_content: String, + pub line1: i32, + pub line2: i32, + pub usefullness: f32, // the higher the better } #[derive(Debug, Serialize, Deserialize, Clone)] diff --git a/src/caps.rs b/src/caps.rs index 8d3470def..08b52c6a4 100644 --- a/src/caps.rs +++ b/src/caps.rs @@ -44,6 +44,14 @@ pub struct CodeAssistantCaps { #[serde(default)] pub code_chat_models: HashMap, pub code_chat_default_model: String, + #[serde(default)] + pub default_embeddings_model: String, + #[serde(default)] + pub endpoint_embeddings_template: String, + #[serde(default)] + pub endpoint_embeddings_style: String, + #[serde(default)] + pub size_embeddings: i32, pub running_models: Vec, #[serde(default)] pub caps_version: i64, // need to reload if it increases on server, that happens when server configuration changes @@ -152,6 +160,16 @@ const KNOWN_MODELS: &str = r####" "similar_models": [ ] }, + "gpt-3.5-turbo-0125": { + "n_ctx": 16384, + "supports_scratchpads": { + "PASSTHROUGH": { + "default_system_message": "You are a coding assistant that outputs short answers, gives links to documentation." + } + }, + "similar_models": [ + ] + }, "gpt-4": { "n_ctx": 4096, "supports_scratchpads": { @@ -292,6 +310,12 @@ const HF_DEFAULT_CAPS: &str = r#" "cloud_name": "Hugging Face", "endpoint_template": "https://api-inference.huggingface.co/models/$MODEL", "endpoint_style": "hf", + + "default_embeddings_model": "BAAI/bge-small-en-v1.5", + "endpoint_embeddings_template": "https://api-inference.huggingface.co/models/$MODEL", + "endpoint_embeddings_style": "hf", + "size_embeddings": 384, + "tokenizer_path_template": "https://huggingface.co/$MODEL/resolve/main/tokenizer.json", "tokenizer_rewrite_path": { "meta-llama/Llama-2-70b-chat-hf": "TheBloke/Llama-2-70B-fp16" @@ -364,6 +388,7 @@ pub async fn load_caps( r1.endpoint_chat_passthrough = relative_to_full_url(&caps_url, &r1.endpoint_chat_passthrough)?; r1.telemetry_basic_dest = relative_to_full_url(&caps_url, &r1.telemetry_basic_dest)?; r1.telemetry_corrected_snippets_dest = relative_to_full_url(&caps_url, &r1.telemetry_corrected_snippets_dest)?; + r1.endpoint_embeddings_template = relative_to_full_url(&caps_url, &r1.endpoint_embeddings_template)?; info!("caps {} completion models", r1.code_completion_models.len()); info!("caps default completion model: \"{}\"", r1.code_completion_default_model); info!("caps {} chat models", r1.code_chat_models.len()); @@ -371,6 +396,7 @@ pub async fn load_caps( Ok(Arc::new(StdRwLock::new(r1))) } + fn relative_to_full_url( caps_url: &String, maybe_relative_url: &str, diff --git a/src/fetch_embedding.rs b/src/fetch_embedding.rs new file mode 100644 index 000000000..06af007b8 --- /dev/null +++ b/src/fetch_embedding.rs @@ -0,0 +1,59 @@ +use std::sync::Arc; + +use tokio::sync::Mutex as AMutex; +use tracing::error; + +use crate::forward_to_hf_endpoint::get_embedding_hf_style; +use crate::forward_to_openai_endpoint::get_embedding_openai_style; + +pub async fn get_embedding( + client: Arc>, + endpoint_embeddings_style: &String, + model_name: &String, + endpoint_template: &String, + text: String, + api_key: &String, +) -> Result, String> { + match endpoint_embeddings_style.to_lowercase().as_str() { + "hf" => get_embedding_hf_style(client, text, endpoint_template, model_name, api_key).await, + "openai" => get_embedding_openai_style(client, text, endpoint_template, model_name, api_key).await, + _ => { + error!("Invalid endpoint_embeddings_style: {}", endpoint_embeddings_style); + Err("Invalid endpoint_embeddings_style".to_string()) + } + } +} + + +// HF often returns 500 errors for no reason +pub async fn try_get_embedding( + client: Arc>, + endpoint_embeddings_style: &String, + model_name: &String, + endpoint_template: &String, + text: String, + api_key: &String, + max_retries: usize, +) -> Result, String> { + let sleep_on_failure_ms = 300; + let mut retries = 0; + loop { + retries += 1; + match get_embedding( + client.clone(), + endpoint_embeddings_style, + model_name, + endpoint_template, + text.clone(), + api_key, + ).await { + Ok(embedding) => return Ok(embedding), + Err(e) => { + tokio::time::sleep(tokio::time::Duration::from_millis(sleep_on_failure_ms)).await; + if retries > max_retries { + return Err(e); + } + } + } + } +} diff --git a/src/forward_to_hf_endpoint.rs b/src/forward_to_hf_endpoint.rs index 1593d5d34..191f88926 100644 --- a/src/forward_to_hf_endpoint.rs +++ b/src/forward_to_hf_endpoint.rs @@ -1,9 +1,14 @@ +use std::sync::Arc; + use reqwest::header::AUTHORIZATION; use reqwest::header::CONTENT_TYPE; use reqwest::header::HeaderMap; use reqwest::header::HeaderValue; use reqwest_eventsource::EventSource; +use serde::Serialize; use serde_json::json; +use tokio::sync::Mutex as AMutex; + use crate::call_validation::SamplingParameters; // Idea: use USER_AGENT @@ -47,7 +52,10 @@ pub async fn forward_to_hf_style_endpoint( if status_code != 200 { return Err(format!("{} status={} text {}", url, status_code, response_txt)); } - Ok(serde_json::from_str(&response_txt).unwrap()) // FIXME: unwrap + Ok(match serde_json::from_str(&response_txt) { + Ok(json) => json, + Err(e) => return Err(format!("{}: {}", url, e)), + }) } @@ -78,10 +86,49 @@ pub async fn forward_to_hf_style_endpoint_streaming( }); let builder = client.post(&url) - .headers(headers) - .body(data.to_string()); + .headers(headers) + .body(data.to_string()); let event_source: EventSource = EventSource::new(builder).map_err(|e| format!("can't stream from {}: {}", url, e) )?; Ok(event_source) } + + +#[derive(Serialize)] +struct EmbeddingsPayloadHF { + pub inputs: String, +} + + +pub async fn get_embedding_hf_style( + client: Arc>, + text: String, + endpoint_template: &String, + model_name: &String, + api_key: &String, +) -> Result, String> { + let payload = EmbeddingsPayloadHF { inputs: text }; + let url = endpoint_template.clone().replace("$MODEL", &model_name); + + let maybe_response = client.lock().await + .post(&url) + .bearer_auth(api_key.clone()) + .json(&payload) + .send() + .await; + + match maybe_response { + Ok(response) => { + if response.status().is_success() { + match response.json::>().await { + Ok(embedding) => Ok(embedding), + Err(err) => Err(format!("Failed to parse the response: {:?}", err)), + } + } else { + Err(format!("Failed to get a response: {:?}", response.status())) + } + } + Err(err) => Err(format!("Failed to send a request: {:?}", err)), + } +} diff --git a/src/forward_to_openai_endpoint.rs b/src/forward_to_openai_endpoint.rs index aae1a50de..407b69735 100644 --- a/src/forward_to_openai_endpoint.rs +++ b/src/forward_to_openai_endpoint.rs @@ -1,13 +1,17 @@ +use std::sync::Arc; + use reqwest::header::AUTHORIZATION; use reqwest::header::CONTENT_TYPE; use reqwest::header::HeaderMap; use reqwest::header::HeaderValue; use reqwest_eventsource::EventSource; +use serde::Serialize; use serde_json::json; -use crate::call_validation; -use crate::call_validation::SamplingParameters; +use tokio::sync::Mutex as AMutex; use tracing::info; +use crate::call_validation; +use crate::call_validation::SamplingParameters; pub async fn forward_to_openai_style_endpoint( save_url: &mut String, @@ -42,10 +46,10 @@ pub async fn forward_to_openai_style_endpoint( } // When cancelling requests, coroutine ususally gets aborted here on the following line. let req = client.post(&url) - .headers(headers) - .body(data.to_string()) - .send() - .await; + .headers(headers) + .body(data.to_string()) + .send() + .await; let resp = req.map_err(|e| format!("{}", e))?; let status_code = resp.status().as_u16(); let response_txt = resp.text().await.map_err(|e| @@ -92,8 +96,8 @@ pub async fn forward_to_openai_style_endpoint_streaming( data["prompt"] = serde_json::Value::String(prompt.to_string()); } let builder = client.post(&url) - .headers(headers) - .body(data.to_string()); + .headers(headers) + .body(data.to_string()); let event_source: EventSource = EventSource::new(builder).map_err(|e| format!("can't stream from {}: {}", url, e) )?; @@ -109,3 +113,51 @@ fn _passthrough_messages_to_json( let messages: Vec = serde_json::from_str(&messages_str).unwrap(); data["messages"] = serde_json::json!(messages); } + + +#[derive(Serialize)] +struct EmbeddingsPayloadOpenAI { + pub input: String, + pub model: String, +} + + +pub async fn get_embedding_openai_style( + client: Arc>, + text: String, + endpoint_template: &String, + model_name: &String, + api_key: &String, +) -> Result, String> { + let payload = EmbeddingsPayloadOpenAI { + input: text, + model: model_name.clone(), + }; + let url = endpoint_template.clone(); + let api_key_clone = api_key.clone(); + let response = client.lock().await + .post(&url) + .bearer_auth(api_key_clone.clone()) + .json(&payload) + .send() + .await + .map_err(|e| format!("Failed to send a request: {:?}", e))?; + + if !response.status().is_success() { + info!("get_embedding_openai_style: {:?}", response); + return Err(format!("get_embedding_openai_style: bad status: {:?}", response.status())); + } + + let json = response.json::() + .await + .map_err(|err| format!("get_embedding_openai_style: failed to parse the response: {:?}", err))?; + + // info!("get_embedding_openai_style: {:?}", json); + match &json["data"][0]["embedding"] { + serde_json::Value::Array(embedding) => { + serde_json::from_value(serde_json::Value::Array(embedding.clone())) + .map_err(|err| { format!("Failed to parse the response: {:?}", err) }) + } + _ => Err("Response is missing 'data[0].embedding' field or it's not an array".to_string()), + } +} diff --git a/src/global_context.rs b/src/global_context.rs index 5654482e6..dc32319bf 100644 --- a/src/global_context.rs +++ b/src/global_context.rs @@ -1,5 +1,5 @@ use std::collections::hash_map::DefaultHasher; -use tracing::info; +use tracing::{error, info}; use std::collections::HashMap; use std::future::Future; use std::hash::Hasher; @@ -8,19 +8,22 @@ use std::sync::{Arc, Mutex}; use std::sync::RwLock as StdRwLock; use tokio::sync::{AcquireError, Mutex as AMutex, Semaphore, SemaphorePermit}; use tokio::sync::RwLock as ARwLock; +use std::io::Write; +use tokio::signal; + use tokenizers::Tokenizer; use structopt::StructOpt; -use std::io::Write; +use hyper::StatusCode; +use tower_lsp::lsp_types::WorkspaceFolder; + +use crate::custom_error::ScratchError; use async_trait::async_trait; use crate::caps::CodeAssistantCaps; use crate::completion_cache::CompletionCache; use crate::telemetry::telemetry_structs; -use crate::vecdb_search::VecdbSearch; -use crate::custom_error::ScratchError; -use hyper::StatusCode; use tokio::sync::mpsc::Permit; -use tower_lsp::lsp_types::WorkspaceFolder; use crate::receive_workspace_changes::Document; +use crate::vecdb::vecdb::VecDb; #[derive(Debug, StructOpt, Clone)] @@ -31,7 +34,7 @@ pub struct CommandLine { pub address_url: String, #[structopt(long, short="k", default_value="", help="The API key to authenticate your requests, will appear in HTTP requests this binary makes.")] pub api_key: String, - #[structopt(long, short="p", default_value="8001", help="Bind 127.0.0.1: to listen for HTTP requests, such as /v1/code-completion, /v1/chat, /v1/caps.")] + #[structopt(long, short="p", default_value="0", help="Bind 127.0.0.1: to listen for HTTP requests, such as /v1/code-completion, /v1/chat, /v1/caps.")] pub http_port: u16, #[structopt(long, default_value="", help="End-user client version, such as version of VS Code plugin.")] pub enduser_client_version: String, @@ -47,6 +50,12 @@ pub struct CommandLine { pub insecure: bool, #[structopt(long, short="v", help="Verbose logging, lots of output")] pub verbose: bool, + #[structopt(long, help="Whether to use a vector database")] + pub vecdb: bool, + #[structopt(long, short = "f", default_value = "", help = "The path to jsonl file which contains filtered source files")] + pub files_set_path: String, + #[structopt(long, default_value = "", help = "Vecdb forced path")] + pub vecdb_forced_path: String, } impl CommandLine { @@ -65,7 +74,6 @@ pub struct LSPBackendDocumentState { pub workspace_folders: Arc>>>, } -// #[derive(Debug)] pub struct GlobalContext { pub cmdline: CommandLine, pub http_client: reqwest::Client, @@ -76,7 +84,7 @@ pub struct GlobalContext { pub tokenizer_map: HashMap< String, Arc>>, pub completions_cache: Arc>, pub telemetry: Arc>, - pub vecdb_search: Arc>>, + pub vec_db: Arc>>, pub ask_shutdown_sender: Arc>>, pub lsp_backend_document_state: LSPBackendDocumentState, } @@ -124,6 +132,7 @@ pub async fn try_load_caps_quickly_if_not_present( Ok(caps) }, Err(e) => { + error!("load caps failed: \"{}\"", e); return Err(ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, format!("server is not reachable: {}", e))); } } @@ -150,6 +159,49 @@ pub async fn look_for_piggyback_fields( } } +pub async fn block_until_signal(ask_shutdown_receiver: std::sync::mpsc::Receiver) { + let ctrl_c = async { + signal::ctrl_c() + .await + .expect("failed to install Ctrl+C handler"); + }; + + #[cfg(unix)] + let sigterm = async { + signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to install signal handler") + .recv() + .await; + }; + #[cfg(not(unix))] + let sigterm = std::future::pending::<()>(); + + #[cfg(unix)] + let sigusr1 = async { + signal::unix::signal(signal::unix::SignalKind::user_defined1()) + .expect("failed to install signal handler") + .recv() + .await; + }; + #[cfg(not(unix))] + let sigusr1 = std::future::pending::<()>(); + + tokio::select! { + _ = ctrl_c => { + info!("SIGINT signal received"); + }, + _ = sigterm => { + info!("SIGTERM signal received"); + }, + _ = sigusr1 => { + info!("SIGUSR1 signal received"); + }, + _ = tokio::task::spawn_blocking(move || ask_shutdown_receiver.recv()) => { + info!("graceful shutdown to store telemetry"); + } + } +} + pub async fn create_global_context( cache_dir: PathBuf, ) -> (Arc>, std::sync::mpsc::Receiver, CommandLine) { @@ -171,7 +223,7 @@ pub async fn create_global_context( tokenizer_map: HashMap::new(), completions_cache: Arc::new(StdRwLock::new(CompletionCache::new())), telemetry: Arc::new(StdRwLock::new(telemetry_structs::Storage::new())), - vecdb_search: Arc::new(AMutex::new(Box::new(crate::vecdb_search::VecdbSearchTest::new()))), + vec_db: Arc::new(AMutex::new(None)), ask_shutdown_sender: Arc::new(Mutex::new(ask_shutdown_sender)), lsp_backend_document_state: LSPBackendDocumentState { document_map: Arc::new(ARwLock::new(HashMap::new())), diff --git a/src/http.rs b/src/http.rs index 9cc078ad8..31dcb5cb9 100644 --- a/src/http.rs +++ b/src/http.rs @@ -3,7 +3,6 @@ use std::sync::Arc; use axum::{Extension, http::{StatusCode, Uri}, response::IntoResponse}; use hyper::Server; -use tokio::signal; use tokio::sync::RwLock as ARwLock; use tokio::task::JoinHandle; use tracing::{error, info}; @@ -19,35 +18,6 @@ async fn handler_404(path: Uri) -> impl IntoResponse { } -pub async fn shutdown_signal(ask_shutdown_receiver: std::sync::mpsc::Receiver) { - let ctrl_c = async { - signal::ctrl_c() - .await - .expect("failed to install Ctrl+C handler"); - }; - - #[cfg(unix)] - let terminate = async { - signal::unix::signal(signal::unix::SignalKind::terminate()) - .expect("failed to install signal handler") - .recv() - .await; - }; - - #[cfg(not(unix))] - let terminate = std::future::pending::<()>(); - - tokio::select! { - _ = ctrl_c => { - info!("SIGINT signal received"); - }, - _ = terminate => {}, - _ = tokio::task::spawn_blocking(move || ask_shutdown_receiver.recv()) => { - info!("graceful shutdown to store telemetry"); - } - } -} - pub async fn start_server( global_context: Arc>, ask_shutdown_receiver: std::sync::mpsc::Receiver, @@ -69,7 +39,7 @@ pub async fn start_server( let router = make_refact_http_server().layer(Extension(global_context.clone())); let server = builder .serve(router.into_make_service()) - .with_graceful_shutdown(shutdown_signal(ask_shutdown_receiver)); + .with_graceful_shutdown(crate::global_context::block_until_signal(ask_shutdown_receiver)); let resp = server.await.map_err(|e| format!("HTTP server error: {}", e)); if let Err(e) = resp { error!("server error: {}", e); diff --git a/src/http/routers/v1.rs b/src/http/routers/v1.rs index 043194189..049ab5364 100644 --- a/src/http/routers/v1.rs +++ b/src/http/routers/v1.rs @@ -20,6 +20,8 @@ use crate::http::routers::v1::telemetry_network::handle_v1_telemetry_network; use crate::http::routers::v1::lsp_like_handlers::handle_v1_lsp_initialize; use crate::http::routers::v1::lsp_like_handlers::handle_v1_lsp_did_change; use crate::http::utils::telemetry_wrapper; +use crate::http::routers::v1::vecdb::{handle_v1_vecdb_search, handle_v1_vecdb_status, handle_v1_vecdb_caps}; +use crate::http::routers::v1::at_commands::{handle_v1_command_completion, handle_v1_command_preview}; pub mod code_completion; pub mod chat; @@ -27,7 +29,9 @@ pub mod telemetry_network; pub mod snippet_accepted; pub mod caps; pub mod graceful_shutdown; -mod lsp_like_handlers; +pub mod lsp_like_handlers; +pub mod vecdb; +mod at_commands; pub fn make_v1_router() -> Router { Router::new() @@ -39,6 +43,12 @@ pub fn make_v1_router() -> Router { .route("/caps", telemetry_get!(handle_v1_caps)) .route("/graceful-shutdown", telemetry_get!(handle_v1_graceful_shutdown)) + .route("/vdb-search", telemetry_post!(handle_v1_vecdb_search)) + .route("/vdb-status", telemetry_get!(handle_v1_vecdb_status)) + .route("/vdb-caps", telemetry_get!(handle_v1_vecdb_caps)) + .route("/at-command-completion", telemetry_post!(handle_v1_command_completion)) + .route("/at-command-preview", telemetry_post!(handle_v1_command_preview)) + .route("/lsp-initialize", telemetry_post!(handle_v1_lsp_initialize)) .route("/lsp-did-changed", telemetry_post!(handle_v1_lsp_did_change)) } \ No newline at end of file diff --git a/src/http/routers/v1/at_commands.rs b/src/http/routers/v1/at_commands.rs new file mode 100644 index 000000000..cf13ae2d3 --- /dev/null +++ b/src/http/routers/v1/at_commands.rs @@ -0,0 +1,182 @@ +use axum::response::Result; +use axum::Extension; +use hyper::{Body, Response, StatusCode}; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use serde_json::{json, Value}; +use tokio::sync::RwLock as ARwLock; +use strsim::jaro_winkler; +use itertools::Itertools; +use crate::at_commands::at_commands::AtCommandsContext; +use crate::at_commands::query::QueryLine; + +use crate::custom_error::ScratchError; +use crate::global_context::GlobalContext; + +#[derive(Serialize, Deserialize, Clone)] +struct CommandCompletionPost { + query: String, + cursor: i64, + top_n: usize, +} +#[derive(Serialize, Deserialize, Clone)] +struct CommandCompletionResponse { + completions: Vec, + replace: (i64, i64), + is_cmd_executable: bool, +} + +#[derive(Serialize, Deserialize, Clone)] +struct CommandPreviewPost { + query: String, +} + +#[derive(Serialize, Deserialize, Clone)] +struct CommandPreviewResponse { + messages: Vec, +} + +pub async fn handle_v1_command_completion( + Extension(global_context): Extension>>, + body_bytes: hyper::body::Bytes, +) -> Result, ScratchError> { + let context = AtCommandsContext::new(global_context.clone()).await; + let post = serde_json::from_slice::(&body_bytes) + .map_err(|e| ScratchError::new(StatusCode::UNPROCESSABLE_ENTITY, format!("JSON problem: {}", e)))?; + + let (query_line_val, cursor_rel, cursor_line_start) = get_line_with_cursor(&post.query, post.cursor)?; + let query_line_val = query_line_val.chars().take(cursor_rel as usize).collect::(); + // info!(query_line_val); + // info!("cursor_rel: {}, cursor_line_start: {}", cursor_rel, cursor_line_start); + let query_line = QueryLine::new(query_line_val, cursor_rel, cursor_line_start); + // for arg in query_line.args.iter() { + // info!("value: {}, focused: {}, type_name: {}; pos1: {}; pos2: {}", arg.value, arg.focused, arg.type_name, arg.pos1, arg.pos2); + // } + let (completions, is_cmd_executable, pos1, pos2) = command_completion(&query_line, &context, post.cursor, post.top_n).await?; + + let response = CommandCompletionResponse { + completions: completions.clone(), + replace: (pos1, pos2), + is_cmd_executable, + }; + + Ok(Response::builder() + .status(StatusCode::OK) + .body(Body::from(serde_json::to_string(&response).unwrap())) + .unwrap()) +} + +pub async fn handle_v1_command_preview( + Extension(global_context): Extension>>, + body_bytes: hyper::body::Bytes, +) -> Result, ScratchError> { + let context = AtCommandsContext::new(global_context.clone()).await; + let post = serde_json::from_slice::(&body_bytes) + .map_err(|e| ScratchError::new(StatusCode::UNPROCESSABLE_ENTITY, format!("JSON problem: {}", e)))?; + let mut query = post.query.clone(); + let valid_commands = crate::at_commands::utils::find_valid_at_commands_in_query(&mut query, &context).await; + if valid_commands.is_empty() { + return Err(ScratchError::new(StatusCode::OK, "no valid commands in query".to_string())); + } + + let mut preview_msgs = vec![]; + for cmd in valid_commands { + match cmd.command.lock().await.execute(&post.query, &cmd.args, 5, &context).await { + Ok(msg) => { + preview_msgs.push(json!(msg)); + }, + Err(_) => {} + } + } + + let response = CommandPreviewResponse { + messages: preview_msgs, + }; + + Ok(Response::builder() + .status(StatusCode::OK) + .body(Body::from(serde_json::to_string(&response).unwrap())) + .unwrap()) +} + +fn get_line_with_cursor(query: &String, cursor: i64) -> Result<(String, i64, i64), ScratchError> { + let mut cursor_rel = cursor; + for line in query.lines() { + let line_length = line.len() as i64; + if cursor_rel <= line_length { + if !line.starts_with("@") { + return Err(ScratchError::new(StatusCode::OK, "no command provided".to_string())); + } + return Ok((line.to_string(), cursor_rel, cursor - cursor_rel)); + } + cursor_rel -= line_length + 1; // +1 to account for the newline character + } + return Err(ScratchError::new(StatusCode::OK, "cursor is incorrect".to_string())); +} + +async fn command_completion( + query_line: &QueryLine, + context: &AtCommandsContext, + cursor_abs: i64, + top_n: usize, +) -> Result<(Vec, bool, i64, i64), ScratchError> { // returns ([possible, completions], good_as_it_is) + let q_cmd = match query_line.command() { + Some(x) => x, + None => { return Err(ScratchError::new(StatusCode::OK, "no command given".to_string()));} + }; + + let (_, cmd) = match context.at_commands.iter().find(|&(k, _v)| k == &q_cmd.value) { + Some(x) => x, + None => { + if !q_cmd.focused { + return Err(ScratchError::new(StatusCode::OK, "incorrect command given".to_string())); + } + return Ok((command_completion_options(&q_cmd.value, &context, top_n).await, false, q_cmd.pos1, q_cmd.pos2)); + } + }; + if cmd.lock().await.can_execute(&query_line.get_args().iter().map(|x|x.value.clone()).collect(), context).await { + return Ok((vec![], true, -1, -1)); + } + + for (arg, param) in query_line.get_args().iter().zip(cmd.lock().await.params()) { + let is_valid = param.lock().await.is_value_valid(&arg.value, context).await; + if !is_valid { + return if arg.focused { + Ok((param.lock().await.complete(&arg.value, context, top_n).await, false, arg.pos1, arg.pos2)) + } else { + Err(ScratchError::new(StatusCode::OK, "invalid parameter".to_string())) + } + } + } + // if command is not focused, and the argument is empty we should make suggestions + if !q_cmd.focused { + match cmd.lock().await.params().get(query_line.get_args().len()) { + Some(param) => { + return Ok((param.lock().await.complete(&"".to_string(), context, top_n).await, false, cursor_abs, cursor_abs)); + }, + None => {} + } + } + + return Ok((vec![], false, -1, -1)); +} + + +async fn command_completion_options( + q_cmd: &String, + context: &AtCommandsContext, + top_n: usize, +) -> Vec { + let at_commands_names = context.at_commands.iter().map(|(name, _cmd)| name.clone()).collect::>(); + at_commands_names + .iter() + .filter(|command| command.starts_with(q_cmd)) + .map(|command| { + (command, jaro_winkler(&command, q_cmd)) + }) + .sorted_by(|(_, dist1), (_, dist2)| dist1.partial_cmp(dist2).unwrap()) + .rev() + .take(top_n) + .map(|(command, _)| command.clone()) + .collect() +} diff --git a/src/http/routers/v1/chat.rs b/src/http/routers/v1/chat.rs index 75943beb6..d20f73e71 100644 --- a/src/http/routers/v1/chat.rs +++ b/src/http/routers/v1/chat.rs @@ -55,7 +55,6 @@ pub async fn handle_v1_chat( let cx_locked = global_context.write().await; (cx_locked.http_client.clone(), cx_locked.cmdline.api_key.clone()) }; - let vecdb_search = global_context.read().await.vecdb_search.clone(); let mut scratchpad = scratchpads::create_chat_scratchpad( global_context.clone(), caps, @@ -63,7 +62,6 @@ pub async fn handle_v1_chat( chat_post.clone(), &scratchpad_name, &scratchpad_patch, - vecdb_search, ).await.map_err(|e| ScratchError::new(StatusCode::BAD_REQUEST, e) )?; diff --git a/src/http/routers/v1/code_completion.rs b/src/http/routers/v1/code_completion.rs index 7b00ed0ad..e2c4de0dc 100644 --- a/src/http/routers/v1/code_completion.rs +++ b/src/http/routers/v1/code_completion.rs @@ -4,10 +4,8 @@ use tokio::sync::RwLock as ARwLock; use axum::Extension; use axum::response::Result; -use futures_util::future::ok; use hyper::{Body, Response, StatusCode}; -use ropey::Rope; -use tracing::{error, info}; +use tracing::info; use crate::call_validation::{CodeCompletionPost, validate_post}; use crate::caps; @@ -81,6 +79,7 @@ pub async fn handle_v1_code_completion( } } + let vecdb_search = global_context.read().await.vec_db.clone(); let mut scratchpad = scratchpads::create_code_completion_scratchpad( global_context.clone(), caps, @@ -90,6 +89,7 @@ pub async fn handle_v1_code_completion( &scratchpad_patch, cache_arc.clone(), tele_storage.clone(), + vecdb_search ).await.map_err(|e| ScratchError::new(StatusCode::BAD_REQUEST, e) )?; diff --git a/src/http/routers/v1/lsp_like_handlers.rs b/src/http/routers/v1/lsp_like_handlers.rs index a444026d4..4b9ffeaad 100644 --- a/src/http/routers/v1/lsp_like_handlers.rs +++ b/src/http/routers/v1/lsp_like_handlers.rs @@ -1,3 +1,4 @@ +use std::path::PathBuf; use axum::Extension; use axum::response::Result; use hyper::{Body, Response, StatusCode}; @@ -7,6 +8,7 @@ use url::Url; use crate::custom_error::ScratchError; use crate::global_context::SharedGlobalContext; use crate::receive_workspace_changes; +use crate::vecdb::file_filter::retrieve_files_by_proj_folders; #[derive(Serialize, Deserialize, Clone)] @@ -22,13 +24,21 @@ struct PostDocument { pub async fn handle_v1_lsp_initialize( - Extension(_global_context): Extension, + Extension(global_context): Extension, body_bytes: hyper::body::Bytes, ) -> Result, ScratchError> { - let _post = serde_json::from_slice::(&body_bytes).map_err(|e| { + let post = serde_json::from_slice::(&body_bytes).map_err(|e| { ScratchError::new(StatusCode::BAD_REQUEST, format!("JSON problem: {}", e)) })?; + let files = retrieve_files_by_proj_folders( + post.project_roots.iter().map(|x| PathBuf::from(x.path())).collect() + ).await; + match *global_context.read().await.vec_db.lock().await { + Some(ref mut db) => db.add_or_update_files(files, true).await, + None => {} + }; + Ok(Response::builder() .status(StatusCode::OK) .body(Body::from(json!({"success": 1}).to_string())) diff --git a/src/http/routers/v1/vecdb.rs b/src/http/routers/v1/vecdb.rs new file mode 100644 index 000000000..73f2e1935 --- /dev/null +++ b/src/http/routers/v1/vecdb.rs @@ -0,0 +1,95 @@ +use axum::Extension; +use axum::response::Result; +use hyper::{Body, Response, StatusCode}; +use serde::{Deserialize, Serialize}; +use serde_json::json; + +use crate::custom_error::ScratchError; +use crate::global_context::SharedGlobalContext; +use crate::vecdb::structs::VecdbSearch; + +#[derive(Serialize, Deserialize, Clone)] +struct VecDBPost { + query: String, + top_n: usize, +} + +pub async fn handle_v1_vecdb_search( + Extension(global_context): Extension, + body_bytes: hyper::body::Bytes, +) -> Result, ScratchError> { + let post = serde_json::from_slice::(&body_bytes).map_err(|e| { + ScratchError::new(StatusCode::BAD_REQUEST, format!("JSON problem: {}", e)) + })?; + + let cx_locked = global_context.read().await; + let search_res = match *cx_locked.vec_db.lock().await { + Some(ref db) => db.search(post.query.to_string(), post.top_n).await, + None => { + return Err(ScratchError::new( + StatusCode::INTERNAL_SERVER_ERROR, "Vector db is not available".to_string() + )); + } + }; + + match search_res { + Ok(search_res) => { + let json_string = serde_json::to_string_pretty(&search_res).map_err(|e| { + ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, format!("JSON serialization problem: {}", e)) + })?; + Ok(Response::builder() + .status(StatusCode::OK) + .body(Body::from(json_string)) + .unwrap()) + } + Err(e) => { + Err(ScratchError::new(StatusCode::BAD_REQUEST, e)) + } + } +} + +pub async fn handle_v1_vecdb_status( + Extension(global_context): Extension, + _: hyper::body::Bytes, +) -> Result, ScratchError> { + let cx_locked = global_context.read().await; + let status = match *cx_locked.vec_db.lock().await { + Some(ref db) => match db.get_status().await { + Ok(status) => status, + Err(err) => { + return Err(ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, err)); + } + }, + None => { + return Err(ScratchError::new( + StatusCode::INTERNAL_SERVER_ERROR, "Vector db is not available".to_string() + )); + } + }; + Ok(Response::builder() + .status(StatusCode::OK) + .body(Body::from(serde_json::to_string_pretty(&status).unwrap())) + .unwrap()) +} + +pub async fn handle_v1_vecdb_caps( + Extension(global_context): Extension, + _: hyper::body::Bytes, +) -> Result, ScratchError> { + let caps = { + let cx_locked = global_context.read().await; + let db = cx_locked.vec_db.lock().await; + if let Some(ref db) = *db { + db.caps().await + } else { + return Err(ScratchError::new( + StatusCode::INTERNAL_SERVER_ERROR, + "Vector db is not available".to_string() + )); + } + }; + Ok(Response::builder() + .status(StatusCode::OK) + .body(Body::from(json!(caps).to_string())) + .unwrap()) +} \ No newline at end of file diff --git a/src/lsp.rs b/src/lsp.rs index 64aad047d..303557233 100644 --- a/src/lsp.rs +++ b/src/lsp.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; use std::fmt::Display; +use std::path::PathBuf; use std::sync::Arc; use serde::{Deserialize, Serialize}; @@ -15,10 +16,10 @@ use crate::call_validation::{CodeCompletionInputs, CodeCompletionPost, CursorPos use crate::global_context; use crate::global_context::CommandLine; use crate::http::routers::v1::code_completion::handle_v1_code_completion; -use crate::telemetry; use crate::receive_workspace_changes; +use crate::telemetry; use crate::telemetry::snippets_collection; - +use crate::vecdb::file_filter::is_valid_file; const VERSION: &str = env!("CARGO_PKG_VERSION"); @@ -35,7 +36,6 @@ impl Display for APIError { } - // #[derive(Debug)] GlobalContext does not implement Debug pub struct Backend { pub gcx: Arc>, @@ -182,10 +182,10 @@ impl Backend { if is_valid { unchanged_percentage = telemetry::utils::unchanged_percentage( ¶ms.orig_grey_text, - &grey_corrected + &grey_corrected, ); } - Ok(TestHeadTailAddedTextRes{is_valid, grey_corrected, unchanged_percentage}) + Ok(TestHeadTailAddedTextRes { is_valid, grey_corrected, unchanged_percentage }) } } @@ -196,10 +196,17 @@ impl LanguageServer for Backend { info!("LSP client_info {:?}", params.client_info); { let gcx_locked = self.gcx.write().await; - *gcx_locked.lsp_backend_document_state.workspace_folders.write().await = params.workspace_folders; + *gcx_locked.lsp_backend_document_state.workspace_folders.write().await = params.workspace_folders.clone(); info!("LSP workspace_folders {:?}", gcx_locked.lsp_backend_document_state.workspace_folders); } + if let Some(folders) = params.workspace_folders { + match *self.gcx.read().await.vec_db.lock().await { + Some(ref mut db) => db.init_folders(folders).await, + None => {}, + }; + } + let completion_options: CompletionOptions; completion_options = CompletionOptions { resolve_provider: Some(false), @@ -240,6 +247,14 @@ impl LanguageServer for Backend { } async fn did_change(&self, params: DidChangeTextDocumentParams) { + let file_path = PathBuf::from(params.text_document.uri.path()); + if is_valid_file(&file_path) { + match *self.gcx.read().await.vec_db.lock().await { + Some(ref mut db) => db.add_or_update_file(file_path, false).await, + None => {} + }; + } + receive_workspace_changes::on_did_change( self.gcx.clone(), ¶ms.text_document.uri.to_string(), @@ -247,13 +262,34 @@ impl LanguageServer for Backend { ).await } - async fn did_close(&self, params: DidCloseTextDocumentParams) { - let _uri = params.text_document.uri.to_string(); - // TODO: remove text from memory + async fn did_save(&self, params: DidSaveTextDocumentParams) { + self.client + .log_message(MessageType::INFO, "{refact-lsp} file saved") + .await; + let uri = params.text_document.uri.to_string(); + let file_path = PathBuf::from(params.text_document.uri.path()); + if is_valid_file(&file_path) { + match *self.gcx.read().await.vec_db.lock().await { + Some(ref mut db) => db.add_or_update_file(file_path, false).await, + None => {} + }; + } + info!("{uri} saved"); } - async fn did_save(&self, params: DidSaveTextDocumentParams) { - let _uri = params.text_document.uri.to_string(); + async fn did_close(&self, params: DidCloseTextDocumentParams) { + self.client + .log_message(MessageType::INFO, "{refact-lsp} file closed") + .await; + let uri = params.text_document.uri.to_string(); + let file_path = PathBuf::from(params.text_document.uri.path()); + if is_valid_file(&file_path) { + match *self.gcx.read().await.vec_db.lock().await { + Some(ref mut db) => db.add_or_update_file(file_path, false).await, + None => {} + }; + } + info!("{uri} closed"); } async fn shutdown(&self) -> Result<()> { @@ -266,6 +302,34 @@ impl LanguageServer for Backend { Ok(Some(CompletionResponse::Array(vec![ ]))) } + + async fn did_delete_files(&self, params: DeleteFilesParams) { + let files = params.files + .into_iter() + .map(|x| PathBuf::from(x.uri.replace("file://", ""))) + .filter(|x| is_valid_file(&x)); + + match *self.gcx.read().await.vec_db.lock().await { + Some(ref mut db) => { + for file in files { + db.remove_file(&file).await; + } + } + None => {} + }; + } + + async fn did_create_files(&self, params: CreateFilesParams) { + let files = params.files + .into_iter() + .map(|x| PathBuf::from(x.uri.replace("file://", ""))) + .filter(|x| is_valid_file(&x)) + .collect(); + match *self.gcx.read().await.vec_db.lock().await { + Some(ref db) => db.add_or_update_files(files, false).await, + None => {} + }; + } } async fn build_lsp_service( @@ -289,7 +353,7 @@ pub async fn spawn_lsp_task( if cmdline.lsp_stdin_stdout == 0 && cmdline.lsp_port > 0 { let gcx_t = gcx.clone(); let addr: std::net::SocketAddr = ([127, 0, 0, 1], cmdline.lsp_port).into(); - return Some(tokio::spawn( async move { + return Some(tokio::spawn(async move { let listener: TcpListener = TcpListener::bind(&addr).await.unwrap(); info!("LSP listening on {}", listener.local_addr().unwrap()); loop { @@ -312,7 +376,7 @@ pub async fn spawn_lsp_task( if cmdline.lsp_stdin_stdout != 0 && cmdline.lsp_port == 0 { let gcx_t = gcx.clone(); - return Some(tokio::spawn( async move { + return Some(tokio::spawn(async move { let stdin = tokio::io::stdin(); let stdout = tokio::io::stdout(); let (lsp_service, socket) = build_lsp_service(gcx_t.clone()).await; diff --git a/src/main.rs b/src/main.rs index 63423ec65..f4acca239 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,6 @@ use std::io::Write; use tokio::task::JoinHandle; - -use tracing::{error, info, Level}; +use tracing::{info, Level}; use tracing_appender; use crate::background_tasks::start_background_tasks; @@ -20,11 +19,15 @@ mod restream; mod custom_error; mod completion_cache; mod telemetry; -mod vecdb_search; mod lsp; mod http; mod background_tasks; mod receive_workspace_changes; +mod vecdb; +mod fetch_embedding; +mod at_commands; +mod nicer_logs; + #[tokio::main] async fn main() { @@ -35,11 +38,12 @@ async fn main() { tracing_appender::non_blocking(std::io::stderr()) } else { write!(std::io::stderr(), "This rust binary keeps logs as files, rotated daily. Try\ntail -f {}/logs/\nor use --logs-stderr for debugging.\n\n", cache_dir.display()).unwrap(); - tracing_appender::non_blocking(tracing_appender::rolling::RollingFileAppender::new( - tracing_appender::rolling::Rotation::DAILY, - cache_dir.join("logs"), - "rustbinary", - )) + tracing_appender::non_blocking(tracing_appender::rolling::RollingFileAppender::builder() + .rotation(tracing_appender::rolling::Rotation::DAILY) + .filename_prefix("rustbinary") + .max_log_files(30) + .build(cache_dir.join("logs")).unwrap() + ) }; let _tracing = tracing_subscriber::fmt() .with_max_level(if cmdline.verbose {Level::DEBUG} else {Level::INFO}) @@ -61,7 +65,7 @@ async fn main() { let mut background_tasks = start_background_tasks(gcx.clone()); let should_start_http = cmdline.http_port != 0; - let should_start_lsp = (cmdline.lsp_port == 0 && cmdline.lsp_stdin_stdout == 1) || + let should_start_lsp = (cmdline.lsp_port == 0 && cmdline.lsp_stdin_stdout == 1) || (cmdline.lsp_port != 0 && cmdline.lsp_stdin_stdout == 0); let mut main_handle: Option> = None; @@ -70,6 +74,7 @@ async fn main() { } if should_start_lsp { if main_handle.is_none() { + // FIXME: this ignores crate::global_context::block_until_signal , important because now we have a database to corrupt main_handle = spawn_lsp_task(gcx.clone(), cmdline.clone()).await; } else { background_tasks.push_back(spawn_lsp_task(gcx.clone(), cmdline.clone()).await.unwrap()) @@ -78,7 +83,7 @@ async fn main() { if main_handle.is_some() { let _ = main_handle.unwrap().await; } - + background_tasks.abort().await; info!("saving telemetry without sending, so should be quick"); basic_transmit::basic_telemetry_compress(gcx.clone()).await; diff --git a/src/nicer_logs.rs b/src/nicer_logs.rs new file mode 100644 index 000000000..9f7741461 --- /dev/null +++ b/src/nicer_logs.rs @@ -0,0 +1,15 @@ +pub fn first_n_chars(msg: &String, n: usize) -> String { + let mut last_n_chars: String = msg.chars().take(n).collect(); + if last_n_chars.len() == n { + last_n_chars.push_str("..."); + } + return last_n_chars.replace("\n", "\\n"); +} + +pub fn last_n_chars(msg: &String, n: usize) -> String { + let mut last_n_chars: String = msg.chars().rev().take(n).collect::().chars().rev().collect(); + if last_n_chars.len() == n { + last_n_chars.insert_str(0, "..."); + } + return last_n_chars.replace("\n", "\\n"); +} diff --git a/src/restream.rs b/src/restream.rs index 841b6887a..2d7481ace 100644 --- a/src/restream.rs +++ b/src/restream.rs @@ -7,6 +7,7 @@ use futures::StreamExt; use async_stream::stream; use hyper::{Body, Response, StatusCode}; +use crate::nicer_logs; use crate::scratchpad_abstract::ScratchpadAbstract; use crate::forward_to_hf_endpoint; use crate::forward_to_openai_endpoint; @@ -153,6 +154,22 @@ pub async fn scratchpad_interaction_stream( let mut save_url: String = String::new(); let permit = slowdown_arc.acquire().await; loop { + { + let value_maybe = scratch.response_spontaneous(); + if let Ok(value) = value_maybe { + for el in value { + let value_str = format!("data: {}\n\n", serde_json::to_string(&el).unwrap()); + info!("yield: {:?}", nicer_logs::first_n_chars(&value_str, 40)); + yield Result::<_, String>::Ok(value_str); + } + } else { + let err_str = value_maybe.unwrap_err(); + error!("response_spontaneous error: {}", err_str); + let value_str = format!("data: {}\n\n", serde_json::to_string(&json!({"detail": err_str})).unwrap()); + yield Result::<_, String>::Ok(value_str); + } + } + let event_source_maybe = if endpoint_style == "hf" { forward_to_hf_endpoint::forward_to_hf_style_endpoint_streaming( &mut save_url, diff --git a/src/scratchpad_abstract.rs b/src/scratchpad_abstract.rs index 109c31d6f..110ee04bb 100644 --- a/src/scratchpad_abstract.rs +++ b/src/scratchpad_abstract.rs @@ -4,6 +4,7 @@ use std::sync::RwLock; use tokenizers::Tokenizer; use crate::call_validation::SamplingParameters; use async_trait::async_trait; +use serde_json::Value; #[async_trait] @@ -31,6 +32,8 @@ pub trait ScratchpadAbstract: Send { stop_toks: bool, stop_length: bool, ) -> Result<(serde_json::Value, bool), String>; + + fn response_spontaneous(&mut self) -> Result, String>; } diff --git a/src/scratchpads/chat_generic.rs b/src/scratchpads/chat_generic.rs index 544fb221f..997b3e0bf 100644 --- a/src/scratchpads/chat_generic.rs +++ b/src/scratchpads/chat_generic.rs @@ -1,18 +1,20 @@ -use crate::scratchpad_abstract::ScratchpadAbstract; -use crate::scratchpad_abstract::HasTokenizerAndEot; -use crate::scratchpads::chat_utils_deltadelta::DeltaDeltaChatStreamer; -use crate::call_validation::{ChatPost, ChatMessage, SamplingParameters, ContextFile}; -use crate::scratchpads::chat_utils_limit_history::limit_messages_history; -use crate::vecdb_search::{VecdbSearch, embed_vecdb_results}; - use std::sync::Arc; use std::sync::RwLock; -use async_trait::async_trait; -use tokio::sync::Mutex as AMutex; +use async_trait::async_trait; +use serde_json::Value; use tokenizers::Tokenizer; +use tokio::sync::RwLock as ARwLock; use tracing::info; +use crate::call_validation::{ChatMessage, ChatPost, ContextFile, SamplingParameters}; +use crate::global_context::GlobalContext; +use crate::scratchpad_abstract::HasTokenizerAndEot; +use crate::scratchpad_abstract::ScratchpadAbstract; +use crate::scratchpads::chat_utils_deltadelta::DeltaDeltaChatStreamer; +use crate::scratchpads::chat_utils_limit_history::limit_messages_history; +use crate::scratchpads::chat_utils_rag::{run_at_commands, HasVecdbResults}; + const DEBUG: bool = true; @@ -25,14 +27,15 @@ pub struct GenericChatScratchpad { pub keyword_user: String, pub keyword_asst: String, pub default_system_message: String, - pub vecdb_search: Arc>>, + pub has_vecdb_results: HasVecdbResults, + pub global_context: Arc>, } impl GenericChatScratchpad { pub fn new( tokenizer: Arc>, post: ChatPost, - vecdb_search: Arc>>, + global_context: Arc>, ) -> Self { GenericChatScratchpad { t: HasTokenizerAndEot::new(tokenizer), @@ -43,7 +46,8 @@ impl GenericChatScratchpad { keyword_user: "".to_string(), keyword_asst: "".to_string(), default_system_message: "".to_string(), - vecdb_search + has_vecdb_results: HasVecdbResults::new(), + global_context, } } } @@ -83,8 +87,9 @@ impl ScratchpadAbstract for GenericChatScratchpad { context_size: usize, sampling_parameters_to_patch: &mut SamplingParameters, ) -> Result { - // embed_vecdb_results(self.vecdb_search.clone(), &mut self.post, 3).await; - let limited_msgs: Vec = limit_messages_history(&self.t, &self.post, context_size, &self.default_system_message)?; + run_at_commands(self.global_context.clone(), &mut self.post, 6, &mut self.has_vecdb_results).await; + + let limited_msgs: Vec = limit_messages_history(&self.t, &self.post.messages, self.post.parameters.max_new_tokens, context_size, &self.default_system_message)?; sampling_parameters_to_patch.stop = Some(self.dd.stop_list.clone()); // adapted from https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/model.py#L24 let mut prompt = "".to_string(); @@ -123,7 +128,7 @@ impl ScratchpadAbstract for GenericChatScratchpad { } if DEBUG { info!("chat prompt\n{}", prompt); - info!("chat re-encode whole prompt again gives {} tokes", self.t.count_tokens(prompt.as_str())?); + info!("chat re-encode whole prompt again gives {} tokens", self.t.count_tokens(prompt.as_str())?); } Ok(prompt) } @@ -144,5 +149,9 @@ impl ScratchpadAbstract for GenericChatScratchpad { ) -> Result<(serde_json::Value, bool), String> { self.dd.response_streaming(delta, stop_toks) } + + fn response_spontaneous(&mut self) -> Result, String> { + return self.has_vecdb_results.response_streaming(); + } } diff --git a/src/scratchpads/chat_llama2.rs b/src/scratchpads/chat_llama2.rs index b57357c42..24d86be2d 100644 --- a/src/scratchpads/chat_llama2.rs +++ b/src/scratchpads/chat_llama2.rs @@ -1,17 +1,19 @@ -use tracing::info; use std::sync::Arc; use std::sync::RwLock as StdRwLock; -use tokio::sync::Mutex as AMutex; -use tokenizers::Tokenizer; + use async_trait::async_trait; +use serde_json::Value; +use tokenizers::Tokenizer; +use tokio::sync::RwLock as ARwLock; +use tracing::info; -use crate::scratchpad_abstract::ScratchpadAbstract; +use crate::call_validation::{ChatMessage, ChatPost, ContextFile, SamplingParameters}; +use crate::global_context::GlobalContext; use crate::scratchpad_abstract::HasTokenizerAndEot; +use crate::scratchpad_abstract::ScratchpadAbstract; use crate::scratchpads::chat_utils_deltadelta::DeltaDeltaChatStreamer; -use crate::call_validation::{ChatPost, ChatMessage, SamplingParameters, ContextFile}; use crate::scratchpads::chat_utils_limit_history::limit_messages_history; -use crate::vecdb_search::{VecdbSearch, embed_vecdb_results}; - +use crate::scratchpads::chat_utils_rag::{run_at_commands, HasVecdbResults}; const DEBUG: bool = true; @@ -24,7 +26,8 @@ pub struct ChatLlama2 { pub keyword_s: String, // "SYSTEM:" keyword means it's not one token pub keyword_slash_s: String, pub default_system_message: String, - pub vecdb_search: Arc>>, + pub has_vecdb_results: HasVecdbResults, + pub global_context: Arc>, } @@ -32,7 +35,7 @@ impl ChatLlama2 { pub fn new( tokenizer: Arc>, post: ChatPost, - vecdb_search: Arc>>, + global_context: Arc>, ) -> Self { ChatLlama2 { t: HasTokenizerAndEot::new(tokenizer), @@ -41,7 +44,8 @@ impl ChatLlama2 { keyword_s: "".to_string(), keyword_slash_s: "".to_string(), default_system_message: "".to_string(), - vecdb_search + has_vecdb_results: HasVecdbResults::new(), + global_context, } } } @@ -69,8 +73,9 @@ impl ScratchpadAbstract for ChatLlama2 { context_size: usize, sampling_parameters_to_patch: &mut SamplingParameters, ) -> Result { - // embed_vecdb_results(self.vecdb_search.clone(), &mut self.post, 3).await; - let limited_msgs: Vec = limit_messages_history(&self.t, &self.post, context_size, &self.default_system_message)?; + run_at_commands(self.global_context.clone(), &mut self.post, 6, &mut self.has_vecdb_results).await; + + let limited_msgs: Vec = limit_messages_history(&self.t, &self.post.messages, self.post.parameters.max_new_tokens, context_size, &self.default_system_message)?; sampling_parameters_to_patch.stop = Some(self.dd.stop_list.clone()); // loosely adapted from https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/model.py#L24 let mut prompt = "".to_string(); @@ -112,7 +117,7 @@ impl ScratchpadAbstract for ChatLlama2 { if DEBUG { // info!("llama2 chat vdb_suggestion {:?}", vdb_suggestion); info!("llama2 chat prompt\n{}", prompt); - info!("llama2 chat re-encode whole prompt again gives {} tokes", self.t.count_tokens(prompt.as_str())?); + info!("llama2 chat re-encode whole prompt again gives {} tokens", self.t.count_tokens(prompt.as_str())?); } Ok(prompt) } @@ -133,5 +138,9 @@ impl ScratchpadAbstract for ChatLlama2 { ) -> Result<(serde_json::Value, bool), String> { self.dd.response_streaming(delta, stop_toks) } + + fn response_spontaneous(&mut self) -> Result, String> { + return self.has_vecdb_results.response_streaming(); + } } diff --git a/src/scratchpads/chat_passthrough.rs b/src/scratchpads/chat_passthrough.rs index a93ad30c0..ed3201cca 100644 --- a/src/scratchpads/chat_passthrough.rs +++ b/src/scratchpads/chat_passthrough.rs @@ -1,14 +1,15 @@ -use tracing::info; use std::sync::Arc; -use tokio::sync::Mutex as AMutex; + use async_trait::async_trait; +use serde_json::Value; +use tracing::{error, info}; +use tokio::sync::RwLock as ARwLock; +use crate::call_validation::{ChatMessage, ChatPost, ContextFile, SamplingParameters}; +use crate::global_context::GlobalContext; use crate::scratchpad_abstract::ScratchpadAbstract; -use crate::call_validation::{ChatPost, ChatMessage, SamplingParameters, ContextFile}; use crate::scratchpads::chat_utils_limit_history::limit_messages_history_in_bytes; -// use crate::vecdb_search::{VecdbSearch, embed_vecdb_results}; -use crate::vecdb_search::VecdbSearch; - +use crate::scratchpads::chat_utils_rag::{run_at_commands, HasVecdbResults}; const DEBUG: bool = true; @@ -18,21 +19,23 @@ pub struct ChatPassthrough { pub post: ChatPost, pub default_system_message: String, pub limit_bytes: usize, - pub vecdb_search: Arc>>, + pub has_vecdb_results: HasVecdbResults, + pub global_context: Arc>, } -const DEFAULT_LIMIT_BYTES: usize = 4096*3; +const DEFAULT_LIMIT_BYTES: usize = 4096*6; impl ChatPassthrough { pub fn new( post: ChatPost, - vecdb_search: Arc>>, + global_context: Arc>, ) -> Self { ChatPassthrough { post, default_system_message: "".to_string(), limit_bytes: DEFAULT_LIMIT_BYTES, // one token translates to 3 bytes (not unicode chars) - vecdb_search, + has_vecdb_results: HasVecdbResults::new(), + global_context, } } } @@ -53,26 +56,35 @@ impl ScratchpadAbstract for ChatPassthrough { _context_size: usize, _sampling_parameters_to_patch: &mut SamplingParameters, ) -> Result { - let limited_msgs: Vec = limit_messages_history_in_bytes(&self.post, self.limit_bytes, &self.default_system_message)?; - info!("chat passthrough {} messages -> {} messages after applying limits and possibly adding the default system message", &limited_msgs.len(), &limited_msgs.len()); + info!("chat passthrough {} messages at start", &self.post.messages.len()); + let top_n = 6; + run_at_commands(self.global_context.clone(), &mut self.post, top_n, &mut self.has_vecdb_results).await; + let limited_msgs: Vec = limit_messages_history_in_bytes(&self.post.messages, self.limit_bytes, &self.default_system_message)?; + info!("chat passthrough {} messages -> {} messages after applying at-commands and limits, possibly adding the default system message", &self.post.messages.len(), &limited_msgs.len()); let mut filtered_msgs: Vec = Vec::::new(); for msg in &limited_msgs { if msg.role == "assistant" || msg.role == "system" || msg.role == "user" { filtered_msgs.push(msg.clone()); } else if msg.role == "context_file" { - let vector_of_context_files: Vec = serde_json::from_str(&msg.content).unwrap(); // FIXME unwrap - for context_file in &vector_of_context_files { - filtered_msgs.push(ChatMessage { - role: "user".to_string(), - content: format!("{}\n```\n{}```", context_file.file_name, context_file.file_content), - }); + info!(msg.content); + match serde_json::from_str(&msg.content) { + Ok(res) => { + let vector_of_context_files: Vec = res; + for context_file in &vector_of_context_files { + filtered_msgs.push(ChatMessage { + role: "user".to_string(), + content: format!("{}\n```\n{}```", context_file.file_name, context_file.file_content), + }); + } + }, + Err(e) => { error!("error parsing context file: {}", e); } } } } let prompt = "PASSTHROUGH ".to_string() + &serde_json::to_string(&filtered_msgs).unwrap(); if DEBUG { for msg in &filtered_msgs { - info!("filtered message: {:?}", msg); + info!("filtered message role={} {}", msg.role, crate::nicer_logs::first_n_chars(&msg.content, 40)); } } Ok(prompt.to_string()) @@ -119,4 +131,8 @@ impl ScratchpadAbstract for ChatPassthrough { }); Ok((ans, finished)) } + + fn response_spontaneous(&mut self) -> Result, String> { + return self.has_vecdb_results.response_streaming(); + } } diff --git a/src/scratchpads/chat_utils_limit_history.rs b/src/scratchpads/chat_utils_limit_history.rs index 9c13b38bb..c1402bc76 100644 --- a/src/scratchpads/chat_utils_limit_history.rs +++ b/src/scratchpads/chat_utils_limit_history.rs @@ -1,21 +1,22 @@ use crate::scratchpad_abstract::HasTokenizerAndEot; -use crate::call_validation::ChatPost; use crate::call_validation::ChatMessage; +use tracing::info; pub fn limit_messages_history( t: &HasTokenizerAndEot, - post: &ChatPost, + messages: &Vec, + max_new_tokens: usize, context_size: usize, - default_system_mesage: &String, + default_system_message: &String, ) -> Result, String> { - let tokens_limit: i32 = context_size as i32 - post.parameters.max_new_tokens as i32; + let tokens_limit: i32 = context_size as i32 - max_new_tokens as i32; let mut tokens_used: i32 = 0; - let mut message_token_count: Vec = vec![0; post.messages.len()]; - let mut message_take: Vec = vec![false; post.messages.len()]; + let mut message_token_count: Vec = vec![0; messages.len()]; + let mut message_take: Vec = vec![false; messages.len()]; let mut have_system = false; - for (i, msg) in post.messages.iter().enumerate() { + for (i, msg) in messages.iter().enumerate() { let tcnt = (3 + t.count_tokens(msg.content.as_str())?) as i32; // 3 for role "\n\nASSISTANT:" kind of thing message_token_count[i] = tcnt; if i==0 && msg.role == "system" { @@ -24,12 +25,12 @@ pub fn limit_messages_history( have_system = true; } } - let need_default_system_msg = !have_system && default_system_mesage.len() > 0; + let need_default_system_msg = !have_system && default_system_message.len() > 0; if need_default_system_msg { - let tcnt = t.count_tokens(default_system_mesage.as_str())? as i32; + let tcnt = t.count_tokens(default_system_message.as_str())? as i32; tokens_used += tcnt; } - for i in (0..post.messages.len()).rev() { + for i in (0..messages.len()).rev() { let tcnt = 3 + message_token_count[i]; if !message_take[i] { if tokens_used + tcnt < tokens_limit { @@ -40,11 +41,11 @@ pub fn limit_messages_history( } } } - let mut messages_out: Vec = post.messages.iter().enumerate().filter(|(i, _)| message_take[*i]).map(|(_, x)| x.clone()).collect(); + let mut messages_out: Vec = messages.iter().enumerate().filter(|(i, _)| message_take[*i]).map(|(_, x)| x.clone()).collect(); if need_default_system_msg { messages_out.insert(0, ChatMessage { role: "system".to_string(), - content: default_system_mesage.clone(), + content: default_system_message.clone(), }); } Ok(messages_out) @@ -52,41 +53,43 @@ pub fn limit_messages_history( pub fn limit_messages_history_in_bytes( - post: &ChatPost, + messages: &Vec, bytes_limit: usize, - default_system_mesage: &String, + default_system_message: &String, ) -> Result, String> { let mut bytes_used: usize = 0; let mut have_system = false; - let mut message_take: Vec = vec![false; post.messages.len()]; - for (i, msg) in post.messages.iter().enumerate() { + let mut message_take: Vec = vec![false; messages.len()]; + for (i, msg) in messages.iter().enumerate() { bytes_used += msg.content.as_bytes().len(); if i==0 && msg.role == "system" { message_take[i] = true; have_system = true; } } - let need_default_system_msg = !have_system && default_system_mesage.len() > 0; + let need_default_system_msg = !have_system && default_system_message.len() > 0; if need_default_system_msg { - bytes_used += default_system_mesage.as_bytes().len(); + bytes_used += default_system_message.as_bytes().len(); } - for i in (0..post.messages.len()).rev() { - let bytes = post.messages[i].content.len(); + for i in (0..messages.len()).rev() { + let bytes = messages[i].content.len(); + info!("limit_messages_history_in_bytes: message{}, bytes_used={} += {}", i, bytes_used, bytes); if !message_take[i] { if bytes_used + bytes < bytes_limit { message_take[i] = true; bytes_used += bytes; } else { + info!("limit_messages_history_in_bytes: overflow, drop message {} and before", i); break; } } } - let mut messages_out: Vec = post.messages.iter().enumerate().filter(|(i, _)| message_take[*i]).map(|(_, x)| x.clone()).collect(); + let mut messages_out: Vec = messages.iter().enumerate().filter(|(i, _)| message_take[*i]).map(|(_, x)| x.clone()).collect(); if need_default_system_msg { messages_out.insert(0, ChatMessage { role: "system".to_string(), - content: default_system_mesage.clone(), + content: default_system_message.clone(), }); } Ok(messages_out) diff --git a/src/scratchpads/chat_utils_rag.rs b/src/scratchpads/chat_utils_rag.rs new file mode 100644 index 000000000..257837496 --- /dev/null +++ b/src/scratchpads/chat_utils_rag.rs @@ -0,0 +1,157 @@ +use std::sync::Arc; +use std::cmp::Ordering; +use tracing::info; +use serde_json::{json, Value}; +use tokio::sync::RwLock as ARwLock; +use crate::at_commands::at_commands::AtCommandsContext; + +use crate::call_validation::{ChatMessage, ChatPost, ContextFile}; +use crate::global_context::GlobalContext; + + +const SMALL_GAP_LINES: i32 = 10; // lines + +pub fn postprocess_at_results( + messages: Vec, + max_bytes: usize, +) -> Vec { + // 1. Decode all + let mut cxfile_list: Vec = vec![]; + for msg in messages { + cxfile_list.extend(serde_json::from_str::>(&msg.content).unwrap()); + } + // 2. Sort by usefullness + cxfile_list.sort_by(|a, b| { + b.usefullness.partial_cmp(&a.usefullness).unwrap_or(Ordering::Equal) + }); + for cxfile in cxfile_list.iter() { + info!("sorted file {}:{}-{} usefullness {}", crate::nicer_logs::last_n_chars(&cxfile.file_name, 40), cxfile.line1, cxfile.line2, cxfile.usefullness); + } + // 3. Truncate less useful to max_bytes + let mut total_bytes: usize = cxfile_list.iter().map(|x| x.file_content.len()).sum(); + while total_bytes > max_bytes { + let least_useful = cxfile_list.pop(); + match least_useful { + Some(file) => { + total_bytes -= file.file_content.len(); + }, + None => break, + } + } + // 4. Remove small gaps in lines and deduplicate + let mut merged: Vec = vec![]; + let list_len = cxfile_list.len(); + let mut eaten: Vec = vec![false; list_len]; + loop { + let mut merged_anything = false; + let cxfile_list_copy = cxfile_list.clone(); // unnecessary operation because of rust borrow rules :/ + for i in 0..list_len { + if eaten[i] { + continue; + } + let x: &mut ContextFile = cxfile_list.get_mut(i).unwrap(); + for j in (i+1)..list_len { + if eaten[j] { + continue; + } + let y: &ContextFile = cxfile_list_copy.get(j).unwrap(); + if x.file_name != y.file_name { + continue; + } + let possible_merge_line1 = x.line1.min(y.line1); + let possible_merge_line2 = x.line2.max(y.line2); + if possible_merge_line2 - possible_merge_line1 <= (x.line2 - x.line1) + (y.line2 - y.line1) + SMALL_GAP_LINES { + // good, makes sense to merge + info!("merging file {} range {}-{} with range {}-{}", x.file_name, x.line1, x.line2, y.line1, y.line2); + eaten[j] = true; + x.line1 = possible_merge_line1; + x.line2 = possible_merge_line2; + merged_anything = true; + } + } + } + if !merged_anything { + break; + } + } + for i in 0..list_len { + if eaten[i] { + continue; + } + merged.push(cxfile_list[i].clone()); + info!("merged {}:{}-{}", cxfile_list[i].file_name, cxfile_list[i].line1, cxfile_list[i].line2); + } + // 5. Encode back into a single message + let mut processed_messages: Vec = vec![]; + let message = ChatMessage { + role: "context_file".to_string(), + content: serde_json::to_string(&merged).unwrap(), + }; + processed_messages.push(message); + processed_messages +} + +pub async fn run_at_commands( + global_context: Arc>, + post: &mut ChatPost, + top_n: usize, + has_vecdb: &mut HasVecdbResults, +) { + // TODO: don't operate on `post`, return a copy of the messages + let context = AtCommandsContext::new(global_context.clone()).await; + let mut query = post.messages.last().unwrap().content.clone(); // latest_msg_cont + let valid_commands = crate::at_commands::utils::find_valid_at_commands_in_query(&mut query, &context).await; + + let mut messages_for_postprocessing = vec![]; + for cmd in valid_commands { + match cmd.command.lock().await.execute(&query, &cmd.args, top_n, &context).await { + Ok(msg) => { + messages_for_postprocessing.push(msg); + }, + Err(_) => {} + } + } + let max_bytes = 5*1024; + let processed = postprocess_at_results( + messages_for_postprocessing, + max_bytes + ); + for msg in processed { + post.messages.push(msg.clone()); + has_vecdb.push_in_json(json!(msg)); + } + let msg = ChatMessage { + role: "user".to_string(), + content: query, // stream back to the user, without commands + }; + post.messages.push(msg.clone()); + has_vecdb.push_in_json(json!(msg)); +} + +pub struct HasVecdbResults { + pub was_sent: bool, + pub in_json: Vec, +} + +impl HasVecdbResults { + pub fn new() -> Self { + HasVecdbResults { + was_sent: false, + in_json: vec![], + } + } +} + +impl HasVecdbResults { + pub fn push_in_json(&mut self, value: Value) { + self.in_json.push(value); + } + + pub fn response_streaming(&mut self) -> Result, String> { + if self.was_sent == true || self.in_json.is_empty() { + return Ok(vec![]); + } + self.was_sent = true; + Ok(self.in_json.clone()) + } +} diff --git a/src/scratchpads/completion_single_file_fim.rs b/src/scratchpads/completion_single_file_fim.rs index 814798224..1b7720dac 100644 --- a/src/scratchpads/completion_single_file_fim.rs +++ b/src/scratchpads/completion_single_file_fim.rs @@ -4,21 +4,25 @@ use crate::call_validation::CodeCompletionPost; use crate::call_validation::SamplingParameters; use std::sync::Arc; use std::sync::RwLock as StdRwLock; +use std::vec; +use tokio::sync::Mutex as AMutex; // use ropey::RopeSlice; use tokenizers::Tokenizer; use ropey::Rope; use tracing::info; use async_trait::async_trait; +use serde_json::Value; use crate::completion_cache; use crate::telemetry::telemetry_structs; use crate::telemetry::snippets_collection; +use crate::vecdb::structs::VecdbSearch; const DEBUG: bool = false; #[derive(Debug)] -pub struct SingleFileFIM { +pub struct SingleFileFIM { pub t: HasTokenizerAndEot, pub post: CodeCompletionPost, pub order: String, @@ -27,19 +31,24 @@ pub struct SingleFileFIM { pub fim_middle: String, pub data4cache: completion_cache::CompletionSaveToCache, pub data4snippet: snippets_collection::SaveSnippet, + pub vecdb_search: Arc>>, } -impl SingleFileFIM { +impl SingleFileFIM { pub fn new( tokenizer: Arc>, post: CodeCompletionPost, order: String, cache_arc: Arc>, tele_storage: Arc>, - ) -> Self { + vecdb_search: Arc>>, + ) -> Self where T: VecdbSearch + Send { let data4cache = completion_cache::CompletionSaveToCache::new(cache_arc, &post); let data4snippet = snippets_collection::SaveSnippet::new(tele_storage, &post); - SingleFileFIM { t: HasTokenizerAndEot::new(tokenizer), post, order, fim_prefix: String::new(), fim_suffix: String::new(), fim_middle: String::new(), data4cache, data4snippet } + SingleFileFIM { t: HasTokenizerAndEot::new(tokenizer), post, order, fim_prefix: String::new(), + fim_suffix: String::new(), fim_middle: String::new(), data4cache, data4snippet, + vecdb_search + } } fn cleanup_prompt(&mut self, text: &String) -> String { @@ -53,7 +62,7 @@ impl SingleFileFIM { #[async_trait] -impl ScratchpadAbstract for SingleFileFIM { +impl ScratchpadAbstract for SingleFileFIM { fn apply_model_adaptation_patch( &mut self, patch: &serde_json::Value, @@ -74,6 +83,7 @@ impl ScratchpadAbstract for SingleFileFIM { Ok(()) } + async fn prompt( &mut self, context_size: usize, @@ -98,6 +108,23 @@ impl ScratchpadAbstract for SingleFileFIM { let pos = &self.post.inputs.cursor; let mut before_iter = text.lines_at(pos.line as usize).reversed(); let mut after_iter = text.lines_at(pos.line as usize + 1); + let (extra_context, mut tokens_used) = match *self.vecdb_search.lock().await { + Some(ref db) => { + match self.post.no_cache || self.post.inputs.multiline { + true => { + let text_near_cursor = get_context_near_cursor(&text, pos.line as usize, 20); + search_vecdb( + db, + self.t.clone(), + text_near_cursor, + (limit as f32 * 0.5) as usize + ).await + } + false => (String::new(), 0) + } + } + None => (String::new(), 0) + }; let mut before_line = before_iter.next(); @@ -117,7 +144,7 @@ impl ScratchpadAbstract for SingleFileFIM { let mut before = vec![]; let mut after = String::new(); - let mut tokens_used = self.t.count_tokens( + tokens_used += self.t.count_tokens( (cursor_line1.clone() + &cursor_line2).as_str() )?; while before_line.is_some() || after_line.is_some() { @@ -146,9 +173,10 @@ impl ScratchpadAbstract for SingleFileFIM { let prompt: String; if self.order == "PSM" { prompt = format!( - "{}{}{}{}{}{}{}{}", + "{}{}{}{}{}{}{}{}{}", self.t.eos, self.fim_prefix, + extra_context, before.into_iter().rev().collect::>().join(""), cursor_line1, self.fim_suffix, @@ -158,9 +186,10 @@ impl ScratchpadAbstract for SingleFileFIM { ); } else if self.order == "SPM" { prompt = format!( - "{}{}{}{}{}{}{}{}", + "{}{}{}{}{}{}{}{}{}", self.t.eos, self.fim_suffix, + extra_context, cursor_line2, after, self.fim_prefix, @@ -174,7 +203,7 @@ impl ScratchpadAbstract for SingleFileFIM { if DEBUG { info!("cursor position\n{:?}", self.post.inputs.cursor); info!("prompt\n{}", prompt); - info!("re-encode whole prompt again gives {} tokes", self.t.count_tokens(prompt.as_str())?); + info!("re-encode whole prompt again gives {} tokens", self.t.count_tokens(prompt.as_str())?); } Ok(prompt) } @@ -257,6 +286,75 @@ impl ScratchpadAbstract for SingleFileFIM { }); Ok((ans, finished)) } + + fn response_spontaneous(&mut self) -> Result, String> { + return Err("".to_string()); + } +} + +fn get_context_near_cursor(text: &Rope, line_pos: usize, max_lines_count: usize) -> String { + let mut before_iter = text.lines_at(line_pos).reversed(); + let mut after_iter = text.lines_at(line_pos + 1); + + let mut before = vec![]; + let mut after = vec![]; + let mut before_line = before_iter.next(); + let mut after_line = after_iter.next(); + + while (before.len() + after.len() < max_lines_count) && (before_line.is_some() || after_line.is_some()) { + if let Some(before_line) = before_iter.next() { + before.push(before_line.as_str().unwrap_or("")); + } + if let Some(after_line) = after_iter.next() { + after.push(after_line.as_str().unwrap_or("")); + } + before_line = before_iter.next(); + after_line = after_iter.next(); + } + + let before_str = before.into_iter().rev().collect::>().join(""); + let after_str = after.join(""); + + return format!("{}{}", before_str, after_str); +} + +async fn search_vecdb( + vecdb_search: &T, + tokenizer: HasTokenizerAndEot, + text_near_cursor: String, + max_context_size: usize +) -> (String, i32) where T: VecdbSearch + Send { + let search_result = vecdb_search.search(text_near_cursor, 20).await; + + let init_cfc_text = "Here are some relevant code fragments from other files of the repo:\n\n"; + let mut tokens_used = tokenizer.count_tokens(init_cfc_text).expect( + "Tokenization has failed" + ); + + match search_result { + Ok(res) => { + if res.results.is_empty() { + return ("".to_string(), tokens_used); + } + + let mut final_text_vec: Vec = vec![init_cfc_text.to_string()]; + for res in res.results { + let text: String = format!( + "The below code fragment is found in {}\n{}\n\n", + res.file_path.to_str().unwrap_or(""), res.window_text + ); + tokens_used += tokenizer.count_tokens(&text).expect( + "Tokenization has failed" + ); + final_text_vec.push(text); + if tokens_used > max_context_size as i32 { + break + } + } + (final_text_vec.join(""), tokens_used) + } + Err(_) => ("".to_string(), tokens_used) + } } diff --git a/src/scratchpads/mod.rs b/src/scratchpads/mod.rs index a0bab2d30..f379efccb 100644 --- a/src/scratchpads/mod.rs +++ b/src/scratchpads/mod.rs @@ -10,6 +10,7 @@ pub mod chat_llama2; pub mod chat_passthrough; pub mod chat_utils_deltadelta; pub mod chat_utils_limit_history; +pub mod chat_utils_rag; use crate::call_validation::CodeCompletionPost; use crate::call_validation::ChatPost; @@ -18,14 +19,14 @@ use crate::caps::CodeAssistantCaps; use crate::scratchpad_abstract::ScratchpadAbstract; use crate::completion_cache; use crate::telemetry::telemetry_structs; -use crate::vecdb_search; use crate::cached_tokenizers; +use crate::vecdb::structs::VecdbSearch; fn verify_has_send(_x: &T) {} -pub async fn create_code_completion_scratchpad( +pub async fn create_code_completion_scratchpad( global_context: Arc>, caps: Arc>, model_name_for_tokenizer: String, @@ -34,13 +35,15 @@ pub async fn create_code_completion_scratchpad( scratchpad_patch: &serde_json::Value, cache_arc: Arc>, tele_storage: Arc>, -) -> Result, String> { + vecdb_search: Arc>>, +) -> Result, String> + where T: VecdbSearch + 'static + Sync { let mut result: Box; let tokenizer_arc: Arc> = cached_tokenizers::cached_tokenizer(caps, global_context, model_name_for_tokenizer).await?; if scratchpad_name == "FIM-PSM" { - result = Box::new(completion_single_file_fim::SingleFileFIM::new(tokenizer_arc, post, "PSM".to_string(), cache_arc, tele_storage)); + result = Box::new(completion_single_file_fim::SingleFileFIM::new(tokenizer_arc, post, "PSM".to_string(), cache_arc, tele_storage, vecdb_search)); } else if scratchpad_name == "FIM-SPM" { - result = Box::new(completion_single_file_fim::SingleFileFIM::new(tokenizer_arc, post, "SPM".to_string(), cache_arc, tele_storage)); + result = Box::new(completion_single_file_fim::SingleFileFIM::new(tokenizer_arc, post, "SPM".to_string(), cache_arc, tele_storage, vecdb_search)); } else { return Err(format!("This rust binary doesn't have code completion scratchpad \"{}\" compiled in", scratchpad_name)); } @@ -56,17 +59,16 @@ pub async fn create_chat_scratchpad( post: ChatPost, scratchpad_name: &str, scratchpad_patch: &serde_json::Value, - vecdb_search: Arc>>, ) -> Result, String> { let mut result: Box; if scratchpad_name == "CHAT-GENERIC" { - let tokenizer_arc: Arc> = cached_tokenizers::cached_tokenizer(caps, global_context, model_name_for_tokenizer).await?; - result = Box::new(chat_generic::GenericChatScratchpad::new(tokenizer_arc, post, vecdb_search)); + let tokenizer_arc: Arc> = cached_tokenizers::cached_tokenizer(caps, global_context.clone(), model_name_for_tokenizer).await?; + result = Box::new(chat_generic::GenericChatScratchpad::new(tokenizer_arc, post, global_context.clone())); } else if scratchpad_name == "CHAT-LLAMA2" { - let tokenizer_arc: Arc> = cached_tokenizers::cached_tokenizer(caps, global_context, model_name_for_tokenizer).await?; - result = Box::new(chat_llama2::ChatLlama2::new(tokenizer_arc, post, vecdb_search)); + let tokenizer_arc: Arc> = cached_tokenizers::cached_tokenizer(caps, global_context.clone(), model_name_for_tokenizer).await?; + result = Box::new(chat_llama2::ChatLlama2::new(tokenizer_arc, post, global_context.clone())); } else if scratchpad_name == "PASSTHROUGH" { - result = Box::new(chat_passthrough::ChatPassthrough::new(post, vecdb_search)); + result = Box::new(chat_passthrough::ChatPassthrough::new(post, global_context.clone())); } else { return Err(format!("This rust binary doesn't have chat scratchpad \"{}\" compiled in", scratchpad_name)); } diff --git a/src/vecdb/file_filter.rs b/src/vecdb/file_filter.rs new file mode 100644 index 000000000..f5f68abfd --- /dev/null +++ b/src/vecdb/file_filter.rs @@ -0,0 +1,132 @@ +use std::fs; +use std::io::Read; +use std::path::PathBuf; + +use async_process::Command; +use walkdir::WalkDir; +use which::which; + +const LARGE_FILE_SIZE_THRESHOLD: u64 = 1_000_000; +// 1 MB +const SMALL_FILE_SIZE_THRESHOLD: u64 = 10; // 10 Bytes + +const SOURCE_FILE_EXTENSIONS: &[&str] = &[ + "c", "cpp", "cc", "h", "hpp", "cs", "java", "py", "rb", "go", "rs", "swift", + "php", "js", "jsx", "ts", "tsx", "lua", "pl", "r", "sh", "bat", "cmd", "ps1", + "m", "kt", "kts", "groovy", "dart", "fs", "fsx", "fsi", "html", "htm", "css", + "scss", "sass", "less", "json", "xml", "yml", "yaml", "md", "sql", "db", "sqlite", + "mdf", "cfg", "conf", "ini", "toml", "dockerfile", "ipynb", "rmd", "swift", "java", + "xml", "kt", "xaml", "unity", "gd", "uproject", "uasset", "asm", "s", "tex", + "makefile", "mk", "cmake", "gradle", +]; + +pub fn is_valid_file(path: &PathBuf) -> bool { + // Check if the path points to a file + if !path.is_file() { + return false; + } + + // Check if the file is in a hidden directory + if path.ancestors().any(|ancestor| { + ancestor.file_name() + .map(|name| name.to_string_lossy().starts_with('.')) + .unwrap_or(false) + }) { + return false; + } + + // Check if the file is a source file + if let Some(extension) = path.extension() { + if !SOURCE_FILE_EXTENSIONS.contains(&extension.to_str().unwrap_or_default()) { + return false; + } + } else { + // No extension, not a source file + return false; + } + + // Check file size + if let Ok(metadata) = fs::metadata(path) { + let file_size = metadata.len(); + if file_size < SMALL_FILE_SIZE_THRESHOLD || file_size > LARGE_FILE_SIZE_THRESHOLD { + return false; + } + } else { + // Unable to access file metadata + return false; + } + + // Check for read permissions + if fs::read(&path).is_err() { + return false; + } + + // Check if the file is not UTF-8 + let mut file = match fs::File::open(&path) { + Ok(file) => file, + Err(_) => return false, + }; + let mut buffer = Vec::new(); + if file.read_to_end(&mut buffer).is_err() { + return false; + } + if String::from_utf8(buffer).is_err() { + return false; + } + + // All checks passed + true +} + +pub async fn get_control_version_files(path: &PathBuf) -> Option> { + if path.join(".git").exists() && which("git").is_ok() { + // Git repository + run_command("git", &["ls-files"], path).await + } else if path.join(".hg").exists() && which("hg").is_ok() { + // Mercurial repository + run_command("hg", &["status", "-c"], path).await + } else if path.join(".svn").exists() && which("svn").is_ok() { + // SVN repository + run_command("svn", &["list", "-R"], path).await + } else { + None + } +} + +async fn run_command(cmd: &str, args: &[&str], path: &PathBuf) -> Option> { + let output = Command::new(cmd) + .args(args) + .current_dir(path) + .output() + .await + .ok()?; + + if !output.status.success() { + return None; + } + + String::from_utf8(output.stdout) + .ok() + .map(|s| s.lines().map(|line| path.join(line)).collect()) +} + + +pub async fn retrieve_files_by_proj_folders(proj_folders: Vec) -> Vec { + let mut all_files: Vec = Vec::new(); + for proj_folder in proj_folders { + let maybe_files = get_control_version_files(&proj_folder).await; + if let Some(files) = maybe_files { + all_files.extend(files); + } else { + let files: Vec = WalkDir::new(proj_folder) + .into_iter() + .filter_map(|e| e.ok()) + .filter(|e| !e.path().is_dir()) + .filter(|e| is_valid_file(&e.path().to_path_buf())) + .map(|e| e.path().to_path_buf()) + .collect::>(); + all_files.extend(files); + } + } + all_files +} diff --git a/src/vecdb/file_splitter.rs b/src/vecdb/file_splitter.rs new file mode 100644 index 000000000..04f31f02c --- /dev/null +++ b/src/vecdb/file_splitter.rs @@ -0,0 +1,85 @@ +use std::path::PathBuf; +use tokio::fs::read_to_string; +use crate::vecdb::structs::SplitResult; +use md5; + +fn str_hash(s: &String) -> String { + let digest = md5::compute(s); + format!("{:x}", digest) +} + +pub struct FileSplitter { + soft_window: usize, + hard_window: usize, +} + +impl FileSplitter { + pub fn new(window_size: usize, soft_limit: usize) -> Self { + Self { + soft_window: window_size, + hard_window: window_size + soft_limit, + } + } + + pub async fn split(&self, file_path: &PathBuf) -> Result, String> { + let text = match read_to_string(file_path).await { + Ok(s) => s, + Err(e) => return Err(e.to_string()) + }; + + let mut chunks = Vec::new(); + let mut batch = Vec::new(); + let mut batch_size = 0; + let mut soft_batch = Vec::new(); + let mut current_line_number: u64 = 0; + for line in text.lines() { + batch_size += line.len(); + if batch_size > self.soft_window { + soft_batch.push(line.to_string()); + } else { + batch.push(line.to_string()); + } + + if batch_size >= self.hard_window { + let best_break_line_n = soft_batch.iter() + .rposition(|l| l.trim().is_empty()) + .unwrap_or(soft_batch.len()); + + let (remaining, to_next_batch) = soft_batch.split_at(best_break_line_n); + batch.extend_from_slice(remaining); + + let start_line = current_line_number; + let end_line = start_line + batch.len() as u64; + current_line_number += batch.len() as u64; + + chunks.push(SplitResult { + file_path: file_path.clone(), + window_text: batch.join("\n"), + window_text_hash: str_hash(&batch.join("\n")), + start_line, + end_line, + }); + + batch = to_next_batch.to_vec(); + soft_batch.clear(); + batch_size = batch.iter().map(|s| s.len()).sum(); + } + } + + if !batch.is_empty() || !soft_batch.is_empty() { + batch.extend(soft_batch); + let start_line = current_line_number; + let end_line = start_line + batch.len() as u64; + + chunks.push(SplitResult { + file_path: file_path.clone(), + window_text: batch.join("\n"), + window_text_hash: str_hash(&batch.join("\n")), + start_line, + end_line, + }); + } + + Ok(chunks) + } +} diff --git a/src/vecdb/file_watcher_service.rs b/src/vecdb/file_watcher_service.rs new file mode 100644 index 000000000..c651a53ff --- /dev/null +++ b/src/vecdb/file_watcher_service.rs @@ -0,0 +1,113 @@ +use std::path::{Path, PathBuf}; +use std::sync::Arc; + +use futures::{ + channel::mpsc::{channel, Receiver}, + SinkExt, StreamExt, +}; +use log::{info, error}; +use notify::{Config, Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher}; +use serde_json::Value; +use tokio::fs::File; +use tokio::io::AsyncBufReadExt; +use tokio::io::BufReader; +use tokio::sync::RwLock as ARwLock; + +use crate::global_context::GlobalContext; +use crate::vecdb::vecdb::VecDb; + +fn make_async_watcher() -> notify::Result<(RecommendedWatcher, Receiver>)> { + let (mut tx, rx) = channel(1); + + let watcher = RecommendedWatcher::new( + move |res| { + futures::executor::block_on(async { + tx.send(res).await.unwrap(); + }) + }, + Config::default(), + )?; + + Ok((watcher, rx)) +} + + +async fn parse_jsonl(path: &PathBuf) -> Result, String> { + let file = File::open(path).await.map_err(|_| format!("File not found: {:?}", path))?; + let reader = BufReader::new(file); + let base_path = path.parent().or(Some(Path::new("/"))).unwrap().to_path_buf(); + + let mut lines = reader.lines(); + let mut paths = Vec::new(); + + while let Some(line) = lines.next_line().await.transpose() { + let line = line.map_err(|_| "Error reading line".to_string())?; + if let Ok(value) = serde_json::from_str::(&line) { + if value.is_object() { + if let Some(filename) = value.get("path").and_then(|v| v.as_str()) { + paths.push(base_path.join(filename)); + } + } + } + } + + Ok(paths) +} + +pub async fn read_and_load_jsonl( + files_set_path: &PathBuf, + vec_db: &VecDb +) { + let filenames_vec = match parse_jsonl(&files_set_path).await { + Ok(data) => data, + Err(_) => { + info!("invalid jsonl file: {:?}", files_set_path); + vec![] + } + }; + vec_db.add_or_update_files(filenames_vec, true).await; +} + +pub async fn file_watcher_task( + files_set_path: PathBuf, + global_context: Arc>, +) -> () { + let (mut watcher, mut rx) = make_async_watcher().expect("Failed to make file watcher"); + if files_set_path.to_str().unwrap_or("").is_empty() { + info!("files_set_path is empty. Exiting."); + return; + } + let load_data = || async { + match *global_context.read().await.vec_db.lock().await { + Some(ref mut db) => read_and_load_jsonl(&files_set_path, db).await, + None => {} + } + }; + + if watcher.watch(&files_set_path, RecursiveMode::Recursive).is_err() { + error!("file watcher {:?} failed to start watching", files_set_path); + return; + } + while let Some(res) = rx.next().await { + match res { + Ok(event) => { + match event.kind { + EventKind::Any => {} + EventKind::Access(_) => {} + EventKind::Create(_) => { + info!("file {:?} was created", files_set_path) + } + EventKind::Modify(_) => { + load_data().await; + } + EventKind::Remove(_) => { + info!("file {:?} was removed", files_set_path) + // TODO: should we remove everything inside the database? + } + EventKind::Other => {} + } + } + Err(e) => info!("file watch error: {:?}", e), + } + } +} diff --git a/src/vecdb/handler.rs b/src/vecdb/handler.rs new file mode 100644 index 000000000..bd2da7101 --- /dev/null +++ b/src/vecdb/handler.rs @@ -0,0 +1,716 @@ +use std::any::Any; +use std::cmp::min; +use std::collections::HashSet; +use std::fmt::{Debug, Formatter}; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Duration; +use std::time::SystemTime; +use tokio::sync::Mutex as AMutex; +use arrow::array::ArrayData; +use arrow::buffer::Buffer; +use arrow::compute::concat_batches; +use arrow_array::{FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, StringArray, UInt64Array}; +use arrow_array::cast::{as_fixed_size_list_array, as_primitive_array, as_string_array}; +use arrow_array::types::{Float32Type, UInt64Type}; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use itertools::Itertools; +use futures_util::TryStreamExt; +use lance::dataset::{WriteMode, WriteParams}; +use lance_linalg::distance::MetricType; +use lance_index::vector::ivf::IvfBuildParams; +use lance::index::vector::pq::PQBuildParams; +use log::info; +use rusqlite::{OpenFlags, params, Result}; +use tempfile::{tempdir, TempDir}; +use tokio::fs; +use tokio_rusqlite::Connection; +use tracing::error; +use vectordb::database::Database; +use vectordb::index::vector::IvfPQIndexBuilder; +use vectordb::table::Table; + +use crate::vecdb::structs::{Record, SplitResult}; + + +impl Debug for VecDBHandler { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "VecDBHandler: {:?}", self.cache_database.type_id()) + } +} + +pub struct VecDBHandler { + cache_database: Arc>, + _data_database_temp_dir: TempDir, + data_table: Table, + schema: SchemaRef, + data_table_hashes: HashSet, + embedding_size: i32, + indexed_file_paths: Arc>>, +} + +fn cosine_similarity(vec1: &Vec, vec2: &Vec) -> f32 { + let dot_product: f32 = vec1.iter().zip(vec2).map(|(x, y)| x * y).sum(); + let magnitude_vec1: f32 = vec1.iter().map(|x| x.powi(2)).sum::().sqrt(); + let magnitude_vec2: f32 = vec2.iter().map(|x| x.powi(2)).sum::().sqrt(); + dot_product / (magnitude_vec1 * magnitude_vec2) +} + +fn cosine_distance(vec1: &Vec, vec2: &Vec) -> f32 { + 1.0 - cosine_similarity(vec1, vec2) +} + +const TWO_WEEKS: i32 = 2 * 7 * 24 * 3600; +const ONE_MONTH: i32 = 30 * 24 * 3600; +const MIN_LIKES: i32 = 3; + +impl VecDBHandler { + pub async fn init(cache_dir: &PathBuf, model_name: &String, embedding_size: i32) -> Result { + let cache_dir_str = match cache_dir.join("refact_vecdb_cache") + .join(format!("model_{}_esize_{}.sqlite", + model_name.replace("/", "_"), + embedding_size + )).to_str() { + + Some(dir) => dir.to_string(), + None => { + return Err(format!("{:?}", "Cache directory is not a valid path")); + } + }; + let data_database_temp_dir = match tempdir() { + Ok(dir) => dir, + Err(_) => return Err(format!("{:?}", "Error creating temp dir")), + }; + let data_database_temp_dir_str = match data_database_temp_dir.path().to_str() { + Some(path) => path, + None => return Err(format!("{:?}", "Temp directory is not a valid path")), + }; + + if !cache_dir.join("refact_vecdb_cache").exists() { + match fs::create_dir_all(cache_dir.join("refact_vecdb_cache")).await { + Ok(_) => {} + Err(e) => return Err(format!("{:?}", e)), + } + } + let cache_database = match Connection::open_with_flags( + cache_dir_str, OpenFlags::SQLITE_OPEN_READ_WRITE + | OpenFlags::SQLITE_OPEN_CREATE + | OpenFlags::SQLITE_OPEN_NO_MUTEX + | OpenFlags::SQLITE_OPEN_URI).await { + Ok(db) => Arc::new(AMutex::new(db)), + Err(err) => return Err(format!("{:?}", err)) + }; + let temp_database = match Database::connect(data_database_temp_dir_str).await { + Ok(db) => db, + Err(err) => return Err(format!("{:?}", err)) + }; + + let vec_trait = Arc::new(Field::new("item", DataType::Float32, true)); + let schema = Arc::new(Schema::new(vec![ + Field::new("vector", DataType::FixedSizeList(vec_trait, embedding_size), true), + Field::new("window_text", DataType::Utf8, true), + Field::new("window_text_hash", DataType::Utf8, true), + Field::new("file_path", DataType::Utf8, true), + Field::new("start_line", DataType::UInt64, true), + Field::new("end_line", DataType::UInt64, true), + Field::new("time_added", DataType::UInt64, true), + Field::new("time_last_used", DataType::UInt64, true), + Field::new("model_name", DataType::Utf8, true), + Field::new("used_counter", DataType::UInt64, true), + ])); + match cache_database.lock().await.call(|conn| { + conn.execute( + "CREATE TABLE IF NOT EXISTS data ( + vector BLOB, + window_text TEXT NOT NULL, + window_text_hash TEXT NOT NULL, + file_path TEXT NOT NULL, + start_line INTEGER NOT NULL, + end_line INTEGER NOT NULL, + time_added INTEGER NOT NULL, + time_last_used INTEGER NOT NULL, + model_name TEXT NOT NULL, + used_counter INTEGER NOT NULL + )", [], + )?; + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_window_text_hash ON data (window_text_hash)", + [], + )?; + Ok(()) + }).await { + Ok(_) => {} + Err(err) => return Err(format!("{:?}", err)) + } + + let batches_iter = RecordBatchIterator::new(vec![].into_iter().map(Ok), schema.clone()); + let data_table = match temp_database.create_table("data", batches_iter, Option::from(WriteParams::default())).await { + Ok(table) => table, + Err(err) => return Err(format!("{:?}", err)) + }; + + Ok(VecDBHandler { + cache_database, + _data_database_temp_dir: data_database_temp_dir, + schema, + data_table, + data_table_hashes: HashSet::new(), + embedding_size, + indexed_file_paths: Arc::new(AMutex::new(vec![])), + }) + } + + async fn checkout(&mut self) { + match self.data_table.checkout_latest().await { + Ok(table) => { self.data_table = table } + Err(err) => error!("Error while checking out the data table: {:?}", err) + } + match self.cache_database.lock().await.call(|connection| { + connection.cache_flush()?; + Ok({}) + }).await { + Ok(_) => {} + Err(err) => error!("Error while flushing cache: {:?}", err) + } + } + + async fn get_records_from_cache(&mut self, hashes: Vec) -> Result<(Vec, Vec), String> { + let mut hashes_set: HashSet = HashSet::from_iter(hashes.iter().cloned()); + let placeholders: String = hashes.iter().map(|_| "?").collect::>().join(","); + let query = format!("SELECT * FROM data WHERE window_text_hash IN ({})", placeholders); + + let records = match self.cache_database.lock().await.call(move |connection| { + let mut statement = connection.prepare(&query)?; + let params = rusqlite::params_from_iter(hashes.iter()); + let records = statement.query_map(params, |row| { + let vector_blob: Vec = row.get(0)?; + let vector: Vec = vector_blob + .chunks_exact(4) + .map(|b| f32::from_ne_bytes(b.try_into().unwrap())) + .collect(); + + let file_path_str: String = row.get(3)?; + let file_path = PathBuf::from(file_path_str); + + let time_added_timestamp: i64 = row.get(6)?; + let time_added = SystemTime::UNIX_EPOCH + Duration::from_secs(time_added_timestamp as u64); + + let time_last_used_timestamp: i64 = row.get(7)?; + let time_last_used = SystemTime::UNIX_EPOCH + Duration::from_secs(time_last_used_timestamp as u64); + + Ok(Record { + vector: Some(vector), + window_text: row.get(1)?, + window_text_hash: row.get(2)?, + file_path: file_path, + start_line: row.get(4)?, + end_line: row.get(5)?, + time_added: time_added, + time_last_used: time_last_used, + model_name: row.get(8)?, + used_counter: row.get(9)?, + distance: -1.0, + }) + })? + .filter_map(|row| row.ok()) + .collect::>(); + Ok(records) + }).await { + Ok(records) => records, + Err(err) => return Err(format!("{:?}", err)) + }; + + for r in &records { + hashes_set.remove(&r.window_text_hash); + } + Ok((records, hashes_set.iter().map(|x| x.clone()).collect())) + } + + async fn insert_records_to_cache(&mut self, records: Vec) -> Result<(), String> { + match self.cache_database.lock().await.call(|connection| { + let transaction = connection.transaction()?; + for record in records { + let time_added = record.time_added.duration_since( + SystemTime::UNIX_EPOCH + ).unwrap_or(Duration::ZERO) + .as_secs(); + + let time_last_used = record.time_last_used.duration_since( + SystemTime::UNIX_EPOCH + ).unwrap_or(Duration::ZERO) + .as_secs(); + + let vector_as_bytes: Vec = record.vector.expect( + "An attempt to push vector-less data to cache DB" + ).iter() + .flat_map(|&num| num.to_ne_bytes()) + .collect(); + + match transaction.execute( + "INSERT INTO data (vector, window_text, window_text_hash, \ + file_path, start_line, end_line, time_added, \ + time_last_used, model_name, used_counter) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)", + rusqlite::params![ + vector_as_bytes, + record.window_text, + record.window_text_hash, + record.file_path.to_str(), + record.start_line, + record.end_line, + time_added as i64, + time_last_used as i64, + record.model_name, + record.used_counter, + ], + ) { + Ok(_) => {} + Err(err) => { + info!("Error while inserting record to cache: {:?}", err); + continue; + } + } + } + match transaction.commit() { + Ok(_) => Ok(()), + Err(err) => Err(err.into()) + } + }).await { + Ok(_) => Ok(()), + Err(err) => Err(format!("{:?}", err)) + } + } + + async fn remove_records_from_cache(&mut self, file_path: String) -> Result<(), String> { + match self.cache_database.lock().await.call(move |connection| { + match connection.execute( + "DELETE FROM data WHERE file_path = ?1", + params![file_path], + ) { + Ok(_) => Ok(()), + Err(err) => Err(err.into()) + } + }).await { + Ok(_) => Ok(()), + Err(err) => Err(format!("{:?}", err)) + } + } + + async fn update_cache_records(&mut self, records: Vec) -> Result<(), String> { + let now = SystemTime::now().duration_since(std::time::UNIX_EPOCH) + .unwrap_or(Duration::ZERO) + .as_secs(); + match self.cache_database.lock().await.call(move |connection| { + let transaction = connection.transaction()?; + for record in records { + match transaction.execute( + "UPDATE data SET time_last_used = ?, used_counter = ? WHERE window_text_hash = ?", + params![ + now, + record.used_counter, + record.window_text_hash, + ], + ) { + Ok(_) => {} + Err(_) => { + continue; + } + }; + } + match transaction.commit() { + Ok(_) => Ok(()), + Err(err) => Err(err.into()) + } + }).await { + Ok(_) => Ok(()), + Err(err) => Err(format!("{:?}", err)) + } + } + + async fn delete_old_records_from_cache(&mut self) -> Result<(), String> { + let now = SystemTime::now().duration_since(std::time::UNIX_EPOCH) + .unwrap_or(Duration::ZERO) + .as_secs(); + + match self.cache_database.lock().await.call(move |connection| { + let transaction = connection.transaction()?; + + transaction.execute( + "DELETE FROM data WHERE (?1 - time_last_used > ?2) AND (used_counter < ?3)", + params![now, TWO_WEEKS, MIN_LIKES], + )?; + + transaction.execute( + "DELETE FROM data WHERE (?1 - time_last_used > ?2)", + params![now, ONE_MONTH], + )?; + + transaction.commit()?; + Ok({}) + }).await { + Ok(_) => Ok(()), + Err(err) => Err(format!("{:?}", err)) + } + } + + pub async fn size(&self) -> Result { + match self.data_table.count_rows().await { + Ok(size) => Ok(size), + Err(err) => Err(format!("{:?}", err)) + } + } + + pub async fn cache_size(&self) -> Result { + self.cache_database.lock().await.call(move |connection| { + let mut stmt = connection.prepare("SELECT COUNT(*) FROM data")?; + let count: usize = stmt.query_row([], |row| row.get(0))?; + Ok(count) + }).await + .map_err(|e| { + e.to_string() + }) + } + + pub async fn select_all_file_paths(&self) -> Vec { + let mut file_paths: HashSet = HashSet::new(); + let records: Vec = self.data_table + .filter(format!("file_path in (select file_path from data)")) + .execute() + .await.unwrap() + .try_collect::>() + .await.unwrap(); + + for rec_batch in records { + for record in VecDBHandler::parse_table_iter(rec_batch, false, None).unwrap() { + file_paths.insert(record.file_path.clone()); + } + } + return file_paths.into_iter().collect(); + } + + pub async fn get_file_orig_text(&mut self, file_path: String) -> String{ + let batches: Vec = self.data_table + .filter(format!("file_path == '{}'", file_path)) + .execute() + .await.unwrap() + .try_collect::>() + .await.unwrap(); + + let mut records = vec![]; + for rec_batch in batches { + for record in VecDBHandler::parse_table_iter(rec_batch, false, None).unwrap() { + records.push((record.start_line, record.end_line, record.window_text)); + } + } + records.sort_by(|a, b| a.1.cmp(&b.1)); + let text: String = records.into_iter().map(|rec| rec.2).collect::>().join("\n"); + text + } + + pub async fn update_indexed_file_paths(&mut self) { + let res = self.select_all_file_paths().await; + self.indexed_file_paths = Arc::new(AMutex::new(res)); + } + + pub async fn get_indexed_file_paths(&self) -> Arc>> { + return self.indexed_file_paths.clone(); + } + + pub async fn try_add_from_cache(&mut self, data: Vec) -> Vec { + if data.is_empty() { + return vec![]; + } + + let hashes = data.iter().map(|x| x.window_text_hash.clone()).collect(); + let (found_records, left_hashes) = match self.get_records_from_cache(hashes).await { + Ok(records) => records, + Err(err) => { + info!("Error while getting values from cache: {:?}", err); + return vec![]; + } + }; + let left_results: Vec = + data.into_iter().filter(|x| left_hashes.contains(&x.window_text_hash)).collect(); + + match self.add_or_update(found_records, false).await { + Ok(_) => {} + Err(err) => info!("Error while adding values from cache: {:?}", err), + }; + left_results + } + + pub async fn add_or_update(&mut self, records: Vec, add_to_cache: bool) -> Result<(), String> { + fn make_emb_data(records: &Vec, embedding_size: i32) -> Result { + let vec_trait = Arc::new(Field::new("item", DataType::Float32, true)); + let mut emb_builder: Vec = vec![]; + + for record in records { + emb_builder.append(&mut record.vector.clone().expect("No embedding is provided")); + } + + let emb_data_res = ArrayData::builder(DataType::Float32) + .add_buffer(Buffer::from_vec(emb_builder)) + .len(records.len() * embedding_size as usize) + .build(); + let emb_data = match emb_data_res { + Ok(res) => res, + Err(err) => { return Err(format!("{:?}", err)); } + }; + + match ArrayData::builder(DataType::FixedSizeList(vec_trait.clone(), embedding_size)) + .len(records.len()) + .add_child_data(emb_data.clone()) + .build() { + Ok(res) => Ok(res), + Err(err) => return Err(format!("{:?}", err)) + } + } + + if records.is_empty() { + return Ok(()); + } + + let vectors: ArrayData = match make_emb_data(&records, self.embedding_size) { + Ok(res) => res, + Err(err) => return Err(format!("{:?}", err)) + }; + let window_texts: Vec = records.iter().map(|x| x.window_text.clone()).collect(); + let window_text_hashes: Vec = records.iter().map(|x| x.window_text_hash.clone()).collect(); + let file_paths: Vec = records.iter().map(|x| x.file_path.to_str().unwrap_or("No filename").to_string()).collect(); + let start_lines: Vec = records.iter().map(|x| x.start_line).collect(); + let end_lines: Vec = records.iter().map(|x| x.end_line).collect(); + let time_adds: Vec = records.iter().map(|x| x.time_added.duration_since(std::time::UNIX_EPOCH).unwrap_or( + Duration::from_secs(0) + ).as_secs()).collect(); + let time_last_used: Vec = records.iter().map(|x| x.time_last_used.duration_since(std::time::UNIX_EPOCH).unwrap_or( + Duration::from_secs(0) + ).as_secs()).collect(); + let model_names: Vec = records.iter().map(|x| x.model_name.clone()).collect(); + let used_counters: Vec = records.iter().map(|x| x.used_counter).collect(); + let data_batches_iter = RecordBatchIterator::new( + vec![RecordBatch::try_new( + self.schema.clone(), + vec![ + Arc::new(FixedSizeListArray::from(vectors.clone())), + Arc::new(StringArray::from(window_texts.clone())), + Arc::new(StringArray::from(window_text_hashes.clone())), + Arc::new(StringArray::from(file_paths.clone())), + Arc::new(UInt64Array::from(start_lines.clone())), + Arc::new(UInt64Array::from(end_lines.clone())), + Arc::new(UInt64Array::from(time_adds.clone())), + Arc::new(UInt64Array::from(time_last_used.clone())), + Arc::new(StringArray::from(model_names.clone())), + Arc::new(UInt64Array::from(used_counters.clone())), + ], + )], + self.schema.clone(), + ); + RecordBatchIterator::new( + vec![RecordBatch::try_new( + self.schema.clone(), + vec![ + Arc::new(FixedSizeListArray::from(vectors)), + Arc::new(StringArray::from(window_texts)), + Arc::new(StringArray::from(window_text_hashes.clone())), + Arc::new(StringArray::from(file_paths)), + Arc::new(UInt64Array::from(start_lines)), + Arc::new(UInt64Array::from(end_lines)), + Arc::new(UInt64Array::from(time_adds)), + Arc::new(UInt64Array::from(time_last_used)), + Arc::new(StringArray::from(model_names)), + Arc::new(UInt64Array::from(used_counters)), + ], + )], + self.schema.clone(), + ); + + if add_to_cache { + match self.insert_records_to_cache(records).await { + Ok(_) => {} + Err(err) => return Err(format!("{:?}", err)) + }; + } + + let data_res = self.data_table.add( + data_batches_iter, Option::from(WriteParams { + mode: WriteMode::Append, + ..Default::default() + }), + ); + self.data_table_hashes.extend(window_text_hashes); + match data_res.await { + Ok(_) => Ok(()), + Err(err) => return Err(format!("{:?}", err)) + } + } + + pub async fn remove(&mut self, file_path: &PathBuf) { + let file_path_str = match file_path.to_str() { + None => { + info!("File path is not a string"); + return; + } + Some(res) => res + }; + + match self.remove_records_from_cache(file_path_str.to_string()).await { + Ok(_) => {} + Err(err) => { + info!("Error while deleting from cache table: {:?}", err); + } + } + // valerii: In documentation I found no way to preprocess strings to prevent SQL injections + match self.data_table.delete( + format!("(file_path = \"{}\")", file_path_str).as_str() // TODO: Prevent a possible sql injection here + ).await { + Ok(_) => {} + Err(err) => { + info!("Error while deleting from data table: {:?}", err); + } + } + } + + pub async fn create_index(&mut self) -> vectordb::error::Result<()> { + let size = self.size().await.unwrap_or(0); + if size == 0 { + return Ok(()); + } + self.data_table.create_index( + IvfPQIndexBuilder::default() + .column("vector".to_owned()) + .index_name("index".to_owned()) + .metric_type(MetricType::Cosine) + .ivf_params(IvfBuildParams { + num_partitions: min(size, 512), + ..IvfBuildParams::default() + }) + .replace(true) + ).await + } + + pub fn contains(&self, hash: &str) -> bool { + self.data_table_hashes.contains(hash) + } + + fn parse_table_iter( + record_batch: RecordBatch, + include_embedding: bool, + embedding_to_compare: Option<&Vec>, + ) -> vectordb::error::Result> { + (0..record_batch.num_rows()).map(|idx| { + let gathered_vec = as_primitive_array::( + &as_fixed_size_list_array(record_batch.column_by_name("vector").unwrap()) + .iter() + .map(|x| x.unwrap()) + .collect::>()[idx] + ) + .iter() + .map(|x| x.unwrap()).collect(); + let distance = match embedding_to_compare { + None => { -1.0 } + Some(embedding) => { cosine_distance(&embedding, &gathered_vec) } + }; + let embedding = match include_embedding { + true => Some(gathered_vec), + false => None + }; + + Ok(Record { + vector: embedding, + window_text: as_string_array(record_batch.column_by_name("window_text") + .expect("Missing column 'window_text'")) + .value(idx) + .to_string(), + window_text_hash: as_string_array(record_batch.column_by_name("window_text_hash") + .expect("Missing column 'window_text_hash'")) + .value(idx) + .to_string(), + file_path: PathBuf::from(as_string_array(record_batch.column_by_name("file_path") + .expect("Missing column 'file_path'")) + .value(idx) + .to_string()), + start_line: as_primitive_array::(record_batch.column_by_name("start_line") + .expect("Missing column 'start_line'")) + .value(idx), + end_line: as_primitive_array::(record_batch.column_by_name("end_line") + .expect("Missing column 'end_line'")) + .value(idx), + time_added: std::time::UNIX_EPOCH + Duration::from_secs( + as_primitive_array::( + record_batch.column_by_name("time_added") + .expect("Missing column 'time_added'")) + .value(idx) + ), + time_last_used: std::time::UNIX_EPOCH + Duration::from_secs( + as_primitive_array::( + record_batch.column_by_name("time_last_used") + .expect("Missing column 'time_last_used'")) + .value(idx) + ), + model_name: as_string_array(record_batch.column_by_name("model_name") + .expect("Missing column 'model_name'")) + .value(idx) + .to_string(), + used_counter: as_primitive_array::(record_batch.column_by_name("used_counter") + .expect("Missing column 'used_counter'")) + .value(idx), + distance, + }) + }).collect() + } + + pub async fn search( + &mut self, + embedding: Vec, + top_n: usize + ) -> vectordb::error::Result> { + let query = self + .data_table + .clone() + .search(Some(Float32Array::from(embedding.clone()))) + .limit(top_n) + .use_index(true) + .execute() + .await? + .try_collect::>() + .await?; + let record_batch = concat_batches(&self.schema, &query)?; + match VecDBHandler::parse_table_iter(record_batch, false, Some(&embedding)) { + Ok(records) => { + let filtered: Vec = records + .into_iter() + .sorted_unstable_by(|a, b| { + a.distance + .partial_cmp(&b.distance) + .unwrap_or(std::cmp::Ordering::Equal) + }) + .collect(); + Ok(filtered) + } + Err(err) => Err(err), + } + } + + pub async fn update_record_statistic(&mut self, records: Vec) { + match self.update_cache_records(records).await { + Ok(_) => {} + Err(err) => { + info!("Error while deleting from data table: {:?}", err); + } + } + } + + pub async fn cleanup_old_records(&mut self) -> Result<(), String> { + info!("VECDB: Cleaning up old records"); + + let now = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap(); + let q = format!("{} - time_last_used > {TWO_WEEKS} AND used_counter < {MIN_LIKES}", now.as_secs()); + self.data_table.delete(&*q).await.expect("could not delete old records"); + + let q = format!("{} - time_last_used > {ONE_MONTH}", now.as_secs()); + self.data_table.delete(&*q).await.expect("could not delete old records"); + + self.delete_old_records_from_cache().await.expect("could not delete old records"); + self.checkout().await; + Ok(()) + } +} diff --git a/src/vecdb/mod.rs b/src/vecdb/mod.rs new file mode 100644 index 000000000..8286f73a3 --- /dev/null +++ b/src/vecdb/mod.rs @@ -0,0 +1,8 @@ +mod file_splitter; +mod handler; +mod vectorizer_service; +pub mod file_watcher_service; +pub mod file_filter; +pub mod structs; +pub mod vecdb; +pub mod vecdb_remote; diff --git a/src/vecdb/structs.rs b/src/vecdb/structs.rs new file mode 100644 index 000000000..320db6922 --- /dev/null +++ b/src/vecdb/structs.rs @@ -0,0 +1,70 @@ +use std::fmt::Debug; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::SystemTime; + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use tokio::sync::Mutex; + +#[async_trait] +pub trait VecdbSearch: Send { + async fn search( + &self, + query: String, + top_n: usize, + ) -> Result; +} + +#[derive(Debug, Clone)] +pub struct VecdbConstants { + // constant in a sense it cannot be changed without creating a new db + pub model_name: String, + pub embedding_size: i32, + pub endpoint_embeddings_template: String, + pub endpoint_embeddings_style: String, + pub cooldown_secs: u64, + pub splitter_window_size: usize, + pub splitter_soft_limit: usize, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct VecDbStatus { + pub unprocessed_files_count: usize, + pub requests_made_since_start: usize, + pub db_size: usize, + pub db_cache_size: usize, +} + +pub type VecDbStatusRef = Arc>; + + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Record { + pub vector: Option>, + pub window_text: String, + pub window_text_hash: String, + pub file_path: PathBuf, + pub start_line: u64, + pub end_line: u64, + pub time_added: SystemTime, + pub time_last_used: SystemTime, + pub model_name: String, + pub used_counter: u64, + pub distance: f32, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)] +pub struct SplitResult { + pub file_path: PathBuf, + pub window_text: String, + pub window_text_hash: String, + pub start_line: u64, + pub end_line: u64, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct SearchResult { + pub query_text: String, + pub results: Vec, +} diff --git a/src/vecdb/vecdb.rs b/src/vecdb/vecdb.rs new file mode 100644 index 000000000..26928d150 --- /dev/null +++ b/src/vecdb/vecdb.rs @@ -0,0 +1,307 @@ +use std::path::PathBuf; +use std::sync::Arc; +use std::sync::RwLock as StdRwLock; +use tokio::sync::RwLock as ARwLock; +use tokio::sync::Mutex as AMutex; +use tracing::{info, error}; + +use async_trait::async_trait; +use serde::Serialize; +use tokio::task::JoinHandle; +use crate::global_context::{CommandLine, GlobalContext}; +use tower_lsp::lsp_types::WorkspaceFolder; +use crate::background_tasks::BackgroundTasksHolder; + +use crate::fetch_embedding; +use crate::vecdb; +use crate::vecdb::file_filter; +use crate::vecdb::file_watcher_service::read_and_load_jsonl; +use crate::vecdb::handler::VecDBHandler; +use crate::vecdb::vectorizer_service::FileVectorizerService; +use crate::vecdb::structs::{SearchResult, VecdbSearch, VecDbStatus, VecdbConstants}; + +fn vecdb_constants( + caps: Arc>, +) -> VecdbConstants { + let caps_locked = caps.read().unwrap(); + VecdbConstants { + model_name: caps_locked.default_embeddings_model.clone(), + embedding_size: caps_locked.size_embeddings.clone(), + endpoint_embeddings_template: caps_locked.endpoint_embeddings_template.clone(), + endpoint_embeddings_style: caps_locked.endpoint_embeddings_style.clone(), + cooldown_secs: 20, + splitter_window_size: 512, + splitter_soft_limit: 1024, + } +} + +pub struct VecDb { + vecdb_emb_client: Arc>, + vecdb_handler: Arc>, + retriever_service: Arc>, + cmdline: CommandLine, + constants: VecdbConstants, +} + +#[derive(Debug, Serialize, Clone)] +pub struct FileSearchResult { + pub file_path: String, + pub file_text: String, +} + +#[derive(Debug, Serialize)] +pub struct VecDbCaps { + functions: Vec, +} + +async fn vecdb_test_request( + vecdb: &VecDb +) -> Result<(), String> { + let search_result = vecdb.search("test query".to_string(), 3).await; + match search_result { + Ok(_) => { + Ok(()) + } + Err(_) => { + error!("vecdb: test search failed"); + Err("vecdb: test search failed".to_string()) + } + } +} + +async fn create_vecdb( + global_context: Arc>, + background_tasks: &mut BackgroundTasksHolder, + constants: VecdbConstants, +) -> Result<(), String> { + info!("vecdb: attempting to launch"); + + let (cache_dir, cmdline) = { + let gcx_locked = global_context.read().await; + (gcx_locked.cache_dir.clone(), gcx_locked.cmdline.clone()) + }; + let base_dir: PathBuf = match cmdline.vecdb_forced_path.as_str() { + "" => cache_dir, + path => PathBuf::from(path), + }; + let vec_db_mb = match VecDb::init( + &base_dir, + cmdline.clone(), + constants, + ).await { + Ok(res) => Some(res), + Err(err) => { + error!("Ooops database is broken! + Last error message: {} + You can report this issue here: + https://github.com/smallcloudai/refact-lsp/issues + Also, you can run this to erase your db: + `rm -rf ~/.cache/refact/refact_vecdb_cache` + After that restart this LSP server or your IDE.", err); + return Err(err); + } + }; + let vec_db = vec_db_mb.unwrap(); + + match vecdb_test_request(&vec_db).await { + Ok(_) => {}, + Err(s) => {return Err(s);} + } + info!("vecdb: test request complete"); + + { + let mut gcx_locked = global_context.write().await; + + if let Some(folders) = gcx_locked.lsp_backend_document_state.workspace_folders.clone().read().await.clone() { + let mut vec_db_lock = gcx_locked.vec_db.lock().await; + if let Some(ref mut db) = *vec_db_lock { + db.init_folders(folders).await; + } + } + let files_set_path = PathBuf::from(gcx_locked.cmdline.files_set_path.clone()); + read_and_load_jsonl(&files_set_path, &vec_db).await; + + let mut tasks = vec_db.start_background_tasks().await; + gcx_locked.vec_db = Arc::new(AMutex::new(Some(vec_db))); + + tasks.extend(vec![ + tokio::spawn(vecdb::file_watcher_service::file_watcher_task(files_set_path, global_context.clone())) + ]); + background_tasks.extend(tasks); + } + Ok(()) +} + +async fn do_i_need_to_reload_vecdb( + global_context: Arc>, +) -> (bool, Option) { + let caps = match crate::global_context::try_load_caps_quickly_if_not_present(global_context.clone(), 0).await { + Ok(caps) => caps, + Err(e) => { + // This branch makes caps error disappear, unless we print it right here: + info!("vecdb: no caps, will not start or reload vecdb, the error was: {}", e); + return (false, None) + } + }; + let consts = vecdb_constants(caps); + + if consts.model_name.is_empty() || consts.endpoint_embeddings_template.is_empty() { + error!("vecdb launch failed: default_embeddings_model.is_empty() || endpoint_embeddings_template.is_empty()"); + return (false, None); + } + + match *global_context.write().await.vec_db.lock().await { + None => {} + Some(ref db) => { + if db.constants.model_name == consts.model_name && + db.constants.endpoint_embeddings_template == consts.endpoint_embeddings_template && + db.constants.endpoint_embeddings_style == consts.endpoint_embeddings_style + { + return (false, None); + } + } + } + + return (true, Some(consts)); +} + +pub async fn vecdb_background_reload( + global_context: Arc>, +) { + let cmd_line = global_context.read().await.cmdline.clone(); + if !cmd_line.vecdb { + return; + } + let mut background_tasks = BackgroundTasksHolder::new(vec![]); + loop { + let (need_reload, consts) = do_i_need_to_reload_vecdb(global_context.clone()).await; + if need_reload && consts.is_some() { + background_tasks.abort().await; + background_tasks = BackgroundTasksHolder::new(vec![]); + match create_vecdb( + global_context.clone(), + &mut background_tasks, + consts.unwrap(), + ).await { + Ok(_) => {} + Err(err) => { + error!("vecdb: init failed: {}", err); + // global_context.vec_db stays None, the rest of the system continues working + } + } + } + tokio::time::sleep(tokio::time::Duration::from_secs(60)).await; + } +} + +impl VecDb { + pub async fn init( + cache_dir: &PathBuf, + cmdline: CommandLine, + constants: VecdbConstants, + ) -> Result { + let handler = match VecDBHandler::init(cache_dir, &constants.model_name, constants.embedding_size).await { + Ok(res) => res, + Err(err) => { return Err(err) } + }; + let vecdb_handler = Arc::new(AMutex::new(handler)); + let retriever_service = Arc::new(AMutex::new(FileVectorizerService::new( + vecdb_handler.clone(), + constants.clone(), + cmdline.api_key.clone(), + ).await)); + Ok(VecDb { + vecdb_emb_client: Arc::new(AMutex::new(reqwest::Client::new())), + vecdb_handler, + retriever_service, + cmdline: cmdline.clone(), + constants: constants.clone(), + }) + } + + pub async fn start_background_tasks(&self) -> Vec> { + info!("vecdb: start_background_tasks"); + return self.retriever_service.lock().await.start_background_tasks(self.vecdb_emb_client.clone()).await; + } + + pub async fn add_or_update_file(&self, file_path: PathBuf, force: bool) { + self.retriever_service.lock().await.process_file(file_path, force).await; + } + + pub async fn add_or_update_files(&self, file_paths: Vec, force: bool) { + self.retriever_service.lock().await.process_files(file_paths, force).await; + } + + pub async fn remove_file(&self, file_path: &PathBuf) { + self.vecdb_handler.lock().await.remove(file_path).await; + } + + pub async fn get_status(&self) -> Result { + self.retriever_service.lock().await.status().await + } + + pub async fn init_folders(&self, folders: Vec) { + let files = file_filter::retrieve_files_by_proj_folders( + folders.iter().map(|x| PathBuf::from(x.uri.path())).collect() + ).await; + self.add_or_update_files(files, true).await; + info!("init_folders complete"); + } + + pub async fn get_indexed_file_paths(&self) -> Arc>> { + return self.vecdb_handler.lock().await.get_indexed_file_paths().await; + } + + pub async fn get_file_orig_text(&self, file_path: String) -> FileSearchResult { + let text = self.vecdb_handler.lock().await.get_file_orig_text(file_path.clone()).await; + FileSearchResult { file_path, file_text: text } + } + + pub async fn caps(&self) -> VecDbCaps { + VecDbCaps { + functions: vec!["@workspace".to_string()], + } + } +} + + +#[async_trait] +impl VecdbSearch for VecDb { + async fn search(&self, query: String, top_n: usize) -> Result { + let t0 = std::time::Instant::now(); + let embedding_mb = fetch_embedding::try_get_embedding( + self.vecdb_emb_client.clone(), + &self.constants.endpoint_embeddings_style, + &self.constants.model_name, + &self.constants.endpoint_embeddings_template, + query.clone(), + &self.cmdline.api_key, + 5 + ).await; + if embedding_mb.is_err() { + return Err("Failed to get embedding".to_string()); + } + info!("search query {:?}, it took {:.3}s to vectorize the query", query, t0.elapsed().as_secs_f64()); + + let mut handler_locked = self.vecdb_handler.lock().await; + let t1 = std::time::Instant::now(); + let results = match handler_locked.search(embedding_mb.unwrap(), top_n).await { + Ok(res) => res, + Err(_) => {return Err("error during search occurred".to_string()) }, + }; + info!("search itself {:.3}s", t1.elapsed().as_secs_f64()); + for rec in results.iter() { + let last_30_chars = crate::nicer_logs::last_n_chars(&rec.file_path.display().to_string(), 30); + info!("distance {:.3}, found {}:{}-{}, ", rec.distance, last_30_chars, rec.start_line, rec.end_line); + } + let t2 = std::time::Instant::now(); + handler_locked.update_record_statistic(results.clone()).await; + info!("update_record_statistic {:.3}s", t2.elapsed().as_secs_f64()); + Ok( + SearchResult { + query_text: query, + results, + } + ) + } +} diff --git a/src/vecdb/vecdb_remote.rs b/src/vecdb/vecdb_remote.rs new file mode 100644 index 000000000..c63e44b35 --- /dev/null +++ b/src/vecdb/vecdb_remote.rs @@ -0,0 +1,46 @@ +use async_trait::async_trait; +use reqwest::header::CONTENT_TYPE; +use reqwest::header::HeaderMap; +use reqwest::header::HeaderValue; +use serde_json::json; + +use crate::vecdb::structs::{SearchResult, VecdbSearch}; + +#[derive(Debug)] +pub struct VecDbRemote {} + +#[async_trait] +impl VecdbSearch for VecDbRemote { + async fn search( + &self, + query: String, + top_n: usize, + ) -> Result { + let url = "http://127.0.0.1:8008/v1/vdb-search".to_string(); + let mut headers = HeaderMap::new(); + // headers.insert(AUTHORIZATION, HeaderValue::from_str(&format!("Bearer {}", self.token)).unwrap()); + headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap()); + let body = json!({ + "text": query, + "top_n": top_n + }); + let res = reqwest::Client::new() + .post(&url) + .headers(headers) + .body(body.to_string()) + .send() + .await.map_err(|e| format!("Vecdb search HTTP error (1): {}", e))?; + + let body = res.text().await.map_err(|e| format!("Vecdb search HTTP error (2): {}", e))?; + // info!("Vecdb search result: {:?}", &body); + let result: Vec = serde_json::from_str(&body).map_err(|e| { + format!("vecdb JSON problem: {}", e) + })?; + if result.len() == 0 { + return Err("Vecdb search result is empty".to_string()); + } + let result0 = result[0].clone(); + // info!("Vecdb search result: {:?}", &result0); + Ok(result0) + } +} diff --git a/src/vecdb/vectorizer_service.rs b/src/vecdb/vectorizer_service.rs new file mode 100644 index 000000000..1e8f418ae --- /dev/null +++ b/src/vecdb/vectorizer_service.rs @@ -0,0 +1,326 @@ +use std::collections::{HashMap, VecDeque}; +use std::path::PathBuf; +use std::time::SystemTime; +use std::ops::Div; +use std::io::Write; +use std::sync::Arc; +use tokio::sync::Mutex as AMutex; +use tokio::sync::Semaphore; +use tokio::task::JoinHandle; +use tracing::info; + +use crate::vecdb::file_splitter::FileSplitter; +use crate::vecdb::handler::VecDBHandler; +use crate::fetch_embedding::try_get_embedding; +use crate::vecdb::structs::{Record, SplitResult, VecDbStatus, VecDbStatusRef, VecdbConstants}; + +#[derive(Debug)] +pub struct FileVectorizerService { + update_request_queue: Arc>>, + output_queue: Arc>>, + vecdb_handler: Arc>, + status: VecDbStatusRef, + constants: VecdbConstants, + api_key: String, +} + +async fn cooldown_queue_thread( + update_request_queue: Arc>>, + out_queue: Arc>>, + _status: VecDbStatusRef, + cooldown_secs: u64, +) { + // This function delays vectorization of a file, until mtime is at least cooldown_secs old. + let mut last_updated: HashMap = HashMap::new(); + loop { + let (path_maybe, _unprocessed_files_count) = { + let mut queue_locked = update_request_queue.lock().await; + let queue_len = queue_locked.len(); + if !queue_locked.is_empty() { + (Some(queue_locked.pop_front().unwrap()), queue_len) + } else { + (None, 0) + } + }; + + if let Some(path) = path_maybe { + last_updated.insert(path, SystemTime::now()); + } + + let mut paths_to_process: Vec = Vec::new(); + let mut stat_too_new = 0; + let mut stat_proceed = 0; + for (path, time) in &last_updated { + if time.elapsed().unwrap().as_secs() > cooldown_secs { + paths_to_process.push(path.clone()); + stat_proceed += 1; + } else { + stat_too_new += 1; + } + } + if stat_proceed > 0 || stat_too_new > 0 { + info!("cooldown_queue_thread: {} files to process, {} files too new", stat_proceed, stat_too_new); + } + for path in paths_to_process { + last_updated.remove(&path); + out_queue.lock().await.push_back(path); + } + tokio::time::sleep(tokio::time::Duration::from_secs(10)).await; + } +} + + +async fn vectorize_thread( + client: Arc>, + queue: Arc>>, + vecdb_handler_ref: Arc>, + status: VecDbStatusRef, + constants: VecdbConstants, + api_key: String, + max_concurrent_tasks: usize, +) { + let file_splitter = FileSplitter::new(constants.splitter_window_size, constants.splitter_soft_limit); + let semaphore = Arc::new(Semaphore::new(max_concurrent_tasks)); + let mut reported_unprocessed: usize = 0; + let mut reported_vecdb_complete: bool = false; + + loop { + let (path_maybe, unprocessed_files_count) = { + let mut queue_locked = queue.lock().await; + let queue_len = queue_locked.len(); + if queue_len > 0 { + (Some(queue_locked.pop_front().unwrap()), queue_len) + } else { + (None, 0) + } + }; + if (unprocessed_files_count + 99).div(100) != (reported_unprocessed + 99).div(100) { + info!("have {} unprocessed files", unprocessed_files_count); + reported_unprocessed = unprocessed_files_count; + } + status.lock().await.unprocessed_files_count = unprocessed_files_count; + reported_vecdb_complete &= unprocessed_files_count==0; + let path = { + match path_maybe { + Some(path) => path, + None => { + // No files left to process + if !reported_vecdb_complete { + let t0 = std::time::Instant::now(); + vecdb_handler_ref.lock().await.update_indexed_file_paths().await; + info!("update_indexed_file_paths: it took {:.3}s", t0.elapsed().as_secs_f64()); + + reported_vecdb_complete = true; + // For now we do not create index 'cause it hurts quality of retrieval + // info!("VECDB Creating index"); + // match vecdb_handler_ref.lock().await.create_index().await { + // Ok(_) => info!("VECDB CREATED INDEX"), + // Err(err) => info!("VECDB Error creating index: {}", err) + // } + write!(std::io::stderr(), "VECDB COMPLETE\n").unwrap(); + info!("VECDB COMPLETE"); // you can see stderr "VECDB COMPLETE" sometimes faster vs logs + } + tokio::time::sleep(tokio::time::Duration::from_millis(5000)).await; + continue; + } + } + }; + + let split_data = match file_splitter.split(&path).await { + Ok(data) => data, + Err(_) => { continue } + }; + + let mut vecdb_handler = vecdb_handler_ref.lock().await; + let mut split_data_filtered: Vec = split_data + .iter() + .filter(|x| !vecdb_handler.contains(&x.window_text_hash)) + .cloned() // Clone to avoid borrowing issues + .collect(); + split_data_filtered = vecdb_handler.try_add_from_cache(split_data_filtered).await; + drop(vecdb_handler); + + let last_30_chars: String = path.display().to_string().chars().rev().take(30).collect::().chars().rev().collect(); + info!("...{} embeddings todo/total {}/{}", last_30_chars, split_data_filtered.len(), split_data.len()); + + // TODO: replace with a batched call? + let join_handles: Vec<_> = split_data_filtered.into_iter().map(|x| { + let model_name_clone = constants.model_name.clone(); + let api_key_clone = api_key.clone(); + let endpoint_embeddings_style_clone = constants.endpoint_embeddings_style.clone(); + let endpoint_template_clone = constants.endpoint_embeddings_template.clone(); + let status_clone = Arc::clone(&status); + let client_clone = Arc::clone(&client); + + let semaphore_clone = Arc::clone(&semaphore); + tokio::spawn(async move { + let _permit = match semaphore_clone.acquire().await { + Ok(permit) => permit, + Err(_) => { + return None; + } + }; + + let result = try_get_embedding( + client_clone, + &endpoint_embeddings_style_clone, + &model_name_clone, + &endpoint_template_clone, + x.window_text.clone(), + &api_key_clone, + 1, + ).await; + status_clone.lock().await.requests_made_since_start += 1; + + drop(_permit); + Some((x, result)) + }) + }).collect(); + + let mut records = vec![]; + + for handle in join_handles { + if let Some((data_res, result_mb)) = handle.await.unwrap() { + match result_mb { + Ok(result) => { + let now = SystemTime::now(); + records.push( + Record { + vector: Some(result), + window_text: data_res.window_text, + window_text_hash: data_res.window_text_hash, + file_path: data_res.file_path, + start_line: data_res.start_line, + end_line: data_res.end_line, + time_added: now, + model_name: constants.model_name.clone(), + distance: -1.0, + used_counter: 0, + time_last_used: now, + } + ); + } + Err(e) => { + info!("Error retrieving embeddings for {}: {}", data_res.file_path.to_str().unwrap(), e); + // queue.lock().await.push_back(data_res.file_path); // push it back again + } + } + } + } + match vecdb_handler_ref.lock().await.add_or_update(records, true).await { + Err(e) => { + info!("Error adding/updating records in VecDB: {}", e); + } + _ => {} + } + } +} + +async fn cleanup_thread(vecdb_handler: Arc>) { + loop { + { + let mut vecdb = vecdb_handler.lock().await; + let _ = vecdb.cleanup_old_records().await; + // By the time we do not create index 'cause it hurts quality of retrieval + // let _ = vecdb.create_index().await; + } + tokio::time::sleep(tokio::time::Duration::from_secs(2 * 3600)).await; + } +} + +impl FileVectorizerService { + pub async fn new( + vecdb_handler: Arc>, + constants: VecdbConstants, + api_key: String, + ) -> Self { + let update_request_queue = Arc::new(AMutex::new(VecDeque::new())); + let output_queue = Arc::new(AMutex::new(VecDeque::new())); + let status = Arc::new(AMutex::new( + VecDbStatus { + unprocessed_files_count: 0, + requests_made_since_start: 0, + db_size: 0, + db_cache_size: 0, + } + )); + FileVectorizerService { + update_request_queue: update_request_queue.clone(), + output_queue: output_queue.clone(), + vecdb_handler: vecdb_handler.clone(), + status: status.clone(), + constants, + api_key, + } + } + + pub async fn start_background_tasks(&self, vecdb_client: Arc>) -> Vec> { + let cooldown_queue_join_handle = tokio::spawn( + cooldown_queue_thread( + self.update_request_queue.clone(), + self.output_queue.clone(), + self.status.clone(), + self.constants.cooldown_secs, + ) + ); + + let retrieve_thread_handle = tokio::spawn( + vectorize_thread( + vecdb_client.clone(), + self.output_queue.clone(), + self.vecdb_handler.clone(), + self.status.clone(), + self.constants.clone(), + self.api_key.clone(), + 4, + ) + ); + + let cleanup_thread_handle = tokio::spawn( + cleanup_thread( + self.vecdb_handler.clone() + ) + ); + + return vec![cooldown_queue_join_handle, retrieve_thread_handle, cleanup_thread_handle]; + } + + pub async fn process_file(&self, path: PathBuf, force: bool) { + info!("adding single file"); + if !force { + self.update_request_queue.lock().await.push_back(path); + } else { + self.output_queue.lock().await.push_back(path); + } + } + + pub async fn process_files(&self, paths: Vec, force: bool) { + info!("adding {} files", paths.len()); + if !force { + self.update_request_queue.lock().await.extend(paths); + } else { + self.output_queue.lock().await.extend(paths); + } + } + + pub async fn status(&self) -> Result { + let mut status = self.status.lock().await.clone(); + status.db_size = match self.vecdb_handler.lock().await.size().await { + Ok(res) => res, + Err(err) => return Err(err) + }; + status.db_cache_size = match self.vecdb_handler.lock().await.cache_size().await { + Ok(res) => res, + Err(err) => return Err(err.to_string()) + }; + Ok(status) + } + + pub async fn get_all_file_paths(&self) -> Arc>> { + return self.vecdb_handler.lock().await.get_indexed_file_paths().await; + } + + pub async fn get_file_orig_text(&self, file_path: String) -> String { + return self.vecdb_handler.lock().await.get_file_orig_text(file_path).await; + } +} diff --git a/src/vecdb_search.rs b/src/vecdb_search.rs deleted file mode 100644 index bdad0a4ab..000000000 --- a/src/vecdb_search.rs +++ /dev/null @@ -1,133 +0,0 @@ -use crate::call_validation::{ChatMessage, ChatPost}; -// use reqwest::header::AUTHORIZATION; -use reqwest::header::CONTENT_TYPE; -use reqwest::header::HeaderMap; -use reqwest::header::HeaderValue; -use serde::{Deserialize, Serialize}; -use serde_json::json; - -use std::sync::Arc; -use tokio::sync::Mutex as AMutex; -use async_trait::async_trait; - - -#[derive(Serialize, Deserialize, Clone, Debug, Default)] -pub struct VecdbResultRec { - pub file_name: String, - pub text: String, - pub score: String, -} - -#[derive(Serialize, Deserialize, Clone, Debug, Default)] -pub struct VecdbResult { - pub results: Vec, -} - -// FIXME: bad idea -pub async fn embed_vecdb_results( - vecdb_search: Arc>>, - post: &mut ChatPost, - limit_examples_cnt: usize, -) { - let my_vdb = vecdb_search.clone(); - let latest_msg_cont = &post.messages.last().unwrap().content; - let mut vecdb_locked = my_vdb.lock().await; - let vdb_resp = vecdb_locked.search(&latest_msg_cont).await; - let vdb_cont = vecdb_resp_to_prompt(&vdb_resp, limit_examples_cnt); - if vdb_cont.len() > 0 { - post.messages = [ - &post.messages[..post.messages.len() -1], - &[ChatMessage { - role: "user".to_string(), - content: vdb_cont, - }], - &post.messages[post.messages.len() -1..], - ].concat(); - } -} - -// FIXME: bad idea -fn vecdb_resp_to_prompt( - resp: &Result, - limit_examples_cnt: usize, -) -> String { - let mut cont = "".to_string(); - match resp { - Ok(resp) => { - cont.push_str("CONTEXT:\n"); - for i in 0..limit_examples_cnt { - if i >= resp.results.len() { - break; - } - cont.push_str("FILENAME:\n"); - cont.push_str(resp.results[i].file_name.clone().as_str()); - cont.push_str("\nTEXT:"); - cont.push_str(resp.results[i].text.clone().as_str()); - cont.push_str("\n"); - } - cont.push_str("\nRefer to the context to answer my next question.\n"); - cont - } - Err(e) => { - format!("Vecdb error: {}", e); - cont - } - } -} - -#[async_trait] -pub trait VecdbSearch: Send { - async fn search( - &mut self, - query: &str, - ) -> Result; -} - -#[derive(Debug, Clone)] -pub struct VecdbSearchTest { -} - -impl VecdbSearchTest { - pub fn new() -> Self { - VecdbSearchTest { - } - } -} - -// unsafe impl Send for VecdbSearchTest {} - -#[async_trait] -impl VecdbSearch for VecdbSearchTest { - async fn search( - &mut self, - query: &str, - ) -> Result { - let url = "http://127.0.0.1:8008/v1/vdb-search".to_string(); - let mut headers = HeaderMap::new(); - // headers.insert(AUTHORIZATION, HeaderValue::from_str(&format!("Bearer {}", self.token)).unwrap()); - headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap()); - let body = json!({ - "texts": [query], - "account": "XXX", - "top_k": 3, - }); - let res = reqwest::Client::new() - .post(&url) - .headers(headers) - .body(body.to_string()) - .send() - .await.map_err(|e| format!("Vecdb search HTTP error (1): {}", e))?; - - let body = res.text().await.map_err(|e| format!("Vecdb search HTTP error (2): {}", e))?; - // info!("Vecdb search result: {:?}", &body); - let result: Vec = serde_json::from_str(&body).map_err(|e| { - format!("vecdb JSON problem: {}", e) - })?; - if result.len() == 0 { - return Err("Vecdb search result is empty".to_string()); - } - let result0 = result[0].clone(); - // info!("Vecdb search result: {:?}", &result0); - Ok(result0) - } -}