Skip to content

Commit

Permalink
fix formatting issues
Browse files Browse the repository at this point in the history
  • Loading branch information
bluearrow98 committed Feb 13, 2025
1 parent 76554c6 commit b7c0c1e
Showing 1 changed file with 29 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import uuid
from typing import Any, Iterable, List, Dict, Optional
from typing import Any, Dict, Iterable, List, Optional

from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
Expand Down Expand Up @@ -40,12 +40,19 @@ class ElasticSearchBM25Retriever(BaseRetriever):
"""Elasticsearch client."""
index_name: str
"""Name of the index to use in Elasticsearch."""
k: str = 4
k: int = 4
"""Number of documents to return."""

@classmethod
def create(
cls, elasticsearch_url: str, index_name: str, delete_if_exists:bool=False, k1: float = 2.0, b: float = 0.75, analyzer_type: str = "standard", es_params: dict = {}
cls,
elasticsearch_url: str,
index_name: str,
delete_if_exists: bool = False,
k1: float = 2.0,
b: float = 0.75,
analyzer_type: str = "standard",
es_params: dict = {},
) -> ElasticSearchBM25Retriever:
"""
Create a ElasticSearchBM25Retriever from a list of texts.
Expand Down Expand Up @@ -116,15 +123,15 @@ def add_texts(
)
requests = []
ids = []
metadata = metadata or [{}] * len(texts)
metadata = metadata or [{}] * len(list(texts))
for i, text in enumerate(texts):
_id = str(uuid.uuid4())
request = {
"_op_type": "index",
"_index": self.index_name,
"content": text,
"_id": _id,
"metadata":metadata[i]
"metadata": metadata[i],
}
ids.append(_id)
requests.append(request)
Expand All @@ -133,11 +140,11 @@ def add_texts(
if refresh_indices:
self.client.indices.refresh(index=self.index_name)
return ids

def add_documents(
self,
documents: List[Document],
refresh_indices:bool = True,
self,
documents: List[Document],
refresh_indices: bool = True,
) -> List[str]:
"""Add documents to the index.
Expand All @@ -148,13 +155,13 @@ def add_documents(
Returns:
List of ids from adding the texts into the retriever.
"""

texts = [doc.page_content for doc in documents]
metadata = [doc.metadata for doc in documents]

return self.add_texts(texts, metadata)

def build_query_body(self, query:str) -> Dict:
def build_query_body(self, query: str) -> Dict:
"""Build query body for the search API"""

return {"query": {"match": {"content": query}}}
Expand All @@ -163,9 +170,16 @@ def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
query_dict = self.build_query_body(query)
res = self.client.search(index=self.index_name, body=query_dict, source=["content", "metadata"])
res = self.client.search(
index=self.index_name, body=query_dict, source=["content", "metadata"]
)

docs = []
for r in res["hits"]["hits"]:
docs.append(Document(metadata=r["_source"]["metadata"], page_content=r["_source"]["content"]))
return docs[:self.k]
docs.append(
Document(
metadata=r["_source"]["metadata"],
page_content=r["_source"]["content"],
)
)
return docs[: self.k]

0 comments on commit b7c0c1e

Please sign in to comment.