From f953d1afe89be59e717792ba9af0c81ab7052810 Mon Sep 17 00:00:00 2001 From: Valeryi Date: Thu, 4 Jan 2024 12:21:07 +0000 Subject: [PATCH] added endpoint_embeddings_style --- src/caps.rs | 3 +++ src/fetch_embedding.rs | 12 +++++------- src/lsp.rs | 1 + src/vecdb/vecdb.rs | 16 ++++++++++++---- src/vecdb/vectorizer_service.rs | 14 +++++++------- 5 files changed, 28 insertions(+), 18 deletions(-) diff --git a/src/caps.rs b/src/caps.rs index e8463d8df..db7ef7571 100644 --- a/src/caps.rs +++ b/src/caps.rs @@ -48,6 +48,8 @@ pub struct CodeAssistantCaps { pub default_embeddings_model: String, #[serde(default)] pub endpoint_embeddings_template: String, + #[serde(default)] + pub endpoint_embeddings_style: String, pub running_models: Vec, #[serde(default)] pub caps_version: i64, // need to reload if it increases on server, that happens when server configuration changes @@ -279,6 +281,7 @@ const HF_DEFAULT_CAPS: &str = r#" "default_embeddings_model": "BAAI/bge-small-en-v1.5", "endpoint_embeddings_template": "https://api-inference.huggingface.co/models/$MODEL", + "endpoint_embeddings_style": "hf", "tokenizer_path_template": "https://huggingface.co/$MODEL/resolve/main/tokenizer.json", "tokenizer_rewrite_path": { diff --git a/src/fetch_embedding.rs b/src/fetch_embedding.rs index 5f59f08f4..8c493d7bc 100644 --- a/src/fetch_embedding.rs +++ b/src/fetch_embedding.rs @@ -7,20 +7,18 @@ use crate::forward_to_openai_endpoint::get_embedding_openai_style; pub async fn get_embedding( - address_url: &String, + endpoint_embeddings_style: &String, model_name: &String, endpoint_template: &String, text: String, api_key: &String, ) -> Result, String> { - match address_url.to_lowercase().as_str() { + match endpoint_embeddings_style.to_lowercase().as_str() { "hf" => Ok(get_embedding_hf_style(text, endpoint_template, model_name, api_key).await?), - url if url == "refact" || url.starts_with("http") => { - Ok(get_embedding_openai_style(text, endpoint_template, model_name, api_key).await?) - } + "openai" => Ok(get_embedding_openai_style(text, endpoint_template, model_name, api_key).await?), _ => { - error!("Invalid address_url: {}", address_url); - Err("Invalid address_url".to_string()) + error!("Invalid endpoint_embeddings_style: {}", endpoint_embeddings_style); + Err("Invalid endpoint_embeddings_style".to_string()) } } } diff --git a/src/lsp.rs b/src/lsp.rs index b75683af5..1a069afeb 100644 --- a/src/lsp.rs +++ b/src/lsp.rs @@ -171,6 +171,7 @@ impl Backend { #[tower_lsp::async_trait] impl LanguageServer for Backend { async fn initialize(&self, params: InitializeParams) -> Result { + // TODO: vecdb can be initialized after LSP, so retrieve_files_by_proj_folders won't start info!("LSP client_info {:?}", params.client_info); { let gcx_locked = self.gcx.write().await; diff --git a/src/vecdb/vecdb.rs b/src/vecdb/vecdb.rs index d74567fa2..b1a2e4e24 100644 --- a/src/vecdb/vecdb.rs +++ b/src/vecdb/vecdb.rs @@ -69,12 +69,13 @@ pub async fn vecdb_background_reload( continue; } - let (default_embeddings_model, endpoint_embeddings_template) = { + let (default_embeddings_model, endpoint_embeddings_template, endpoint_embeddings_style) = { let caps = caps_mb.unwrap(); let caps_locked = caps.read().unwrap(); ( caps_locked.default_embeddings_model.clone(), - caps_locked.endpoint_embeddings_template.clone() + caps_locked.endpoint_embeddings_template.clone(), + caps_locked.endpoint_embeddings_style.clone(), ) }; @@ -84,7 +85,11 @@ pub async fn vecdb_background_reload( } let vecdb_mb = create_vecdb_if_caps_present( - default_embeddings_model, endpoint_embeddings_template, cmdline, cache_dir + default_embeddings_model, + endpoint_embeddings_template, + endpoint_embeddings_style, + cmdline, + cache_dir ).await; if vecdb_mb.is_none() { @@ -106,6 +111,7 @@ pub async fn vecdb_background_reload( pub async fn create_vecdb_if_caps_present( default_embeddings_model: String, endpoint_embeddings_template: String, + endpoint_embeddings_style: String, cmdline: &CommandLine, cache_dir: &PathBuf, @@ -115,6 +121,7 @@ pub async fn create_vecdb_if_caps_present( 384, 60, 512, 1024, default_embeddings_model.clone(), endpoint_embeddings_template.clone(), + endpoint_embeddings_style.clone(), ).await { Ok(res) => Some(res), Err(err) => { @@ -143,6 +150,7 @@ impl VecDb { model_name: String, endpoint_template: String, + endpoint_embeddings_style: String, ) -> Result { let handler = match VecDBHandler::init(cache_dir, embedding_size).await { Ok(res) => res, @@ -157,7 +165,7 @@ impl VecDb { model_name.clone(), cmdline.api_key.clone(), - cmdline.address_url.clone(), + endpoint_embeddings_style.clone(), endpoint_template.clone(), ).await)); diff --git a/src/vecdb/vectorizer_service.rs b/src/vecdb/vectorizer_service.rs index 906d1a9ee..4322b5f43 100644 --- a/src/vecdb/vectorizer_service.rs +++ b/src/vecdb/vectorizer_service.rs @@ -24,7 +24,7 @@ pub struct FileVectorizerService { model_name: String, api_key: String, - address_url: String, + endpoint_embeddings_style: String, endpoint_template: String, } @@ -73,7 +73,7 @@ async fn vectorize_thread( model_name: String, api_key: String, - address_url: String, + endpoint_embeddings_style: String, endpoint_template: String, max_concurrent_tasks: usize, @@ -120,7 +120,7 @@ async fn vectorize_thread( let join_handles: Vec<_> = split_data_filtered.into_iter().map(|x| { let model_name_clone = model_name.clone(); let api_key_clone = api_key.clone(); - let address_url_clone = address_url.clone(); + let endpoint_embeddings_style_clone = endpoint_embeddings_style.clone(); let endpoint_template_clone = endpoint_template.clone(); let semaphore_clone = Arc::clone(&semaphore); @@ -133,7 +133,7 @@ async fn vectorize_thread( }; let result = get_embedding( - &address_url_clone, + &endpoint_embeddings_style_clone, &model_name_clone, &endpoint_template_clone, x.window_text.clone(), @@ -203,7 +203,7 @@ impl FileVectorizerService { model_name: String, api_key: String, - address_url: String, + endpoint_embeddings_style: String, endpoint_template: String, ) -> Self { let update_request_queue = Arc::new(Mutex::new(VecDeque::new())); @@ -227,7 +227,7 @@ impl FileVectorizerService { model_name, api_key, - address_url, + endpoint_embeddings_style, endpoint_template, } } @@ -252,7 +252,7 @@ impl FileVectorizerService { self.model_name.clone(), self.api_key.clone(), - self.address_url.clone(), + self.endpoint_embeddings_style.clone(), self.endpoint_template.clone(), 4,