Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vecdb caps support #53

Merged
merged 2 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/background_tasks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ pub fn start_background_tasks(global_context: Arc<ARwLock<GlobalContext>>) -> Ba
tokio::spawn(global_context::caps_background_reload(global_context.clone())),
tokio::spawn(basic_transmit::telemetry_background_task(global_context.clone())),
tokio::spawn(snippets_transmit::tele_snip_background_task(global_context.clone())),
tokio::spawn(vecdb::file_watcher_service::file_watcher_task(global_context.clone())),
tokio::spawn(vecdb::vecdb::vecdb_background_reload(global_context.clone())),
])
}
16 changes: 16 additions & 0 deletions src/caps.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ pub struct CodeAssistantCaps {
#[serde(default)]
pub code_chat_models: HashMap<String, ModelRecord>,
pub code_chat_default_model: String,
#[serde(default)]
pub default_embeddings_model: String,
#[serde(default)]
pub endpoint_embeddings_template: String,
#[serde(default)]
pub endpoint_embeddings_style: String,
#[serde(default)]
pub size_embeddings: i32,
pub running_models: Vec<String>,
#[serde(default)]
pub caps_version: i64, // need to reload if it increases on server, that happens when server configuration changes
Expand Down Expand Up @@ -272,6 +280,12 @@ const HF_DEFAULT_CAPS: &str = r#"
"cloud_name": "Hugging Face",
"endpoint_template": "https://api-inference.huggingface.co/models/$MODEL",
"endpoint_style": "hf",

"default_embeddings_model": "BAAI/bge-small-en-v1.5",
"endpoint_embeddings_template": "https://api-inference.huggingface.co/models/$MODEL",
"endpoint_embeddings_style": "hf",
"size_embeddings": 384,

"tokenizer_path_template": "https://huggingface.co/$MODEL/resolve/main/tokenizer.json",
"tokenizer_rewrite_path": {
"meta-llama/Llama-2-70b-chat-hf": "TheBloke/Llama-2-70B-fp16"
Expand Down Expand Up @@ -344,13 +358,15 @@ pub async fn load_caps(
r1.endpoint_chat_passthrough = relative_to_full_url(&caps_url, &r1.endpoint_chat_passthrough)?;
r1.telemetry_basic_dest = relative_to_full_url(&caps_url, &r1.telemetry_basic_dest)?;
r1.telemetry_corrected_snippets_dest = relative_to_full_url(&caps_url, &r1.telemetry_corrected_snippets_dest)?;
r1.endpoint_embeddings_template = relative_to_full_url(&caps_url, &r1.endpoint_embeddings_template)?;
info!("caps {} completion models", r1.code_completion_models.len());
info!("caps default completion model: \"{}\"", r1.code_completion_default_model);
info!("caps {} chat models", r1.code_chat_models.len());
info!("caps default chat model: \"{}\"", r1.code_chat_default_model);
Ok(Arc::new(StdRwLock::new(r1)))
}


valaises marked this conversation as resolved.
Show resolved Hide resolved
fn relative_to_full_url(
caps_url: &String,
maybe_relative_url: &str,
Expand Down
102 changes: 44 additions & 58 deletions src/vecdb/req_client.rs → src/fetch_embedding.rs
Original file line number Diff line number Diff line change
@@ -1,72 +1,58 @@
use std::time::Duration;
use reqwest;
use serde::Serialize;
use tokio::task::JoinHandle;
use tokio::time::sleep;

#[derive(Serialize)]
struct Payload {
pub inputs: String,
}
use tracing::error;

use crate::forward_to_hf_endpoint::get_embedding_hf_style;
use crate::forward_to_openai_endpoint::get_embedding_openai_style;

fn get_base_url() -> String {
#[cfg(test)]
{
// When running tests, return the mockito server URL
mockito::server_url()
}

#[cfg(not(test))]
{
// In a regular run, return the actual URL
"https://api-inference.huggingface.co".to_string()
pub async fn get_embedding(
endpoint_embeddings_style: &String,
model_name: &String,
endpoint_template: &String,
text: String,
api_key: &String,
) -> Result<Vec<f32>, String> {
match endpoint_embeddings_style.to_lowercase().as_str() {
"hf" => get_embedding_hf_style(text, endpoint_template, model_name, api_key).await,
"openai" => get_embedding_openai_style(text, endpoint_template, model_name, api_key).await,
_ => {
error!("Invalid endpoint_embeddings_style: {}", endpoint_embeddings_style);
Err("Invalid endpoint_embeddings_style".to_string())
}
}
}

pub fn get_embedding(
text: String,

// HF often returns 500 errors for no reason
pub async fn try_get_embedding(
endpoint_embeddings_style: &String,
model_name: &String,
api_key: String,
) -> JoinHandle<Result<Vec<f32>, String>> {
let url = format!("{}/models/{}", get_base_url(), model_name);
let client = reqwest::Client::new();
let payload = Payload { inputs: text };

tokio::spawn(async move {
let mut attempts = 0;
let max_attempts = 3;
let delay = Duration::from_secs(5);

while attempts < max_attempts {
let maybe_response = client.post(&url)
.bearer_auth("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".clone())
.json(&payload)
.send()
.await;

match maybe_response {
Ok(response) => {
if response.status().is_success() {
match response.json::<Vec<f32>>().await {
Ok(embedding) => return Ok(embedding),
Err(err) => return Err(format!("Failed to parse the response: {:?}", err)),
}
} else if response.status().is_server_error() {
// Retry in case of 5xx server errors
attempts += 1;
sleep(delay).await;
continue;
} else {
return Err(format!("Failed to get a response: {:?}", response.status()));
}
},
Err(err) => return Err(format!("Failed to send a request: {:?}", err))
endpoint_template: &String,
text: String,
api_key: &String,
max_retries: usize,
) -> Result<Vec<f32>, String> {
let sleep_on_failure_ms = 300;
let mut retries = 0;
loop {
retries += 1;
match get_embedding(
endpoint_embeddings_style,
model_name,
endpoint_template,
text.clone(),
api_key,
).await {
Ok(embedding) => return Ok(embedding),
Err(e) => {
tokio::time::sleep(tokio::time::Duration::from_millis(sleep_on_failure_ms)).await;
if retries > max_retries {
return Err(e);
}
}
}

Err("Exceeded maximum attempts to reach the server".to_string())
})
}
}

#[cfg(test)]
Expand Down
44 changes: 44 additions & 0 deletions src/forward_to_hf_endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use reqwest::header::CONTENT_TYPE;
use reqwest::header::HeaderMap;
use reqwest::header::HeaderValue;
use reqwest_eventsource::EventSource;
use serde::Serialize;
use serde_json::json;
use crate::call_validation::SamplingParameters;

Expand Down Expand Up @@ -85,3 +86,46 @@ pub async fn forward_to_hf_style_endpoint_streaming(
)?;
Ok(event_source)
}


#[derive(Serialize)]
struct EmbeddingsPayloadHF {
pub inputs: String,
}


pub async fn get_embedding_hf_style(
text: String,
endpoint_template: &String,
model_name: &String,
api_key: &String,
) -> Result<Vec<f32>, String> {
let client = reqwest::Client::new();
let payload = EmbeddingsPayloadHF { inputs: text };
let url = endpoint_template.clone().replace("$MODEL", &model_name);
let api_key_clone = api_key.clone();

let join_handle = tokio::task::spawn(async move {
let maybe_response = client
.post(&url)
.bearer_auth(api_key_clone.clone())
.json(&payload)
.send()
.await;

return match maybe_response {
Ok(response) => {
if response.status().is_success() {
match response.json::<Vec<f32>>().await {
Ok(embedding) => Ok(embedding),
Err(err) => Err(format!("Failed to parse the response: {:?}", err)),
}
} else {
Err(format!("Failed to get a response: {:?}", response.status()))
}
},
Err(err) => Err(format!("Failed to send a request: {:?}", err)),
}
});
join_handle.await.unwrap_or_else(|_| Err("Task join error".to_string()))
}
60 changes: 60 additions & 0 deletions src/forward_to_openai_endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use reqwest::header::CONTENT_TYPE;
use reqwest::header::HeaderMap;
use reqwest::header::HeaderValue;
use reqwest_eventsource::EventSource;
use serde::Serialize;
use serde_json::json;
use crate::call_validation;
use crate::call_validation::SamplingParameters;
Expand Down Expand Up @@ -103,3 +104,62 @@ fn _passthrough_messages_to_json(
let messages: Vec<call_validation::ChatMessage> = serde_json::from_str(&messages_str).unwrap();
data["messages"] = serde_json::json!(messages);
}


#[derive(Serialize)]
struct EmbeddingsPayloadOpenAI {
pub input: String,
pub model: String,
}


pub async fn get_embedding_openai_style(
text: String,
endpoint_template: &String,
model_name: &String,
api_key: &String,
) -> Result<Vec<f32>, String> {
let client = reqwest::Client::new();
let payload = EmbeddingsPayloadOpenAI {
input: text,
model: model_name.clone(),
};
let url = endpoint_template.clone();
let api_key_clone = api_key.clone();

let join_handle = tokio::spawn(async move {
let maybe_response = client
.post(&url)
.bearer_auth(api_key_clone.clone())
.json(&payload)
.send()
.await;

return match maybe_response {
Ok(response) => {
if response.status().is_success() {
let response_json = response.json::<serde_json::Value>().await;

match response_json {
Ok(json) => match &json["data"][0]["embedding"] {
serde_json::Value::Array(embedding) => {
let embedding_values: Result<Vec<f32>, _> =
serde_json::from_value(serde_json::Value::Array(embedding.clone()));
embedding_values.map_err(|err| {
format!("Failed to parse the response: {:?}", err)
})
}
_ => Err("Response is missing 'data[0].embedding' field or it's not an array".to_string()),
},
Err(err) => Err(format!("Failed to parse the response: {:?}", err)),
}
} else {
Err(format!("Failed to get a response: {:?}", response.status()))
}
}
Err(err) => Err(format!("Failed to send a request: {:?}", err)),
}
});

join_handle.await.unwrap_or_else(|_| Err("Task join error".to_string()))
}
22 changes: 1 addition & 21 deletions src/global_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,26 +165,6 @@ pub async fn create_global_context(
http_client_builder = http_client_builder.danger_accept_invalid_certs(true)
}
let http_client = http_client_builder.build().unwrap();
let mut vec_db: Option<VecDb> = None;
if cmdline.vecdb {
vec_db = match VecDb::init(
cache_dir.clone(), cmdline.clone(),
384, 60, 512, 1024,
"BAAI/bge-small-en-v1.5".to_string(),
).await {
Ok(res) => Some(res),
Err(err) => {
error!("Ooops database is broken!
Last error message: {}
You can report this issue here:
https://github.com/smallcloudai/refact-lsp/issues
Also you can run this to erase your db:
`rm -rf ~/.cache/refact/refact_vecdb_cache`
After that restart this LSP server or your IDE.", err);
None
}
};
}

let cx = GlobalContext {
cmdline: cmdline.clone(),
Expand All @@ -196,7 +176,7 @@ pub async fn create_global_context(
tokenizer_map: HashMap::new(),
completions_cache: Arc::new(StdRwLock::new(CompletionCache::new())),
telemetry: Arc::new(StdRwLock::new(telemetry_structs::Storage::new())),
vec_db: Arc::new(AMutex::new(vec_db)),
vec_db: Arc::new(AMutex::new(None)),
ask_shutdown_sender: Arc::new(Mutex::new(ask_shutdown_sender)),
lsp_backend_document_state: LSPBackendDocumentState {
document_map: Arc::new(ARwLock::new(HashMap::new())),
Expand Down
9 changes: 2 additions & 7 deletions src/lsp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,14 +179,9 @@ impl LanguageServer for Backend {
}

if let Some(folders) = params.workspace_folders {
let files = retrieve_files_by_proj_folders(
folders.iter().map(|x| PathBuf::from(x.uri.path())).collect()
).await;
match *self.gcx.read().await.vec_db.lock().await {
Some(ref db) => db.add_or_update_files(files, true).await,
None => {
info!("LSP no vec_db");
}
Some(ref db) => db.init_folders(folders).await,
None => {},
};
}

Expand Down
7 changes: 3 additions & 4 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use std::io::Write;

use tracing::{error, info};
use tracing_appender;
use std::sync::Arc;
use tokio::sync::Mutex as AMutex;

use crate::background_tasks::start_background_tasks;
use crate::lsp::spawn_lsp_task;
Expand All @@ -24,6 +26,7 @@ mod http;
mod background_tasks;
mod receive_workspace_changes;
mod vecdb;
mod fetch_embedding;


#[tokio::main]
Expand Down Expand Up @@ -63,10 +66,6 @@ async fn main() {
if lsp_task.is_some() {
background_tasks.push_back(lsp_task.unwrap())
}
background_tasks.extend(match *gcx.read().await.vec_db.lock().await {
Some(ref db) => db.start_background_tasks().await,
None => vec![]
});

let gcx_clone = gcx.clone();
let server = http::start_server(gcx_clone, ask_shutdown_receiver);
Expand Down
Loading
Loading