Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ uploads/
myenv/
venv/
*.pyc
env/
9 changes: 9 additions & 0 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
class VectorDBType(Enum):
PGVECTOR = "pgvector"
ATLAS_MONGO = "atlas-mongo"
ELASTIC_SEARCH = "elasticsearch"


class EmbeddingsProvider(Enum):
Expand Down Expand Up @@ -63,6 +64,7 @@ def get_env_variable(
MONGO_VECTOR_COLLECTION = get_env_variable(
"MONGO_VECTOR_COLLECTION", None
) # Deprecated, backwards compatability
ES_URL = get_env_variable("ES_URL", "es")
CHUNK_SIZE = int(get_env_variable("CHUNK_SIZE", "1500"))
CHUNK_OVERLAP = int(get_env_variable("CHUNK_OVERLAP", "100"))

Expand Down Expand Up @@ -298,6 +300,13 @@ def init_embeddings(provider, model):
mode="atlas-mongo",
search_index=ATLAS_SEARCH_INDEX,
)
elif VECTOR_DB_TYPE == VectorDBType.ELASTIC_SEARCH:
vector_store = get_vector_store(
connection_string=ES_URL,
embeddings=embeddings,
collection_name=COLLECTION_NAME,
mode="elasticsearch"
)
else:
raise ValueError(f"Unsupported vector store type: {VECTOR_DB_TYPE}")

Expand Down
13 changes: 13 additions & 0 deletions app/routes/document_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
QueryMultipleBody,
)
from app.services.vector_store.async_pg_vector import AsyncPgVector
from app.services.vector_store.elasticsearch_vector import ExtendedElasticsearchVector
from app.utils.document_loader import get_loader, clean_text, process_documents
from app.utils.health import is_health_ok

Expand Down Expand Up @@ -182,6 +183,12 @@ async def query_embeddings_by_file_id(
k=body.k,
filter={"file_id": body.file_id},
)
elif isinstance(vector_store, ExtendedElasticsearchVector):
documents = vector_store.similarity_search_with_score_by_vector(
query=body.query,
embedding=embedding,
file_ids=[body.file_id],
)
else:
documents = vector_store.similarity_search_with_score_by_vector(
embedding, k=body.k, filter={"file_id": body.file_id}
Expand Down Expand Up @@ -580,6 +587,12 @@ async def query_embeddings_by_file_ids(body: QueryMultipleBody):
k=body.k,
filter={"file_id": {"$in": body.file_ids}},
)
elif isinstance(vector_store, ExtendedElasticsearchVector):
documents = vector_store.similarity_search_with_score_by_vector(
query=body.query,
embedding=embedding,
file_ids=body.file_ids,
)
else:
documents = vector_store.similarity_search_with_score_by_vector(
embedding, k=body.k, filter={"file_id": {"$in": body.file_ids}}
Expand Down
221 changes: 221 additions & 0 deletions app/services/vector_store/elasticsearch_vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
from concurrent.futures import ThreadPoolExecutor
import threading
from typing import List, Tuple
from langchain_elasticsearch import ElasticsearchStore
from langchain_core.embeddings import Embeddings
from langchain_core.documents import Document

class ExtendedElasticsearchVector(ElasticsearchStore):
@property
def embedding_function(self) -> Embeddings:
"""
Property to access the embedding function.

Returns:
Embeddings: The embedding function used for generating vector representations.
"""

return self.embedding

def add_documents(self, docs: List[Document], ids: List[str]):
"""
Adds a list of documents to the vector store after embedding them.

Args:
docs (List[Document]): A list of documents to be embedded and stored.
ids (List[str]): Base identifier for the documents. Each doc will be assigned a unique ID.

Returns:
Any: Result of adding texts to the underlying store.
"""

embedded_vectors = [None] * len(docs)
lock = threading.Lock()

# Worker function to embed documents in parallel
def worker(i):
text = docs[i].page_content
emb = self.embedding.embed_query(text)
with lock:
embedded_vectors[i] = emb

# Use a ThreadPoolExecutor to parallelize embedding computation
with ThreadPoolExecutor(max_workers=10) as executor:
executor.map(worker, range(len(docs)))

# Generate unique file IDs using the base ID
file_ids = [f"{ids[0]}_{i}" for i in range(len(embedded_vectors))]

return self._store.add_texts(
ids=file_ids,
texts=[doc.page_content for doc in docs],
metadatas=[doc.metadata for doc in docs],
vectors=embedded_vectors,
create_index_if_not_exists=True,
refresh_indices=False,
bulk_kwargs={"chunk_size": 1000}
)

def similarity_search_with_score_by_vector(
self,
query: str,
embedding: List[float],
file_ids: List[str],
**kwargs
) -> List[Tuple[Document, float]]:
"""
Performs a similarity search using a given embedding and returns results with scores.

Args:
query (str): Text query to include in the search (used in match clause).
embedding (List[float]): Embedding vector to search against.
file_ids (List[str]): List of file IDs to restrict the search.

Returns:
List[Tuple[Document, float]]: A list of (Document, score) tuples matching the query.
"""

query_body = {
"query": {
"min_score": 1.5,
"query": {
"bool": {
"should": [
{
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.query_vector, 'vector') + 1.0",
"params": {"query_vector": embedding}
}
}
},
{
"match": {
"text": {"query": query}
}
}
],
"must": [
{
"terms": {
"metadata.file_id.keyword": file_ids
}
}
]
}
}
}
}

