Skip to content

Commit

Permalink
added endpoint_embeddings_style
Browse files Browse the repository at this point in the history
  • Loading branch information
valaises committed Jan 4, 2024
1 parent a730ca1 commit f953d1a
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 18 deletions.
3 changes: 3 additions & 0 deletions src/caps.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ pub struct CodeAssistantCaps {
pub default_embeddings_model: String,
#[serde(default)]
pub endpoint_embeddings_template: String,
#[serde(default)]
pub endpoint_embeddings_style: String,
pub running_models: Vec<String>,
#[serde(default)]
pub caps_version: i64, // need to reload if it increases on server, that happens when server configuration changes
Expand Down Expand Up @@ -279,6 +281,7 @@ const HF_DEFAULT_CAPS: &str = r#"
"default_embeddings_model": "BAAI/bge-small-en-v1.5",
"endpoint_embeddings_template": "https://api-inference.huggingface.co/models/$MODEL",
"endpoint_embeddings_style": "hf",
"tokenizer_path_template": "https://huggingface.co/$MODEL/resolve/main/tokenizer.json",
"tokenizer_rewrite_path": {
Expand Down
12 changes: 5 additions & 7 deletions src/fetch_embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,18 @@ use crate::forward_to_openai_endpoint::get_embedding_openai_style;


pub async fn get_embedding(
address_url: &String,
endpoint_embeddings_style: &String,
model_name: &String,
endpoint_template: &String,
text: String,
api_key: &String,
) -> Result<Vec<f32>, String> {
match address_url.to_lowercase().as_str() {
match endpoint_embeddings_style.to_lowercase().as_str() {
"hf" => Ok(get_embedding_hf_style(text, endpoint_template, model_name, api_key).await?),
url if url == "refact" || url.starts_with("http") => {
Ok(get_embedding_openai_style(text, endpoint_template, model_name, api_key).await?)
}
"openai" => Ok(get_embedding_openai_style(text, endpoint_template, model_name, api_key).await?),
_ => {
error!("Invalid address_url: {}", address_url);
Err("Invalid address_url".to_string())
error!("Invalid endpoint_embeddings_style: {}", endpoint_embeddings_style);
Err("Invalid endpoint_embeddings_style".to_string())
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/lsp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ impl Backend {
#[tower_lsp::async_trait]
impl LanguageServer for Backend {
async fn initialize(&self, params: InitializeParams) -> Result<InitializeResult> {
// TODO: vecdb can be initialized after LSP, so retrieve_files_by_proj_folders won't start
info!("LSP client_info {:?}", params.client_info);
{
let gcx_locked = self.gcx.write().await;
Expand Down
16 changes: 12 additions & 4 deletions src/vecdb/vecdb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,13 @@ pub async fn vecdb_background_reload(
continue;
}

let (default_embeddings_model, endpoint_embeddings_template) = {
let (default_embeddings_model, endpoint_embeddings_template, endpoint_embeddings_style) = {
let caps = caps_mb.unwrap();
let caps_locked = caps.read().unwrap();
(
caps_locked.default_embeddings_model.clone(),
caps_locked.endpoint_embeddings_template.clone()
caps_locked.endpoint_embeddings_template.clone(),
caps_locked.endpoint_embeddings_style.clone(),
)
};

Expand All @@ -84,7 +85,11 @@ pub async fn vecdb_background_reload(
}

let vecdb_mb = create_vecdb_if_caps_present(
default_embeddings_model, endpoint_embeddings_template, cmdline, cache_dir
default_embeddings_model,
endpoint_embeddings_template,
endpoint_embeddings_style,
cmdline,
cache_dir
).await;

if vecdb_mb.is_none() {
Expand All @@ -106,6 +111,7 @@ pub async fn vecdb_background_reload(
pub async fn create_vecdb_if_caps_present(
default_embeddings_model: String,
endpoint_embeddings_template: String,
endpoint_embeddings_style: String,

cmdline: &CommandLine,
cache_dir: &PathBuf,
Expand All @@ -115,6 +121,7 @@ pub async fn create_vecdb_if_caps_present(
384, 60, 512, 1024,
default_embeddings_model.clone(),
endpoint_embeddings_template.clone(),
endpoint_embeddings_style.clone(),
).await {
Ok(res) => Some(res),
Err(err) => {
Expand Down Expand Up @@ -143,6 +150,7 @@ impl VecDb {

model_name: String,
endpoint_template: String,
endpoint_embeddings_style: String,
) -> Result<VecDb, String> {
let handler = match VecDBHandler::init(cache_dir, embedding_size).await {
Ok(res) => res,
Expand All @@ -157,7 +165,7 @@ impl VecDb {

model_name.clone(),
cmdline.api_key.clone(),
cmdline.address_url.clone(),
endpoint_embeddings_style.clone(),
endpoint_template.clone(),
).await));

Expand Down
14 changes: 7 additions & 7 deletions src/vecdb/vectorizer_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub struct FileVectorizerService {

model_name: String,
api_key: String,
address_url: String,
endpoint_embeddings_style: String,
endpoint_template: String,
}

Expand Down Expand Up @@ -73,7 +73,7 @@ async fn vectorize_thread(

model_name: String,
api_key: String,
address_url: String,
endpoint_embeddings_style: String,
endpoint_template: String,

max_concurrent_tasks: usize,
Expand Down Expand Up @@ -120,7 +120,7 @@ async fn vectorize_thread(
let join_handles: Vec<_> = split_data_filtered.into_iter().map(|x| {
let model_name_clone = model_name.clone();
let api_key_clone = api_key.clone();
let address_url_clone = address_url.clone();
let endpoint_embeddings_style_clone = endpoint_embeddings_style.clone();
let endpoint_template_clone = endpoint_template.clone();

let semaphore_clone = Arc::clone(&semaphore);
Expand All @@ -133,7 +133,7 @@ async fn vectorize_thread(
};

let result = get_embedding(
&address_url_clone,
&endpoint_embeddings_style_clone,
&model_name_clone,
&endpoint_template_clone,
x.window_text.clone(),
Expand Down Expand Up @@ -203,7 +203,7 @@ impl FileVectorizerService {

model_name: String,
api_key: String,
address_url: String,
endpoint_embeddings_style: String,
endpoint_template: String,
) -> Self {
let update_request_queue = Arc::new(Mutex::new(VecDeque::new()));
Expand All @@ -227,7 +227,7 @@ impl FileVectorizerService {

model_name,
api_key,
address_url,
endpoint_embeddings_style,
endpoint_template,
}
}
Expand All @@ -252,7 +252,7 @@ impl FileVectorizerService {

self.model_name.clone(),
self.api_key.clone(),
self.address_url.clone(),
self.endpoint_embeddings_style.clone(),
self.endpoint_template.clone(),

4,
Expand Down

0 comments on commit f953d1a

Please sign in to comment.