From 38f945cda247a7d91d7ee4e76e8ed27747fac619 Mon Sep 17 00:00:00 2001 From: Akshay Ballal Date: Thu, 17 Oct 2024 23:51:25 +0200 Subject: [PATCH] add sparse embeddings --- Cargo.lock | 1 - Dockerfile | 19 +- examples/splade.py | 49 ++++ python/Cargo.toml | 3 +- .../python/embed_anything/_embed_anything.pyi | 1 + python/src/lib.rs | 43 ++- python/src/models/colpali.rs | 6 +- python/src/models/mod.rs | 2 +- rust/Cargo.toml | 3 +- rust/examples/bert_ort.rs | 16 +- rust/examples/cloud.rs | 12 +- rust/examples/colpali.rs | 4 +- rust/examples/splade.rs | 77 ++++++ rust/src/chunkers/cumulative.rs | 6 +- rust/src/chunkers/statistical.rs | 11 +- rust/src/embeddings/cloud/cohere.rs | 2 +- rust/src/embeddings/cloud/openai.rs | 2 +- rust/src/embeddings/embed.rs | 76 ++--- rust/src/embeddings/local/bert.rs | 261 ++++++++++++------ rust/src/embeddings/local/clip.rs | 12 +- rust/src/embeddings/local/colpali.rs | 53 ++-- rust/src/embeddings/local/jina.rs | 20 +- rust/src/embeddings/local/mod.rs | 2 +- rust/src/embeddings/local/text_embedding.rs | 3 - rust/src/embeddings/mod.rs | 1 + rust/src/embeddings/utils.rs | 41 +++ rust/src/file_processor/website_processor.rs | 13 +- rust/src/lib.rs | 28 +- rust/src/models/bert.rs | 97 +++++++ rust/src/text_loader.rs | 12 +- tests/model_tests/test_openai.py | 2 +- 31 files changed, 619 insertions(+), 259 deletions(-) create mode 100644 examples/splade.py create mode 100644 rust/examples/splade.rs create mode 100644 rust/src/embeddings/utils.rs diff --git a/Cargo.lock b/Cargo.lock index 5a61fb0..42514a6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -812,7 +812,6 @@ dependencies = [ "candle-nn", "candle-transformers", "chrono", - "cudarc", "futures", "hf-hub", "image", diff --git a/Dockerfile b/Dockerfile index 603d848..f66abae 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,11 +11,14 @@ COPY --from=planner /app/recipe.json recipe.json RUN cargo chef cook --release --recipe-path recipe.json # Build application -RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null \ - && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list \ - && apt-get update \ +# Download Intel GPG key and add repository +RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null \ + && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list + +# Install Intel MKL and extract libiomp5.so +RUN apt-get update \ && apt-get install -y intel-oneapi-mkl-devel \ - && export LD_LIBRARY_PATH="/opt/intel/oneapi/compiler/2024.2/lib:$LD_LIBRARY_PATH" + && cp /opt/intel/oneapi/compiler/2024.2/lib/libiomp5.so /app/libiomp5.so RUN apt-get install libssl-dev pkg-config python3-full python3-pip -y RUN pip3 install maturin[patchelf] --break-system-packages @@ -27,11 +30,15 @@ FROM python:3.11-slim WORKDIR /app -COPY --from=builder /app/target/wheels . +# Copy the extracted libiomp5.so from the builder stage +COPY --from=builder /app/libiomp5.so /usr/lib/ + +# Set the library path +ENV LD_LIBRARY_PATH="/usr/lib:$LD_LIBRARY_PATH" COPY . . -RUN pip install *.whl +RUN pip install target/wheels/*.whl RUN pip install numpy pillow pytest diff --git a/examples/splade.py b/examples/splade.py new file mode 100644 index 0000000..19fa871 --- /dev/null +++ b/examples/splade.py @@ -0,0 +1,49 @@ +import embed_anything +from embed_anything import EmbedData, EmbeddingModel, WhichModel, embed_query +from embed_anything.vectordb import Adapter +import os +from time import time +import numpy as np +import heapq + + +model = EmbeddingModel.from_pretrained_hf( + WhichModel.SparseBert, "prithivida/Splade_PP_en_v1" +) + +sentences = [ + "The cat sits outside", + "A man is playing guitar", + "I love pasta", + "The new movie is awesome", + "The cat plays in the garden", + "A woman watches TV", + "The new movie is so great", + "Do you like pizza?", +] + +embedddings = embed_query(sentences, embeder=model) + +embed_vector = np.array([e.embedding for e in embedddings]) + +similarities = np.matmul(embed_vector, embed_vector.T) + +# get top 5 similarities and show the two sentences and their similarity scores +# Flatten the upper triangle of the similarity matrix, excluding the diagonal +similarity_scores = [ + (similarities[i, j], i, j) + for i in range(len(sentences)) + for j in range(i + 1, len(sentences)) +] + +# Get the top 5 similarity scores +top_5_similarities = heapq.nlargest(5, similarity_scores, key=lambda x: x[0]) + +# Print the top 5 similarities with sentences +for score, i, j in top_5_similarities: + print(f"Score: {score:.2} | {sentences[i]} | {sentences[j]}") + + + + + diff --git a/python/Cargo.toml b/python/Cargo.toml index f5d252d..a13d4bc 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -8,7 +8,7 @@ name = "_embed_anything" crate-type = ["cdylib"] [dependencies] -embed_anything = {path = "../rust"} +embed_anything = {path = "../rust", features = ["load-dynamic"]} pyo3 = { version = "0.22.3"} tokio = { version = "1.39.0", features = ["rt-multi-thread"]} @@ -18,5 +18,4 @@ mkl = ["embed_anything/mkl"] accelerate = ["embed_anything/accelerate"] cuda = ["embed_anything/cuda"] cudnn = ["embed_anything/cudnn"] -load-dynamic = ["embed_anything/load-dynamic"] diff --git a/python/python/embed_anything/_embed_anything.pyi b/python/python/embed_anything/_embed_anything.pyi index 3d6a3a8..e7712ff 100644 --- a/python/python/embed_anything/_embed_anything.pyi +++ b/python/python/embed_anything/_embed_anything.pyi @@ -471,3 +471,4 @@ class WhichModel(Enum): Jina = ("Jina",) Clip = ("Clip",) Colpali = ("Colpali",) + SparseBert = ("SparseBert",) diff --git a/python/src/lib.rs b/python/src/lib.rs index 8f4fdf5..bbf308c 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -6,7 +6,7 @@ use embed_anything::{ self, config::TextEmbedConfig, emb_audio, - embeddings::embed::{EmbeddingResult, Embedder}, + embeddings::embed::{Embedder, EmbeddingResult}, file_processor::audio::audio_processor, text_loader::FileLoadingError, }; @@ -36,8 +36,8 @@ impl EmbedData { Python::with_gil(|py| { let embedding = self.inner.embedding.clone(); match embedding { - EmbeddingResult::Dense(x) => PyList::new_bound(py, x).into(), - EmbeddingResult::Sparse(x) => { + EmbeddingResult::DenseVector(x) => PyList::new_bound(py, x).into(), + EmbeddingResult::MultiVector(x) => { PyList::new_bound(py, x.iter().map(|inner| PyList::new_bound(py, inner))).into() } } @@ -84,6 +84,7 @@ pub enum WhichModel { OpenAI, Cohere, Bert, + SparseBert, Clip, Jina, Colpali, @@ -99,6 +100,7 @@ impl From<&str> for WhichModel { "openai" | "OpenAI" => WhichModel::OpenAI, "cohere" | "Cohere" => WhichModel::Cohere, "bert" | "Bert" => WhichModel::Bert, + "sparse-bert" | "SparseBert" => WhichModel::SparseBert, "clip" | "Clip" => WhichModel::Clip, "jina" | "Jina" => WhichModel::Jina, "colpali" | "Colpali" => WhichModel::Colpali, @@ -113,6 +115,7 @@ impl From for WhichModel { "openai" | "OpenAI" => WhichModel::OpenAI, "cohere" | "Cohere" => WhichModel::Cohere, "bert" | "Bert" => WhichModel::Bert, + "sparse-bert" | "SparseBert" => WhichModel::SparseBert, "clip" | "Clip" => WhichModel::Clip, "jina" | "Jina" => WhichModel::Jina, "colpali" | "Colpali" => WhichModel::Colpali, @@ -150,12 +153,25 @@ impl EmbeddingModel { inner: Arc::new(model), }) } + WhichModel::SparseBert => { + let model_id = model_id.unwrap_or("prithivida/Splade_PP_en_v1"); + let model = Embedder::Text(TextEmbedder::Bert(Box::new( + embed_anything::embeddings::local::bert::SparseBertEmbedder::new( + model_id.to_string(), + revision.map(|s| s.to_string()), + ) + .unwrap(), + ))); + Ok(EmbeddingModel { + inner: Arc::new(model), + }) + } WhichModel::Clip => { let model_id = model_id.unwrap_or("openai/clip-vit-base-patch32"); let model = Embedder::Vision(VisionEmbedder::Clip( embed_anything::embeddings::local::clip::ClipEmbedder::new( model_id.to_string(), - revision.map(|s| s.to_string()), + revision, ) .map_err(|e| PyValueError::new_err(e.to_string()))?, )); @@ -166,11 +182,8 @@ impl EmbeddingModel { WhichModel::Jina => { let model_id = model_id.unwrap_or("jinaai/jina-embeddings-v2-small-en"); let model = Embedder::Text(TextEmbedder::Jina( - embed_anything::embeddings::local::jina::JinaEmbedder::new( - model_id.to_string(), - revision.map(|s| s.to_string()), - ) - .unwrap(), + embed_anything::embeddings::local::jina::JinaEmbedder::new(model_id, revision) + .unwrap(), )); Ok(EmbeddingModel { inner: Arc::new(model), @@ -178,9 +191,9 @@ impl EmbeddingModel { } WhichModel::Colpali => { let model_id = model_id.unwrap_or("vidore/colpali-v1.2-merged"); - let model = Embedder::Vision(VisionEmbedder::ColPali(embed_anything::embeddings::local::colpali::ColPaliEmbedder::new( - model_id, - revision.map(|s| s), + let model = Embedder::Vision(VisionEmbedder::ColPali( + embed_anything::embeddings::local::colpali::ColPaliEmbedder::new( + model_id, revision, ) .unwrap(), )); @@ -188,7 +201,7 @@ impl EmbeddingModel { inner: Arc::new(model), }) } - + _ => panic!("Invalid model"), } } @@ -358,7 +371,7 @@ pub fn embed_file( let embeddings = rt .block_on(async { - embed_anything::embed_file(file_name, &embedding_model, config, adapter).await + embed_anything::embed_file(file_name, embedding_model, config, adapter).await }) .map_err(|e| match e.downcast_ref::() { Some(FileLoadingError::FileNotFound(file)) => { @@ -390,7 +403,7 @@ pub fn embed_audio_file( let audio_decoder = &mut audio_decoder.inner; let rt = Builder::new_multi_thread().enable_all().build().unwrap(); let data = rt.block_on(async { - emb_audio(audio_file, audio_decoder, &embedding_model, config) + emb_audio(audio_file, audio_decoder, embedding_model, config) .await .map_err(|e| PyValueError::new_err(e.to_string())) .unwrap() diff --git a/python/src/models/colpali.rs b/python/src/models/colpali.rs index 9519d80..28bd126 100644 --- a/python/src/models/colpali.rs +++ b/python/src/models/colpali.rs @@ -11,7 +11,6 @@ pub struct ColpaliModel { #[pymethods] impl ColpaliModel { - #[new] #[pyo3(signature = (model_id, revision=None))] pub fn new(model_id: &str, revision: Option<&str>) -> PyResult { @@ -40,7 +39,10 @@ impl ColpaliModel { } pub fn embed_query(&self, query: &str) -> PyResult> { - let embed_data = self.model.embed_query(query).map_err(|e| PyValueError::new_err(e.to_string()))?; + let embed_data = self + .model + .embed_query(query) + .map_err(|e| PyValueError::new_err(e.to_string()))?; Ok(embed_data .into_iter() .map(|data| EmbedData { inner: data }) diff --git a/python/src/models/mod.rs b/python/src/models/mod.rs index d5f34c6..b3fdcbd 100644 --- a/python/src/models/mod.rs +++ b/python/src/models/mod.rs @@ -1 +1 @@ -pub mod colpali; \ No newline at end of file +pub mod colpali; diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 883d3b2..bfce747 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -69,7 +69,6 @@ byteorder = "1.5.0" futures = "0.3.30" pdf-extract = {workspace = true} -cudarc = { version = "0.12.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false } ort = {workspace = true} ndarray = "0.16.1" ndarray-linalg = {version = "0.16.0"} @@ -93,5 +92,5 @@ default = [] mkl = ["dep:intel-mkl-src", "candle-nn/mkl", "candle-transformers/mkl", "candle-core/mkl"] accelerate = ["dep:accelerate-src", "candle-core/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"] cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-core/cuda"] -cudnn = ["candle-core/cuda", "cudarc/cudnn"] +cudnn = ["candle-core/cudnn"] load-dynamic = ["ort/load-dynamic"] diff --git a/rust/examples/bert_ort.rs b/rust/examples/bert_ort.rs index 83fbdc3..313ff12 100644 --- a/rust/examples/bert_ort.rs +++ b/rust/examples/bert_ort.rs @@ -11,7 +11,6 @@ use std::time::Instant; #[tokio::main] async fn main() -> Result<(), anyhow::Error> { let model = - Arc::new(Embedder::from_pretrained_onnx("bert", ONNXModel::AllMiniLML6V2, None).unwrap()); let config = TextEmbedConfig::new( Some(1000), @@ -29,17 +28,13 @@ async fn main() -> Result<(), anyhow::Error> { let now = Instant::now(); - let futures = files .par_iter() - .map(|file| { - embed_file(file, &model, Some(&config), None::)>) - }) + .map(|file| embed_file(file, &model, Some(&config), None::)>)) .collect::>(); let _data = futures.into_iter().next().unwrap().await; - let elapsed_time = now.elapsed(); println!("Elapsed Time: {}", elapsed_time.as_secs_f32()); @@ -52,7 +47,6 @@ async fn main() -> Result<(), anyhow::Error> { "The dog is sitting in the park", "The window is broken", "pizza is the best", - ] .iter() .map(|s| s.to_string()) @@ -71,7 +65,10 @@ async fn main() -> Result<(), anyhow::Error> { .map(|x| x.to_dense().unwrap()) .flatten() .collect::>(), - (n_vectors, doc_embeddings[0].embedding.to_dense().unwrap().len()), + ( + n_vectors, + doc_embeddings[0].embedding.to_dense().unwrap().len(), + ), &Device::Cpu, ) .unwrap(); @@ -94,5 +91,4 @@ async fn main() -> Result<(), anyhow::Error> { } Ok(()) - -} \ No newline at end of file +} diff --git a/rust/examples/cloud.rs b/rust/examples/cloud.rs index f0aba49..ec6d13d 100644 --- a/rust/examples/cloud.rs +++ b/rust/examples/cloud.rs @@ -4,22 +4,13 @@ use embed_anything::{ config::TextEmbedConfig, embed_directory_stream, embed_file, embeddings::embed::{EmbedData, Embedder}, - text_loader::SplittingStrategy, }; use anyhow::Result; #[tokio::main] async fn main() -> Result<()> { - let semantic_encoder = - Embedder::from_pretrained_cloud("openai", "text-embedding-3-small", None).unwrap(); - let text_embed_config = TextEmbedConfig::new( - Some(1000), - Some(256), - Some(32), - None, - None, - ); + let text_embed_config = TextEmbedConfig::new(Some(1000), Some(512), Some(512), None, None); let cohere_model = Embedder::from_pretrained_cloud("cohere", "embed-english-v3.0", None).unwrap(); let openai_model = @@ -44,7 +35,6 @@ async fn main() -> Result<()> { .await? .unwrap(); - let _cohere_embedding = embed_file( "test_files/attention.pdf", &cohere_model, diff --git a/rust/examples/colpali.rs b/rust/examples/colpali.rs index f2d54f4..e674198 100644 --- a/rust/examples/colpali.rs +++ b/rust/examples/colpali.rs @@ -1,8 +1,7 @@ use embed_anything::embeddings::local::colpali::ColPaliEmbedder; fn main() -> Result<(), anyhow::Error> { - - let colpali_model = ColPaliEmbedder::new( "vidore/colpali-v1.2-merged", None)?; + let colpali_model = ColPaliEmbedder::new("vidore/colpali-v1.2-merged", None)?; let file_path = "test_files/attention.pdf"; let batch_size = 1; let embed_data = colpali_model.embed_file(file_path, batch_size)?; @@ -12,5 +11,4 @@ fn main() -> Result<(), anyhow::Error> { let query_embeddings = colpali_model.embed_query(prompt)?; println!("{:?}", query_embeddings.len()); Ok(()) - } diff --git a/rust/examples/splade.rs b/rust/examples/splade.rs new file mode 100644 index 0000000..1e9257a --- /dev/null +++ b/rust/examples/splade.rs @@ -0,0 +1,77 @@ +use candle_core::{Device, Tensor}; +use embed_anything::{ + config::TextEmbedConfig, + embed_query, + embeddings::embed::{Embedder, TextEmbedder}, + text_loader::SplittingStrategy, +}; +use std::sync::Arc; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let model = Arc::new(Embedder::Text( + TextEmbedder::from_pretrained_hf("sparse-bert", "prithivida/Splade_PP_en_v1", None) + .unwrap(), + )); + + let config = TextEmbedConfig::new( + Some(256), + Some(32), + Some(32), + Some(SplittingStrategy::Sentence), + Some(model.clone()), + ); + + let sentences = [ + "The cat sits outside", + "A man is playing guitar", + "I love pasta", + "The new movie is awesome", + "The cat plays in the garden", + "A woman watches TV", + "The new movie is so great", + "Do you like pizza?", + ] + .iter() + .map(|x| x.to_string()) + .collect::>(); + + let n_sentences = sentences.len(); + + let out = embed_query(sentences.clone(), &model, Some(&config)) + .await + .unwrap(); + + let embeddings = out + .iter() + .map(|embed| embed.embedding.to_dense().unwrap()) + .flatten() + .collect::>(); + + let embeddings_tensor = Tensor::from_vec( + embeddings.clone(), + (n_sentences, out[0].embedding.to_dense().unwrap().len()), + &Device::Cpu, + ) + .unwrap(); + + let mut similarities = vec![]; + for i in 0..n_sentences { + let e_i = embeddings_tensor.get(i).unwrap(); + for j in (i + 1)..n_sentences { + let e_j = embeddings_tensor.get(j).unwrap(); + let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::()?; + let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::()?; + let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::()?; + let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt(); + similarities.push((cosine_similarity, i, j)) + } + } + println!("similarities: {:?}", similarities); + similarities.sort_by(|u, v| v.0.total_cmp(&u.0)); + for &(score, i, j) in similarities[..5].iter() { + println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j]) + } + + Ok(()) +} diff --git a/rust/src/chunkers/cumulative.rs b/rust/src/chunkers/cumulative.rs index cfc3d15..7856f70 100644 --- a/rust/src/chunkers/cumulative.rs +++ b/rust/src/chunkers/cumulative.rs @@ -148,8 +148,8 @@ impl CumulativeChunker { mod tests { use super::*; - #[test] - fn test_cumulative_chunker() { + #[tokio::test] + async fn test_cumulative_chunker() { let text = " The Bank of Elarian Nestled in the heart of the bustling city of Elarian, the Bank of Elarian stands as a beacon of financial stability and innovative banking. Founded over a century ago, this bank has grown from a modest community institution into one of the city's most trusted financial centers. Known for its majestic architecture, the building's facade is a blend of classical and modern design, featuring high pillars and sleek glass panels that reflect the city's skyline. @@ -165,6 +165,6 @@ Elarian Freiseur also places a high emphasis on using eco-friendly and sustainab "; let chunker = CumulativeChunker::default(); - chunker._chunk(text); + chunker._chunk(text).await; } } diff --git a/rust/src/chunkers/statistical.rs b/rust/src/chunkers/statistical.rs index 79443a6..901b6bc 100644 --- a/rust/src/chunkers/statistical.rs +++ b/rust/src/chunkers/statistical.rs @@ -1,6 +1,9 @@ use std::{cmp::max, sync::Arc}; -use crate::embeddings::{embed::{Embedder, TextEmbedder}, local::jina::JinaEmbedder}; +use crate::embeddings::{ + embed::{Embedder, TextEmbedder}, + local::jina::JinaEmbedder, +}; use candle_core::Tensor; use itertools::{enumerate, Itertools}; // use text_splitter::{ChunkConfig, TextSplitter}; @@ -128,11 +131,7 @@ impl StatisticalChunker { .collect::>(); } - let encoded_splits = self - .encoder - .embed(&batch_splits, Some(16)) - .await - .unwrap(); + let encoded_splits = self.encoder.embed(&batch_splits, Some(16)).await.unwrap(); let encoded_splits = encoded_splits .into_iter() .map(|x| x.to_dense().unwrap()) diff --git a/rust/src/embeddings/cloud/cohere.rs b/rust/src/embeddings/cloud/cohere.rs index cb739be..def8f4a 100644 --- a/rust/src/embeddings/cloud/cohere.rs +++ b/rust/src/embeddings/cloud/cohere.rs @@ -77,7 +77,7 @@ impl CohereEmbedder { let encodings = encodings .iter() - .map(|embedding| EmbeddingResult::Dense(embedding.clone())) + .map(|embedding| EmbeddingResult::DenseVector(embedding.clone())) .collect::>(); Ok(encodings) diff --git a/rust/src/embeddings/cloud/openai.rs b/rust/src/embeddings/cloud/openai.rs index 92462e6..fca2d39 100644 --- a/rust/src/embeddings/cloud/openai.rs +++ b/rust/src/embeddings/cloud/openai.rs @@ -74,7 +74,7 @@ impl OpenAIEmbedder { let encodings = data .data .iter() - .map(|data| EmbeddingResult::Dense(data.embedding.clone())) + .map(|data| EmbeddingResult::DenseVector(data.embedding.clone())) .collect::>(); Ok(encodings) diff --git a/rust/src/embeddings/embed.rs b/rust/src/embeddings/embed.rs index 1fdc386..800210f 100644 --- a/rust/src/embeddings/embed.rs +++ b/rust/src/embeddings/embed.rs @@ -2,38 +2,38 @@ use crate::file_processor::audio::audio_processor::Segment; use super::cloud::cohere::CohereEmbedder; use super::cloud::openai::OpenAIEmbedder; -use super::local::bert::{BertEmbed, BertEmbedder, OrtBertEmbedder}; +use super::local::bert::{BertEmbed, BertEmbedder, OrtBertEmbedder, SparseBertEmbedder}; use super::local::clip::ClipEmbedder; use super::local::colpali::ColPaliEmbedder; use super::local::jina::JinaEmbedder; use super::local::text_embedding::ONNXModel; +use anyhow::anyhow; use serde::Deserialize; use std::collections::HashMap; -use anyhow::anyhow; #[derive(Deserialize, Debug, Clone)] pub enum EmbeddingResult { - Dense(Vec), - Sparse(Vec>), + DenseVector(Vec), + MultiVector(Vec>), } impl From> for EmbeddingResult { fn from(value: Vec) -> Self { - EmbeddingResult::Dense(value) + EmbeddingResult::DenseVector(value) } } impl From>> for EmbeddingResult { fn from(value: Vec>) -> Self { - EmbeddingResult::Sparse(value) + EmbeddingResult::MultiVector(value) } } impl EmbeddingResult { pub fn to_dense(&self) -> Result, anyhow::Error> { match self { - EmbeddingResult::Dense(x) => Ok(x.to_vec()), - EmbeddingResult::Sparse(_) => Err(anyhow!( + EmbeddingResult::DenseVector(x) => Ok(x.to_vec()), + EmbeddingResult::MultiVector(_) => Err(anyhow!( "Sparse Embedding are not supported for this operation" )), } @@ -41,8 +41,8 @@ impl EmbeddingResult { pub fn to_sparse(&self) -> Result>, anyhow::Error> { match self { - EmbeddingResult::Sparse(x) => Ok(x.to_vec()), - EmbeddingResult::Dense(_) => Err(anyhow!( + EmbeddingResult::MultiVector(x) => Ok(x.to_vec()), + EmbeddingResult::DenseVector(_) => Err(anyhow!( "Dense Embedding are not supported for this operation" )), } @@ -111,15 +111,15 @@ impl TextEmbedder { revision: Option<&str>, ) -> Result { match model { - "jina" | "Jina" => Ok(Self::Jina(JinaEmbedder::new( - model_id.to_string(), - revision.map(|s| s.to_string()), - )?)), + "jina" | "Jina" => Ok(Self::Jina(JinaEmbedder::new(model_id, revision)?)), "Bert" | "bert" => Ok(Self::Bert(Box::new(BertEmbedder::new( model_id.to_string(), revision.map(|s| s.to_string()), )?))), + "sparse-bert" | "SparseBert" | "SPARSE-BERT" => Ok(Self::Bert(Box::new( + SparseBertEmbedder::new(model_id.to_string(), revision.map(|s| s.to_string()))?, + ))), _ => Err(anyhow::anyhow!("Model not supported")), } } @@ -175,7 +175,6 @@ impl TextEmbedder { } } - pub enum VisionEmbedder { Clip(ClipEmbedder), ColPali(ColPaliEmbedder), @@ -187,7 +186,7 @@ impl From for Embedder { } } -impl From for VisionEmbedder{ +impl From for VisionEmbedder { fn from(value: Embedder) -> Self { match value { Embedder::Vision(value) => value, @@ -196,7 +195,7 @@ impl From for VisionEmbedder{ } } -impl From for TextEmbedder{ +impl From for TextEmbedder { fn from(value: Embedder) -> Self { match value { Embedder::Text(value) => value, @@ -214,12 +213,11 @@ impl VisionEmbedder { match model { "clip" | "Clip" | "CLIP" => Ok(Self::Clip(ClipEmbedder::new( model_id.to_string(), - revision.map(|s| s.to_string()), - )?)), - "colpali" | "ColPali" | "COLPALI" => Ok(Self::ColPali(ColPaliEmbedder::new( - model_id, - revision.map(|s| s), + revision, )?)), + "colpali" | "ColPali" | "COLPALI" => { + Ok(Self::ColPali(ColPaliEmbedder::new(model_id, revision)?)) + } _ => Err(anyhow::anyhow!("Model not supported")), } } @@ -248,10 +246,18 @@ impl Embedder { revision: Option<&str>, ) -> Result { match model { - "clip" | "Clip" | "CLIP" => Ok(Self::Vision(VisionEmbedder::from_pretrained_hf(model, model_id, revision)?)), - "colpali" | "ColPali" | "COLPALI" => Ok(Self::Vision(VisionEmbedder::from_pretrained_hf(model, model_id, revision)?)), - "bert" | "Bert" => Ok(Self::Text(TextEmbedder::from_pretrained_hf(model, model_id, revision)?)), - "jina" | "Jina" => Ok(Self::Text(TextEmbedder::from_pretrained_hf(model, model_id, revision)?)), + "clip" | "Clip" | "CLIP" => Ok(Self::Vision(VisionEmbedder::from_pretrained_hf( + model, model_id, revision, + )?)), + "colpali" | "ColPali" | "COLPALI" => Ok(Self::Vision( + VisionEmbedder::from_pretrained_hf(model, model_id, revision)?, + )), + "bert" | "Bert" => Ok(Self::Text(TextEmbedder::from_pretrained_hf( + model, model_id, revision, + )?)), + "jina" | "Jina" => Ok(Self::Text(TextEmbedder::from_pretrained_hf( + model, model_id, revision, + )?)), _ => Err(anyhow::anyhow!("Model not supported")), } } @@ -262,8 +268,12 @@ impl Embedder { api_key: Option, ) -> Result { match model { - "openai" | "OpenAI" => Ok(Self::Text(TextEmbedder::from_pretrained_cloud(model, model_id, api_key)?)), - "cohere" | "Cohere" => Ok(Self::Text(TextEmbedder::from_pretrained_cloud(model, model_id, api_key)?)), + "openai" | "OpenAI" => Ok(Self::Text(TextEmbedder::from_pretrained_cloud( + model, model_id, api_key, + )?)), + "cohere" | "Cohere" => Ok(Self::Text(TextEmbedder::from_pretrained_cloud( + model, model_id, api_key, + )?)), _ => Err(anyhow::anyhow!("Model not supported")), } } @@ -273,7 +283,11 @@ impl Embedder { model_name: ONNXModel, revision: Option<&str>, ) -> Result { - Ok(Self::Text(TextEmbedder::from_pretrained_ort(model_architecture, model_name, revision)?)) + Ok(Self::Text(TextEmbedder::from_pretrained_ort( + model_architecture, + model_name, + revision, + )?)) } } @@ -321,7 +335,6 @@ impl TextEmbed for VisionEmbedder { } } - pub trait EmbedImage { fn embed_image>( &self, @@ -332,7 +345,6 @@ pub trait EmbedImage { &self, image_paths: &[T], ) -> anyhow::Result>; - } impl EmbedImage for VisionEmbedder { @@ -356,6 +368,4 @@ impl EmbedImage for VisionEmbedder { Self::ColPali(embeder) => embeder.embed_image_batch(image_paths), } } - - } diff --git a/rust/src/embeddings/local/bert.rs b/rust/src/embeddings/local/bert.rs index 3230c47..bde2875 100644 --- a/rust/src/embeddings/local/bert.rs +++ b/rust/src/embeddings/local/bert.rs @@ -7,9 +7,10 @@ extern crate accelerate_src; use crate::embeddings::embed::EmbeddingResult; use crate::embeddings::local::text_embedding::{get_model_info_by_hf_id, models_map}; use crate::embeddings::normalize_l2; -use crate::models::bert::{BertModel, Config, HiddenAct, DTYPE}; +use crate::embeddings::utils::{get_attention_mask, tokenize_batch}; +use crate::models::bert::{BertForMaskedLM, BertModel, Config, DTYPE}; use anyhow::Error as E; -use candle_core::{Device, Tensor}; +use candle_core::{DType, Device, Tensor}; use candle_nn::VarBuilder; use hf_hub::{api::sync::Api, Repo}; use ndarray::prelude::*; @@ -43,9 +44,9 @@ pub struct OrtBertEmbedder { impl OrtBertEmbedder { pub fn new(model: ONNXModel, revision: Option) -> Result { - let model_info = models_map().get(&model).ok_or_else(|| { - E::msg("ONNX model does not exist for the specified model") - })?; + let model_info = models_map() + .get(&model) + .ok_or_else(|| E::msg("ONNX model does not exist for the specified model"))?; let pooling = model_info .model .get_default_pooling_method() @@ -74,14 +75,16 @@ impl OrtBertEmbedder { let tokenizer_config: TokenizerConfig = serde_json::from_str(&tokenizer_config)?; // Set max_length to the minimum of max_length and model_max_length if both are present - let max_length = match (tokenizer_config.max_length, tokenizer_config.model_max_length) { + let max_length = match ( + tokenizer_config.max_length, + tokenizer_config.model_max_length, + ) { (Some(max_len), Some(model_max_len)) => std::cmp::min(max_len, model_max_len), (Some(max_len), None) => max_len, (None, Some(model_max_len)) => model_max_len, (None, None) => 128, }; - let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let pp = PaddingParams { @@ -99,7 +102,7 @@ impl OrtBertEmbedder { .unwrap(); let cuda = CUDAExecutionProvider::default(); - + if !cuda.is_available()? { eprintln!("CUDAExecutionProvider is not available"); } else { @@ -108,8 +111,7 @@ impl OrtBertEmbedder { let threads = std::thread::available_parallelism().unwrap().get(); let model = Session::builder()? - .with_execution_providers([CUDAExecutionProvider::default() - .build()])? + .with_execution_providers([CUDAExecutionProvider::default().build()])? .with_optimization_level(GraphOptimizationLevel::Level3)? .with_intra_threads(threads)? .commit_from_file(weights_filename)?; @@ -143,11 +145,46 @@ impl OrtBertEmbedder { .unwrap(); Ok(token_ids_array) } +} - // fn reshape_into_2d_vector(&self, raw_data: Vec, dim: Option) -> Vec> { - // let dim = dim.expect("Dimension must be provided"); - // raw_data.chunks(dim).map(|chunk| chunk.to_vec()).collect() - // } +impl BertEmbed for OrtBertEmbedder { + fn embed( + &self, + text_batch: &[String], + batch_size: Option, + ) -> Result, E> { + let batch_size = batch_size.unwrap_or(32); + let encodings = text_batch + .par_chunks(batch_size) + .flat_map(|mini_text_batch| -> Result>, E> { + let token_ids: Array2 = self.tokenize_batch(mini_text_batch)?; + let token_type_ids: Array2 = Array2::zeros(token_ids.raw_dim()); + let attention_mask: Array2 = Array2::ones(token_ids.raw_dim()); + let outputs = + self.model + .run(ort::inputs![token_ids, token_type_ids, attention_mask]?)?; + let embeddings: Array3 = outputs["last_hidden_state"] + .try_extract_tensor::()? + .to_owned() + .into_dimensionality::()?; + let (_, _, _) = embeddings.dim(); + let embeddings = self + .pooling + .pool(&ModelOutput::Array(embeddings))? + .to_array()?; + let norms = embeddings.mapv(|x| x * x).sum_axis(Axis(1)).mapv(f32::sqrt); + let embeddings = &embeddings / &norms.insert_axis(Axis(1)); + + Ok(embeddings.outer_iter().map(|row| row.to_vec()).collect()) + }) + .flatten() + .collect::>(); + + Ok(encodings + .iter() + .map(|x| EmbeddingResult::DenseVector(x.to_vec())) + .collect()) + } } pub struct BertEmbedder { @@ -165,7 +202,10 @@ impl BertEmbedder { pub fn new(model_id: String, revision: Option) -> Result { let model_info = get_model_info_by_hf_id(&model_id); let pooling = match model_info { - Some(info) => info.model.get_default_pooling_method().unwrap_or(Pooling::Mean), + Some(info) => info + .model + .get_default_pooling_method() + .unwrap_or(Pooling::Mean), None => Pooling::Mean, }; @@ -196,7 +236,7 @@ impl BertEmbedder { (config, tokenizer, weights) }; let config = std::fs::read_to_string(config_filename)?; - let mut config: Config = serde_json::from_str(&config)?; + let config: Config = serde_json::from_str(&config)?; let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let pp = PaddingParams { @@ -218,16 +258,12 @@ impl BertEmbedder { let device = Device::cuda_if_available(0).unwrap_or(Device::Cpu); let vb = if weights_filename.ends_with("model.safetensors") { - unsafe { - VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? - } + unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? } } else { - println!("Loading weights from pytorch_model.bin"); + println!("Can't find model.safetensors, loading from pytorch_model.bin"); VarBuilder::from_pth(&weights_filename, DTYPE, &device)? }; - config.hidden_act = HiddenAct::GeluApproximate; - let model = BertModel::load(vb, &config)?; let tokenizer = tokenizer; @@ -237,60 +273,6 @@ impl BertEmbedder { pooling, }) } - fn tokenize_batch(&self, text_batch: &[String], device: &Device) -> anyhow::Result { - let tokens = self - .tokenizer - .encode_batch(text_batch.to_vec(), true) - .map_err(E::msg)?; - let token_ids = tokens - .iter() - .map(|tokens| { - let tokens = tokens.get_ids().to_vec(); - Tensor::new(tokens.as_slice(), device) - }) - .collect::>>()?; - - Ok(Tensor::stack(&token_ids, 0)?) - } -} - -impl BertEmbed for OrtBertEmbedder { - fn embed(&self, text_batch: &[String], batch_size: Option) -> Result, E> { - let batch_size = batch_size.unwrap_or(32); - let encodings = text_batch - .par_chunks(batch_size) - .flat_map(|mini_text_batch| -> Result>, E> { - let token_ids: Array2 = self.tokenize_batch(mini_text_batch)?; - let token_type_ids: Array2 = Array2::zeros(token_ids.raw_dim()); - let attention_mask: Array2 = Array2::ones(token_ids.raw_dim()); - let outputs = - self.model - .run(ort::inputs![token_ids, token_type_ids, attention_mask]?)?; - let embeddings: Array3 = outputs["last_hidden_state"] - .try_extract_tensor::()? - .to_owned() - .into_dimensionality::()?; - let (_, _, _) = embeddings.dim(); - let embeddings = self - .pooling - .pool(&ModelOutput::Array(embeddings))? - .to_array()?; - let norms = embeddings.mapv(|x| x * x).sum_axis(Axis(1)).mapv(f32::sqrt); - let embeddings = &embeddings / &norms.insert_axis(Axis(1)); - - Ok(embeddings - .outer_iter() - .map(|row| row.to_vec()) - .collect()) - }) - .flatten() - .collect::>(); - - Ok(encodings - .iter() - .map(|x| EmbeddingResult::Dense(x.to_vec())) - .collect()) - } } impl BertEmbed for BertEmbedder { @@ -303,9 +285,8 @@ impl BertEmbed for BertEmbedder { let mut encodings: Vec = Vec::new(); for mini_text_batch in text_batch.chunks(batch_size) { - let token_ids = self - .tokenize_batch(mini_text_batch, &self.model.device) - .unwrap(); + let token_ids = + tokenize_batch(&self.tokenizer, mini_text_batch, &self.model.device).unwrap(); let token_type_ids = token_ids.zeros_like().unwrap(); let embeddings: Tensor = self .model @@ -319,9 +300,129 @@ impl BertEmbed for BertEmbedder { let embeddings = normalize_l2(&pooled_output).unwrap(); let batch_encodings = embeddings.to_vec2::().unwrap(); - encodings.extend(batch_encodings.iter().map(|x| EmbeddingResult::Dense(x.to_vec()))); + encodings.extend( + batch_encodings + .iter() + .map(|x| EmbeddingResult::DenseVector(x.to_vec())), + ); } Ok(encodings) } } +pub struct SparseBertEmbedder { + pub tokenizer: Tokenizer, + pub model: BertForMaskedLM, + pub device: Device, + pub dtype: DType, +} + +impl SparseBertEmbedder { + pub fn new(model_id: String, revision: Option) -> Result { + let (config_filename, tokenizer_filename, weights_filename) = { + let api = Api::new().unwrap(); + let api = match revision { + Some(rev) => api.repo(Repo::with_revision(model_id, hf_hub::RepoType::Model, rev)), + None => api.repo(hf_hub::Repo::new( + model_id.to_string(), + hf_hub::RepoType::Model, + )), + }; + let config = api.get("config.json")?; + let tokenizer = api.get("tokenizer.json")?; + let weights = match api.get("model.safetensors") { + Ok(safetensors) => safetensors, + Err(_) => match api.get("pytorch_model.bin") { + Ok(pytorch_model) => pytorch_model, + Err(e) => { + return Err(anyhow::Error::msg(format!( + "Model weights not found. The weights should either be a `model.safetensors` or `pytorch_model.bin` file. Error: {}", + e + ))); + } + }, + }; + + (config, tokenizer, weights) + }; + let config = std::fs::read_to_string(config_filename)?; + let config: Config = serde_json::from_str(&config)?; + let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + + let pp = PaddingParams { + strategy: tokenizers::PaddingStrategy::BatchLongest, + ..Default::default() + }; + let trunc = TruncationParams { + strategy: tokenizers::TruncationStrategy::LongestFirst, + max_length: config.max_position_embeddings as usize, + ..Default::default() + }; + + tokenizer + .with_padding(Some(pp)) + .with_truncation(Some(trunc)) + .unwrap(); + + println!("Loading weights from {:?}", weights_filename); + + let device = Device::cuda_if_available(0).unwrap_or(Device::Cpu); + let vb = if weights_filename.ends_with("model.safetensors") { + unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? } + } else { + println!("Loading weights from pytorch_model.bin"); + VarBuilder::from_pth(&weights_filename, DTYPE, &device)? + }; + let model = BertForMaskedLM::load(vb, &config)?; + let tokenizer = tokenizer; + + Ok(SparseBertEmbedder { + model, + tokenizer, + device, + dtype: DTYPE, + }) + } +} + +impl BertEmbed for SparseBertEmbedder { + fn embed( + &self, + text_batch: &[String], + batch_size: Option, + ) -> Result, anyhow::Error> { + let batch_size = batch_size.unwrap_or(32); + let mut encodings: Vec = Vec::new(); + + for mini_text_batch in text_batch.chunks(batch_size) { + let token_ids = tokenize_batch(&self.tokenizer, mini_text_batch, &self.device).unwrap(); + let token_type_ids = token_ids.zeros_like().unwrap(); + let embeddings: Tensor = self + .model + .forward(&token_ids, &token_type_ids, None) + .unwrap(); + let attention_mask = + get_attention_mask(&self.tokenizer, mini_text_batch, &self.device).unwrap(); + + let batch_encodings = Tensor::log( + &Tensor::try_from(1.0)? + .to_dtype(self.dtype)? + .to_device(&self.device)? + .broadcast_add(&embeddings.relu()?)?, + )?; + + let batch_encodings = batch_encodings + .broadcast_mul(&attention_mask.unsqueeze(2)?.to_dtype(self.dtype)?)? + .max(1)?; + let batch_encodings = normalize_l2(&batch_encodings)?; + + encodings.extend( + batch_encodings + .to_vec2::()? + .into_iter() + .map(|x| EmbeddingResult::DenseVector(x.to_vec())), + ); + } + Ok(encodings) + } +} diff --git a/rust/src/embeddings/local/clip.rs b/rust/src/embeddings/local/clip.rs index 14955df..5a1206d 100644 --- a/rust/src/embeddings/local/clip.rs +++ b/rust/src/embeddings/local/clip.rs @@ -28,14 +28,14 @@ impl Default for ClipEmbedder { fn default() -> Self { Self::new( "openai/clip-vit-base-patch32".to_string(), - Some("refs/pr/15".to_string()), + Some("refs/pr/15"), ) .unwrap() } } impl ClipEmbedder { - pub fn new(model_id: String, revision: Option) -> Result { + pub fn new(model_id: String, revision: Option<&str>) -> Result { let api = hf_hub::api::sync::Api::new()?; let api = match revision { @@ -206,7 +206,7 @@ impl ClipEmbedder { encodings.extend( batch_encodings .iter() - .map(|embedding| EmbeddingResult::Dense(embedding.to_vec())), + .map(|embedding| EmbeddingResult::DenseVector(embedding.to_vec())), ); } @@ -250,7 +250,7 @@ impl EmbedImage for ClipEmbedder { ); EmbedData::new( - EmbeddingResult::Dense(data.to_vec()), + EmbeddingResult::DenseVector(data.to_vec()), Some(path.as_ref().to_str().unwrap().to_string()), Some(metadata), ) @@ -277,7 +277,7 @@ impl EmbedImage for ClipEmbedder { .to_vec2::() .unwrap()[0]; Ok(EmbedData::new( - EmbeddingResult::Dense(encoding.to_vec()), + EmbeddingResult::DenseVector(encoding.to_vec()), None, metadata.clone(), )) @@ -335,7 +335,7 @@ mod tests { // Tests the embed_image_batch method. #[test] fn test_embed_image_batch() { - let mut clip_embeder = ClipEmbedder::default(); + let clip_embeder = ClipEmbedder::default(); let embeddings = clip_embeder .embed_image_batch(&["test_files/clip/cat1.jpg", "test_files/clip/cat2.jpeg"]) .unwrap(); diff --git a/rust/src/embeddings/local/colpali.rs b/rust/src/embeddings/local/colpali.rs index 9ea5370..b530428 100644 --- a/rust/src/embeddings/local/colpali.rs +++ b/rust/src/embeddings/local/colpali.rs @@ -1,15 +1,15 @@ use std::sync::RwLock; use std::{collections::HashMap, path::Path}; +use crate::embeddings::embed::{EmbedData, EmbedImage, EmbeddingResult}; use anyhow::Error as E; +use base64::Engine; use candle_core::{DType, Device, Tensor}; use candle_nn::VarBuilder; use candle_transformers::models::{colpali::Model, paligemma}; use image::{DynamicImage, ImageFormat}; -use tokenizers::{PaddingParams, Tokenizer, TruncationParams}; -use crate::embeddings::embed::{EmbedData, EmbedImage, EmbeddingResult}; use pdf2image::{Pages, RenderOptionsBuilder, PDF}; -use base64::Engine; +use tokenizers::{PaddingParams, Tokenizer, TruncationParams}; pub struct ColPaliEmbedder { pub model: RwLock, pub tokenizer: Tokenizer, @@ -56,7 +56,7 @@ impl ColPaliEmbedder { }; let trunc = TruncationParams { strategy: tokenizers::TruncationStrategy::LongestFirst, - max_length: config.text_config.max_position_embeddings as usize, + max_length: config.text_config.max_position_embeddings, ..Default::default() }; @@ -105,14 +105,10 @@ impl ColPaliEmbedder { let img = img.to_rgb8(); let img = img.into_raw(); - let img = Tensor::from_vec( - img, - (height, width, 3), - &self.device, - )? - .permute((2, 0, 1))? - .to_dtype(DType::F32)? - .affine(2. / 255., -1.)?; + let img = Tensor::from_vec(img, (height, width, 3), &self.device)? + .permute((2, 0, 1))? + .to_dtype(DType::F32)? + .affine(2. / 255., -1.)?; Ok(img) } @@ -159,7 +155,7 @@ impl ColPaliEmbedder { batch_encodings .to_vec3::()? .iter() - .map(|x| EmbeddingResult::Sparse(x.to_vec())), + .map(|x| EmbeddingResult::MultiVector(x.to_vec())), ); } Ok(encodings) @@ -222,11 +218,14 @@ impl ColPaliEmbedder { .unwrap() .forward_images(&page_images, &dummy_input)? .to_dtype(DType::F32)? - .to_vec3::()?.into_iter().map(|x| EmbeddingResult::Sparse(x.to_vec())); + .to_vec3::()? + .into_iter() + .map(|x| EmbeddingResult::MultiVector(x.to_vec())); // zip the embeddings with the page numbers let embed_data_batch = image_embeddings - .zip(page_numbers.into_iter()).zip(batch.into_iter()) + .zip(page_numbers.into_iter()) + .zip(batch.iter()) .map(|((embedding, page_number), page_image)| { let mut metadata = HashMap::new(); @@ -237,26 +236,20 @@ impl ColPaliEmbedder { let base64_image = engine.encode(&buf); metadata.insert("page_number".to_string(), page_number.to_string()); - metadata.insert("file_path".to_string(), file_path.as_ref().to_str().unwrap_or("").to_string()); + metadata.insert( + "file_path".to_string(), + file_path.as_ref().to_str().unwrap_or("").to_string(), + ); metadata.insert("image".to_string(), base64_image); - EmbedData::new( - embedding, - None, - Some(metadata), - ) + EmbedData::new(embedding, None, Some(metadata)) }); embed_data.extend(embed_data_batch); } Ok(embed_data) - } pub fn embed_query(&self, query: &str) -> anyhow::Result> { - let input_ids = tokenize_batch( - &self.tokenizer, - vec![query], - &self.device, - )?; + let input_ids = tokenize_batch(&self.tokenizer, vec![query], &self.device)?; let encoding = self .model @@ -266,7 +259,7 @@ impl ColPaliEmbedder { .to_dtype(DType::F32)? .to_vec3::()? .into_iter() - .map(|x| EmbeddingResult::Sparse(x.to_vec())); + .map(|x| EmbeddingResult::MultiVector(x.to_vec())); Ok(encoding .map(|x| EmbedData::new(x.clone(), None, None)) @@ -292,7 +285,7 @@ impl EmbedImage for ColPaliEmbedder { .to_dtype(DType::F32)? .to_vec3::()? .into_iter() - .map(|x| EmbeddingResult::Sparse(x.to_vec())) + .map(|x| EmbeddingResult::MultiVector(x.to_vec())) .collect::>(); Ok(EmbedData::new(encoding[0].clone(), None, metadata)) @@ -315,7 +308,7 @@ impl EmbedImage for ColPaliEmbedder { Ok(encodings .into_iter() - .map(|x| EmbedData::new(EmbeddingResult::Sparse(x), None, None)) + .map(|x| EmbedData::new(EmbeddingResult::MultiVector(x), None, None)) .collect::>()) } } diff --git a/rust/src/embeddings/local/jina.rs b/rust/src/embeddings/local/jina.rs index c8f26a2..5febfa2 100644 --- a/rust/src/embeddings/local/jina.rs +++ b/rust/src/embeddings/local/jina.rs @@ -30,15 +30,19 @@ pub struct JinaEmbedder { impl Default for JinaEmbedder { fn default() -> Self { - Self::new("jinaai/jina-embeddings-v2-small-en".to_string(), None).unwrap() + Self::new("jinaai/jina-embeddings-v2-small-en", None).unwrap() } } impl JinaEmbedder { - pub fn new(model_id: String, revision: Option) -> Result { + pub fn new(model_id: &str, revision: Option<&str>) -> Result { let api = hf_hub::api::sync::Api::new()?; let api = match revision { - Some(rev) => api.repo(Repo::with_revision(model_id, hf_hub::RepoType::Model, rev)), + Some(rev) => api.repo(Repo::with_revision( + model_id.to_string(), + hf_hub::RepoType::Model, + rev.to_string(), + )), None => api.repo(Repo::new(model_id.to_string(), hf_hub::RepoType::Model)), }; @@ -105,10 +109,9 @@ impl JinaEmbedder { let embeddings = normalize_l2(&embeddings).unwrap(); // Avoid using to_vec2() and instead work with the Tensor directly - encodings - .extend((0..embeddings.dim(0)?).map(|i| { - EmbeddingResult::Dense(embeddings.get(i).unwrap().to_vec1().unwrap()) - })); + encodings.extend((0..embeddings.dim(0)?).map(|i| { + EmbeddingResult::DenseVector(embeddings.get(i).unwrap().to_vec1().unwrap()) + })); } Ok(encodings) @@ -121,8 +124,7 @@ mod tests { #[test] fn test_embed() { - let embeder = - JinaEmbedder::new("jinaai/jina-embeddings-v2-small-en".to_string(), None).unwrap(); + let embeder = JinaEmbedder::new("jinaai/jina-embeddings-v2-small-en", None).unwrap(); let text_batch = vec!["Hello, world!".to_string()]; let encodings = embeder.embed(&text_batch, None).unwrap(); diff --git a/rust/src/embeddings/local/mod.rs b/rust/src/embeddings/local/mod.rs index 4643154..8beeff7 100644 --- a/rust/src/embeddings/local/mod.rs +++ b/rust/src/embeddings/local/mod.rs @@ -1,7 +1,7 @@ pub mod bert; pub mod clip; +pub mod colpali; pub mod jina; pub mod model_info; pub mod pooling; pub mod text_embedding; -pub mod colpali; \ No newline at end of file diff --git a/rust/src/embeddings/local/text_embedding.rs b/rust/src/embeddings/local/text_embedding.rs index e802c60..48ca266 100644 --- a/rust/src/embeddings/local/text_embedding.rs +++ b/rust/src/embeddings/local/text_embedding.rs @@ -72,8 +72,6 @@ pub enum ONNXModel { JINAV2BASEEN, /// jinaai/jina-embeddings-v2-large-en JINAV2LARGEEN, - - } // impl From<&str> for ONNXModel { @@ -335,7 +333,6 @@ pub fn get_model_info_by_hf_id(hf_model_id: &str) -> Option<&ModelInfo anyhow::Result { + let tokens = tokenizer + .encode_batch(text_batch.to_vec(), true) + .map_err(E::msg)?; + let token_ids = tokens + .iter() + .map(|tokens| { + let tokens = tokens.get_ids().to_vec(); + Tensor::new(tokens.as_slice(), device) + }) + .collect::>>()?; + + Ok(Tensor::stack(&token_ids, 0)?) +} + +pub fn get_attention_mask( + tokenizer: &Tokenizer, + text_batch: &[String], + device: &Device, +) -> anyhow::Result { + let tokens = tokenizer + .encode_batch(text_batch.to_vec(), true) + .map_err(E::msg)?; + + let attention_mask = tokens + .iter() + .map(|tokens| { + let tokens = tokens.get_attention_mask().to_vec(); + Tensor::new(tokens.as_slice(), device) + }) + .collect::>>()?; + Ok(Tensor::stack(&attention_mask, 0)?) +} diff --git a/rust/src/file_processor/website_processor.rs b/rust/src/file_processor/website_processor.rs index 23defcc..80a1d39 100644 --- a/rust/src/file_processor/website_processor.rs +++ b/rust/src/file_processor/website_processor.rs @@ -1,4 +1,7 @@ -use std::{collections::{HashMap, HashSet}, rc::Rc}; +use std::{ + collections::{HashMap, HashSet}, + rc::Rc, +}; use anyhow::Result; use scraper::{Html, Selector}; @@ -95,11 +98,9 @@ impl WebPage { let metadata_hashmap: HashMap = serde_json::from_value(metadata)?; - let encodings = embeder - .embed(&chunks, batch_size) - .await - .unwrap(); - let embeddings = get_text_metadata(&Rc::new(encodings), &chunks, &Some(metadata_hashmap))?; + let encodings = embeder.embed(&chunks, batch_size).await.unwrap(); + let embeddings = + get_text_metadata(&Rc::new(encodings), &chunks, &Some(metadata_hashmap))?; embed_data.extend(embeddings); } diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 1acd1d8..8727861 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -120,7 +120,7 @@ where Embedder::Text(embeder) => { emb_text( file_name, - &embeder, + embeder, Some(chunk_size), batch_size, Some(splitting_strategy), @@ -129,7 +129,7 @@ where ) .await } - Embedder::Vision( embeder) => Ok(Some(vec![emb_image(file_name, embeder).unwrap()])), + Embedder::Vision(embeder) => Ok(Some(vec![emb_image(file_name, embeder).unwrap()])), } } @@ -507,13 +507,8 @@ where metadata_buffer.push(metadata); if chunk_buffer.len() == buffer_size { - match process_chunks( - &chunk_buffer, - &metadata_buffer, - &embeder, - batch_size, - ) - .await + match process_chunks(&chunk_buffer, &metadata_buffer, &embeder, batch_size) + .await { Ok(embeddings) => { let files = embeddings @@ -543,14 +538,7 @@ where // Process any remaining chunks if !chunk_buffer.is_empty() { - match process_chunks( - &chunk_buffer, - &metadata_buffer, - &embeder, - batch_size, - ) - .await - { + match process_chunks(&chunk_buffer, &metadata_buffer, &embeder, batch_size).await { Ok(embeddings) => { let files = embeddings .iter() @@ -621,9 +609,7 @@ pub async fn process_chunks( embedding_model: &Arc, batch_size: Option, ) -> Result>> { - let encodings = embedding_model - .embed(chunks, batch_size) - .await?; + let encodings = embedding_model.embed(chunks, batch_size).await?; // zip encodings with chunks and metadata let embeddings = encodings @@ -636,5 +622,3 @@ pub async fn process_chunks( .collect::>(); Ok(Arc::new(embeddings)) } - - diff --git a/rust/src/models/bert.rs b/rust/src/models/bert.rs index 83b4d3d..ad4ea94 100644 --- a/rust/src/models/bert.rs +++ b/rust/src/models/bert.rs @@ -504,3 +504,100 @@ fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result< (attention_mask.ones_like()? - &attention_mask)? .broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?) } + +//https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L752-L766 +struct BertPredictionHeadTransform { + dense: Linear, + activation: HiddenActLayer, + layer_norm: LayerNorm, +} + +impl BertPredictionHeadTransform { + fn load(vb: VarBuilder, config: &Config) -> Result { + let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?; + let activation = HiddenActLayer::new(config.hidden_act); + let layer_norm = layer_norm( + config.hidden_size, + config.layer_norm_eps, + vb.pp("LayerNorm"), + )?; + Ok(Self { + dense, + activation, + layer_norm, + }) + } +} + +impl Module for BertPredictionHeadTransform { + fn forward(&self, hidden_states: &Tensor) -> Result { + let hidden_states = self + .activation + .forward(&self.dense.forward(hidden_states)?)?; + self.layer_norm.forward(&hidden_states) + } +} + +// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L769C1-L790C1 +pub struct BertLMPredictionHead { + transform: BertPredictionHeadTransform, + decoder: Linear, +} + +impl BertLMPredictionHead { + pub fn load(vb: VarBuilder, config: &Config) -> Result { + let transform = BertPredictionHeadTransform::load(vb.pp("transform"), config)?; + let decoder = linear(config.hidden_size, config.vocab_size, vb.pp("decoder"))?; + Ok(Self { transform, decoder }) + } +} + +impl Module for BertLMPredictionHead { + fn forward(&self, hidden_states: &Tensor) -> Result { + self.decoder + .forward(&self.transform.forward(hidden_states)?) + } +} + +// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L792 +pub struct BertOnlyMLMHead { + predictions: BertLMPredictionHead, +} + +impl BertOnlyMLMHead { + pub fn load(vb: VarBuilder, config: &Config) -> Result { + let predictions = BertLMPredictionHead::load(vb.pp("predictions"), config)?; + Ok(Self { predictions }) + } +} + +impl Module for BertOnlyMLMHead { + fn forward(&self, sequence_output: &Tensor) -> Result { + self.predictions.forward(sequence_output) + } +} + +pub struct BertForMaskedLM { + bert: BertModel, + cls: BertOnlyMLMHead, +} + +impl BertForMaskedLM { + pub fn load(vb: VarBuilder, config: &Config) -> Result { + let bert = BertModel::load(vb.pp("bert"), config)?; + let cls = BertOnlyMLMHead::load(vb.pp("cls"), config)?; + Ok(Self { bert, cls }) + } + + pub fn forward( + &self, + input_ids: &Tensor, + token_type_ids: &Tensor, + attention_mask: Option<&Tensor>, + ) -> Result { + let sequence_output = self + .bert + .forward(input_ids, token_type_ids, attention_mask)?; + self.cls.forward(&sequence_output) + } +} diff --git a/rust/src/text_loader.rs b/rust/src/text_loader.rs index c3c3702..51e6704 100644 --- a/rust/src/text_loader.rs +++ b/rust/src/text_loader.rs @@ -5,11 +5,14 @@ use std::{ sync::Arc, }; -use crate::{embeddings::embed::Embedder, file_processor::{markdown_processor::MarkdownProcessor, txt_processor::TxtProcessor}}; use crate::{ chunkers::statistical::StatisticalChunker, embeddings::{embed::TextEmbedder, local::jina::JinaEmbedder}, }; +use crate::{ + embeddings::embed::Embedder, + file_processor::{markdown_processor::MarkdownProcessor, txt_processor::TxtProcessor}, +}; use anyhow::Error; use chrono::{DateTime, Local}; use text_splitter::{Characters, ChunkConfig, TextSplitter}; @@ -89,8 +92,9 @@ impl TextLoader { .map(|chunk| chunk.to_string()) .collect(), SplittingStrategy::Semantic => { - let embeder = - semantic_encoder.unwrap_or(Arc::new(Embedder::Text(TextEmbedder::Jina(JinaEmbedder::default())))); + let embeder = semantic_encoder.unwrap_or(Arc::new(Embedder::Text( + TextEmbedder::Jina(JinaEmbedder::default()), + ))); let chunker = StatisticalChunker { encoder: embeder, ..Default::default() @@ -173,7 +177,7 @@ mod tests { #[test] fn test_image_embeder() { let file_path = PathBuf::from("test_files/clip/cat1.jpg"); - let mut embeder = ClipEmbedder::default(); + let embeder = ClipEmbedder::default(); let emb_data = embeder.embed_image(file_path, None).unwrap(); assert_eq!(emb_data.embedding.to_dense().unwrap().len(), 512); } diff --git a/tests/model_tests/test_openai.py b/tests/model_tests/test_openai.py index 64f5d5b..f0e0077 100644 --- a/tests/model_tests/test_openai.py +++ b/tests/model_tests/test_openai.py @@ -8,7 +8,7 @@ def test_openai_model_file(openai_model, test_pdf_file): assert len(data[0].embedding) == 1536 @pytest.mark.parametrize( - "config", [TextEmbedConfig(batch_size=32, chunk_size=256)] + "config", [TextEmbedConfig(batch_size=512, chunk_size=1000, buffer_size = 512)] ) def test_openai_model_directory(openai_model, config, test_files_directory): data = embed_directory(test_files_directory, openai_model, config=config)