result = self.client.search(index=self._store.index, body=query_body)
documents = []

for hit in result["hits"]["hits"]:
documents.append(
(
Document(
page_content=hit["_source"]["text"],
metadata=hit["_source"]["metadata"]
),
hit["_score"]
)
)

return documents

def get_all_ids(self) -> List[str]:
"""
Retrieves all document IDs from the vector store.

Returns:
List[str]: A list of document IDs.
"""

query = {
"query": {
"match_all": {}
},
"_source": False,
"stored_fields": []
}

ids: List[str] = []

result = self.client.search(index=self._store.index, body=query)
for hit in result["hits"]["hits"]:
ids.append(hit["_id"])

return ids

def get_documents_by_ids(self, ids: List[str]) -> List[Document]:
"""
Retrieves documents from the vector store by their unique IDs.

Args:
ids (List[str]): List of document IDs to retrieve.

Returns:
List[Document]: A list of Document objects corresponding to the provided IDs.
"""

if not ids:
return []

query = {
"query": {
"terms": {
"_id": ids
}
}
}

documents: List[Document] = []
result = self.client.search(index=self._store.index, body=query)

for hit in result["hits"]["hits"]:
source = hit["_source"]
doc = Document(
page_content=source.get("text", ""),
metadata=source.get("metadata", {})
)
documents.append(doc)

return documents

def get_filtered_ids(self, ids: List[str]) -> List[str]:
"""
Returns file IDs filtered by the provided list of file IDs

Args:
ids (List[str]): List of file IDs to filter by.

Returns:
List[str]: A list of file IDs that exist in the document.
"""

query = {
"query": {
"terms": {
"metadata.file_id.keyword": ids
}
},
"size": 0,
"aggs": {
"unique_file_ids": {
"terms": {
"field": "metadata.file_id.keyword",
"size": len(ids)
}
}
}
}

result = self.client.search(index=self._store.index, body=query)
filtered_ids = []

if "aggregations" in result and "unique_file_ids" in result["aggregations"]:
for bucket in result["aggregations"]["unique_file_ids"]["buckets"]:
filtered_ids.append(bucket["key"])

return filtered_ids
7 changes: 7 additions & 0 deletions app/services/vector_store/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .async_pg_vector import AsyncPgVector
from .atlas_mongo_vector import AtlasMongoVector
from .extended_pg_vector import ExtendedPgVector
from .elasticsearch_vector import ExtendedElasticsearchVector


def get_vector_store(
Expand Down Expand Up @@ -32,5 +33,11 @@ def get_vector_store(
return AtlasMongoVector(
collection=mong_collection, embedding=embeddings, index_name=search_index
)
elif mode == "elasticsearch":
return ExtendedElasticsearchVector(
es_url=connection_string,
index_name=collection_name,
embedding=embeddings,
)
else:
raise ValueError("Invalid mode specified. Choose 'sync', 'async', or 'atlas-mongo'.")
1 change: 1 addition & 0 deletions requirements.lite.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ langchain-community==0.3.12
langchain-openai==0.2.11
langchain-core==0.3.27
langchain-google-vertexai==2.0.11
langchain_elasticsearch==0.3.2
sqlalchemy==2.0.28
python-dotenv==1.0.1
fastapi==0.115.12
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ langchain-core==0.3.27
langchain-aws==0.2.1
langchain-google-vertexai==2.0.0
langchain_text_splitters==0.3.3
langchain_elasticsearch==0.3.2
boto3==1.34.144
sqlalchemy==2.0.28
python-dotenv==1.0.1
Expand Down