Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added batch embedding computing #86

Merged
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 101 additions & 47 deletions crates/llm-ls/src/retrieval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@ use std::path::Path;
use std::{path::PathBuf, sync::Arc};
use tinyvec_embed::db::{Collection, Compare, Db, Embedding, FilterBuilder, SimilarityResult};
use tinyvec_embed::similarity::Distance;
use tokenizers::{Encoding, Tokenizer, TruncationDirection};
use tokenizers::{Encoding, Tokenizer, TruncationDirection, PaddingStrategy, PaddingDirection, PaddingParams};
use tokio::io::AsyncReadExt;
use tokio::task::spawn_blocking;
use tokio::task::{spawn_blocking};
use tokio::time::Instant;
use tower_lsp::lsp_types::notification::Progress;
use tower_lsp::lsp_types::{
NumberOrString, ProgressParams, ProgressParamsValue, Range, WorkDoneProgress,
WorkDoneProgressReport,
};
use std::iter::zip;
use tower_lsp::Client;
use tracing::{debug, error, warn};

Expand Down Expand Up @@ -156,9 +157,14 @@ async fn build_model_and_tokenizer(
let config = tokio::fs::read_to_string(config_filename).await?;
let config: Config = serde_json::from_str(&config)?;
let mut tokenizer: Tokenizer = Tokenizer::from_file(tokenizer_filename)?;
tokenizer.with_padding(None);
tokenizer.with_padding(Some(PaddingParams { strategy: PaddingStrategy::BatchLongest,
direction: PaddingDirection::Right,
pad_to_multiple_of: Some(8),
// TODO: use values provided in model config
pad_id: 0,
pad_type_id: 0,
pad_token: "<pad>".to_string()}));
McPatate marked this conversation as resolved.
Show resolved Hide resolved
tokenizer.with_truncation(None)?;

let vb = VarBuilder::from_pth(&weights_filename, DTYPE, &device)?;
let model = BertModel::load(vb, &config)?;
debug!(
Expand Down Expand Up @@ -191,6 +197,8 @@ fn device(cpu: bool) -> Result<Device> {
pub(crate) struct Snippet {
pub(crate) file_url: String,
pub(crate) code: String,
pub(crate) start_line: usize,
pub(crate) end_line: usize,
}

impl TryFrom<&SimilarityResult> for Snippet {
Expand All @@ -210,7 +218,15 @@ impl TryFrom<&SimilarityResult> for Snippet {
.get("snippet")
.ok_or_else(|| Error::MalformattedEmbeddingMetadata("snippet".to_owned()))?
.inner_string()?;
Ok(Snippet { file_url, code })
let start_line = meta
.get("start_line_no")
.ok_or_else(|| Error::MalformattedEmbeddingMetadata("snippet".to_owned()))?
.inner_value()?;
let end_line= meta
.get("start_line_no")
.ok_or_else(|| Error::MalformattedEmbeddingMetadata("snippet".to_owned()))?
.inner_value()?;
Ok(Snippet { file_url, code, start_line, end_line })
}
}

Expand Down Expand Up @@ -280,6 +296,7 @@ impl SnippetRetriever {
workspace_root: &str,
) -> Result<()> {
debug!("building workspace snippets");
let start = Instant::now();
let workspace_root = PathBuf::from(workspace_root);
if self.db.is_none() {
self.initialise_database(&format!(
Expand Down Expand Up @@ -360,7 +377,7 @@ impl SnippetRetriever {
})
.await;
}

debug!("Built workspace snippets in {} ms", start.elapsed().as_millis());
Ok(())
}

Expand All @@ -384,15 +401,15 @@ impl SnippetRetriever {
snippet: String,
strategy: BuildFrom,
) -> Result<Vec<f32>> {
match strategy {
let result = match strategy {
BuildFrom::Start => {
let mut encoding = self.tokenizer.encode(snippet.clone(), true)?;
encoding.truncate(
self.model_config.max_input_size,
1,
TruncationDirection::Right,
);
self.generate_embedding(encoding, self.model.clone()).await
self.generate_embeddings(vec![encoding], self.model.clone()).await?
}
BuildFrom::Cursor { cursor_position } => {
let (before, after) = snippet.split_at(cursor_position);
Expand All @@ -404,8 +421,8 @@ impl SnippetRetriever {
before_encoding.take_overflowing();
after_encoding.take_overflowing();
before_encoding.merge_with(after_encoding, false);
self.generate_embedding(before_encoding, self.model.clone())
.await
self.generate_embeddings(vec![before_encoding], self.model.clone())
.await?
}
BuildFrom::End => {
let mut encoding = self.tokenizer.encode(snippet.clone(), true)?;
Expand All @@ -414,9 +431,14 @@ impl SnippetRetriever {
1,
TruncationDirection::Left,
);
self.generate_embedding(encoding, self.model.clone()).await
self.generate_embeddings(vec![encoding], self.model.clone()).await?
}
}
};
let first_embedding = match result.first() {
Some(n) => n.clone(),
_ => vec![]
};
McPatate marked this conversation as resolved.
Show resolved Hide resolved
Ok(first_embedding)
}

pub(crate) async fn search(
Expand All @@ -428,7 +450,7 @@ impl SnippetRetriever {
Some(db) => db.clone(),
None => return Err(Error::UninitialisedDatabase),
};
let col = db.get_collection(&self.collection_name).await?;
let col = db.get_collection("code-slices").await?;
McPatate marked this conversation as resolved.
Show resolved Hide resolved
let result = col
.read()
.await
Expand Down Expand Up @@ -477,21 +499,26 @@ impl SnippetRetriever {

impl SnippetRetriever {
// TODO: handle overflowing in Encoding
async fn generate_embedding(
/// Embedding order is preserved and stays the same as encoding input
async fn generate_embeddings(
McPatate marked this conversation as resolved.
Show resolved Hide resolved
&self,
encoding: Encoding,
encodings: Vec<Encoding>,
model: Arc<BertModel>,
) -> Result<Vec<f32>> {
) -> Result<Vec<Vec<f32>>> {
let start = Instant::now();
let embedding = spawn_blocking(move || -> Result<Vec<f32>> {
let tokens = encoding.get_ids().to_vec();
let token_ids = Tensor::new(&tokens[..], &model.device)?.unsqueeze(0)?;
let embedding = spawn_blocking(move || -> Result<Vec<Vec<f32>>> {
let tokens = encodings
.iter()
.map(|elem| {
Ok(Tensor::new(elem.get_ids().to_vec(), &model.device)?)
} )
.collect::<Result<Vec<_>>>()?;
let token_ids = Tensor::stack(&tokens, 0)?;
let token_type_ids = token_ids.zeros_like()?;
let embedding = model.forward(&token_ids, &token_type_ids)?;
let (_n_sentence, n_tokens, _hidden_size) = embedding.dims3()?;
let embedding = (embedding.sum(1)? / (n_tokens as f64))?;
let embedding = embedding.get(0)?.to_vec1::<f32>()?;
Ok(embedding)
Ok(embedding.to_vec2::<f32>()?)
})
.await?;
debug!("embedding generated in {} ms", start.elapsed().as_millis());
Expand All @@ -512,6 +539,8 @@ impl SnippetRetriever {
let file = tokio::fs::read_to_string(&file_url).await?;
let lines = file.split('\n').collect::<Vec<_>>();
let end = end.unwrap_or(lines.len()).min(lines.len());
let mut snippets: Vec<Snippet> = Vec::new();
debug!("Building embeddings for {file_url}");
for start_line in (start..end).step_by(self.window_step) {
let end_line = (start_line + self.window_size - 1).min(lines.len());
if !col
Expand All @@ -538,35 +567,60 @@ impl SnippetRetriever {
let window = lines[start_line..end_line].to_vec();
let snippet = window.join("\n");
if snippet.is_empty() {
continue;
debug!("snippet {file_url}[{start_line}, {end_line}] empty");
continue;
}
snippets.push(Snippet{ file_url: file_url.clone().into(), code: snippet, start_line, end_line });
}
{
let nb_snippets = snippets.len();
let steps = self.window_step;
debug!("Build {nb_snippets} snippets for {file_url}: {start}, {end}, {steps}");
}

let mut encoding = self.tokenizer.encode(snippet.clone(), true)?;
encoding.truncate(
self.model_config.max_input_size,
1,
TruncationDirection::Right,
);
let result = self.generate_embedding(encoding, self.model.clone()).await;
let embedding = match result {
Ok(e) => e,
Err(err) => {
error!(
"error generating embedding for {file_url}[{start_line}, {end_line}]: {err}",
);
continue;
}
};
col.write().await.insert(Embedding::new(
embedding,
Some(HashMap::from([
("file_url".to_owned(), file_url.clone().into()),
("start_line_no".to_owned(), start_line.into()),
("end_line_no".to_owned(), end_line.into()),
("snippet".to_owned(), snippet.clone().into()),
])),
))?;
// Group by length to reduce padding effect
let snippets = spawn_blocking(|| -> Result<Vec<Snippet>> {
snippets.sort_unstable_by(|first, second| first.code.len().cmp(&second.code.len()));
Ok(snippets)
}).await?;

// TODO: improvements to compute an efficient batch size:
// - batch size should be relative to the cumulative size of all elements in the batch,
// Set embedding_batch_size to 8 if device is GPU, use match
let embedding_batch_size = match self.model.device {
Device::Cpu => 2,
_ => 8,
};
for batch in snippets?.chunks(embedding_batch_size) {
let batch_code = batch
.iter()
.map(|snippet| snippet.code.clone())
.collect();
let encodings = self.tokenizer
.encode_batch(batch_code, true)?
.iter_mut()
.map(|encoding| {
encoding.truncate(512, 1, TruncationDirection::Right);
encoding.clone()
})
.collect();
let results = self.generate_embeddings(encodings, self.model.clone()).await?;
col.write().await.batch_insert(zip(results, batch).map(|item| {
Wats0ns marked this conversation as resolved.
Show resolved Hide resolved
let (embedding, snippet) = item;
Embedding::new(
embedding,
Some(HashMap::from([
("file_url".to_owned(), snippet.file_url.clone().into()),
("start_line_no".to_owned(), snippet.start_line.into()),
("end_line_no".to_owned(), snippet.end_line.into()),
("snippet".to_owned(), snippet.code.clone().into()),
])))
}).collect::<Vec<Embedding>>()
)?;
}
db
.save()
.await?;
Ok(())
}
}
Expand Down
17 changes: 17 additions & 0 deletions crates/tinyvec-embed/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,14 @@ impl Collection {
Ok(())
}

pub fn batch_insert(&mut self, embeddings: Vec<Embedding>) -> Result<()> {
if embeddings.iter().any(|embedding| embedding.vector.len() != self.dimension) {
return Err(CollectionError::DimensionMismatch.into());
}
self.embeddings.extend(embeddings);
Ok(())
}

/// Remove values matching filter.
///
/// Empties the collection when `filter` is `None`.
Expand Down Expand Up @@ -224,13 +232,15 @@ impl Eq for Embedding {}
pub enum Value {
String(String),
Number(f32),
Usize(usize),
McPatate marked this conversation as resolved.
Show resolved Hide resolved
}

impl Display for Value {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::String(s) => write!(f, "{s}"),
Self::Number(n) => write!(f, "{n}"),
Self::Usize(u) => write!(f, "{u}"),
}
}
}
Expand All @@ -242,6 +252,13 @@ impl Value {
_ => Err(Error::ValueNotString(self.to_owned())),
}
}

pub fn inner_value(&self) -> Result<usize> {
match self {
Self::Usize(s) => Ok(s.to_owned()),
_ => Err(Error::ValueNotString(self.to_owned())),
}
}
McPatate marked this conversation as resolved.
Show resolved Hide resolved
}

impl From<usize> for Value {
Expand Down
2 changes: 2 additions & 0 deletions crates/tinyvec-embed/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ pub enum Error {
InvalidFileName,
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("expected value to be a valid number, got: {0}")]
ValueNotNumber(Value),
#[error("expected value to be string, got: {0}")]
ValueNotString(Value),
}
Expand Down