Skip to content

Commit

Permalink
Merge branch 'weaviate-client-v4' of https://github.com/hsm207/haysta…
Browse files Browse the repository at this point in the history
…ck-core-integrations into weaviate-client-v4
  • Loading branch information
hsm207 committed Feb 29, 2024
2 parents 8d72423 + 2c3e446 commit 3f8fa4c
Show file tree
Hide file tree
Showing 29 changed files with 855 additions and 372 deletions.
13 changes: 10 additions & 3 deletions integrations/astra/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@
# Astra Store

## Installation

```bash
pip install astra-haystack

```

### Local Development
install astra-haystack package locally to run integration tests:

Open in gitpod:
Expand Down Expand Up @@ -46,8 +53,8 @@ This package includes Astra Document Store and Astra Embedding Retriever classes

Import the Document Store:
```
from astra_store.document_store import AstraDocumentStore
from haystack.preview.document_stores import DuplicatePolicy
from haystack_integrations.document_stores.astra import AstraDocumentStore
from haystack.document_stores.types.policy import DuplicatePolicy
```

Load in environment variables:
Expand Down Expand Up @@ -76,7 +83,7 @@ Then you can use the document store functions like count_document below:
Create the Document Store object like above, then import and create the Pipeline:

```
from haystack.preview import Pipeline
from haystack import Pipeline
pipeline = Pipeline()
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,28 @@
class AstraEmbeddingRetriever:
"""
A component for retrieving documents from an AstraDocumentStore.
Usage example:
```python
from haystack_integrations.document_stores.astra import AstraDocumentStore
from haystack_integrations.components.retrievers.astra import AstraEmbeddingRetriever
document_store = AstraDocumentStore(
api_endpoint=api_endpoint,
token=token,
collection_name=collection_name,
duplicates_policy=DuplicatePolicy.SKIP,
embedding_dim=384,
)
retriever = AstraEmbeddingRetriever(document_store=document_store)
```
"""

def __init__(self, document_store: AstraDocumentStore, filters: Optional[Dict[str, Any]] = None, top_k: int = 10):
"""
Create an AstraEmbeddingRetriever component. Usually you pass some basic configuration
parameters to the constructor.
:param filters: A dictionary with filters to narrow down the search space (default is None).
:param top_k: The maximum number of documents to retrieve (default is 10).
:param filters: a dictionary with filters to narrow down the search space.
:param top_k: the maximum number of documents to retrieve.
"""
self.filters = filters
self.top_k = top_k
Expand All @@ -33,13 +46,13 @@ def __init__(self, document_store: AstraDocumentStore, filters: Optional[Dict[st

@component.output_types(documents=List[Document])
def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None):
"""Run the retriever on the given list of queries.
"""Retrieve documents from the AstraDocumentStore.
Args:
query_embedding (List[str]): An input list of queries
filters (Optional[Dict[str, Any]], optional): A dictionary with filters to narrow down the search space.
Defaults to None.
top_k (Optional[int], optional): The maximum number of documents to retrieve. Defaults to None.
:param query_embedding: floats representing the query embedding
:param filters: filters to narrow down the search space.
:param top_k: the maximum number of documents to retrieve.
:returns: a dictionary with the following keys:
- documents: A list of documents retrieved from the AstraDocumentStore.
"""

if not top_k:
Expand All @@ -51,6 +64,12 @@ def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] =
return {"documents": self.document_store.search(query_embedding, top_k, filters=filters)}

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
filters=self.filters,
Expand All @@ -60,6 +79,14 @@ def to_dict(self) -> Dict[str, Any]:

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AstraEmbeddingRetriever":
"""
Deserializes the component from a dictionary.
:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
document_store = AstraDocumentStore.from_dict(data["init_parameters"]["document_store"])
data["init_parameters"]["document_store"] = document_store
return default_from_dict(cls, data)
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,19 @@ def __init__(
similarity_function: str,
namespace: Optional[str] = None,
):
"""
The connection to Astra DB is established and managed through the JSON API.
The required credentials (api endpoint and application token) can be generated
through the UI by clicking and the connect tab, and then selecting JSON API and
Generate Configuration.
:param api_endpoint: the Astra DB API endpoint.
:param token: the Astra DB application token.
:param collection_name: the current collection in the keyspace in the current Astra DB.
:param embedding_dimension: dimension of embedding vector.
:param similarity_function: the similarity function to use for the index.
:param namespace: the namespace to use for the collection.
"""
self.api_endpoint = api_endpoint
self.token = token
self.collection_name = collection_name
Expand Down Expand Up @@ -119,23 +132,17 @@ def query(
include_values: Optional[bool] = None,
) -> QueryResponse:
"""
The Query operation searches a namespace, using a query vector.
It retrieves the ids of the most similar items in a namespace, along with their similarity scores.
Args:
vector (List[float]): The query vector. This should be the same length as the dimension of the index
being queried. Each `query()` request can contain only one of the parameters
`queries`, `id` or `vector`... [optional]
top_k (int): The number of results to return for each query. Must be an integer greater than 1.
query_filter (Dict[str, Union[str, float, int, bool, List, dict]):
The filter to apply. You can use vector metadata to limit your search. [optional]
include_metadata (bool): Indicates whether metadata is included in the response as well as the ids.
If omitted the server will use the default value of False [optional]
include_values (bool): Indicates whether values/vector is included in the response as well as the ids.
If omitted the server will use the default value of False [optional]
Returns: object which contains the list of the closest vectors as ScoredVector objects,
and namespace name.
Search the Astra index using a query vector.
:param vector: the query vector. This should be the same length as the dimension of the index being queried.
Each `query()` request can contain only one of the parameters `queries`, `id` or `vector`.
:param query_filter: the filter to apply. You can use vector metadata to limit your search.
:param top_k: the number of results to return for each query. Must be an integer greater than 1.
:param include_metadata: indicates whether metadata is included in the response as well as the ids.
If omitted the server will use the default value of `False`.
:param include_values: indicates whether values/vector is included in the response as well as the ids.
If omitted the server will use the default value of `False`.
:returns: object which contains the list of the closest vectors as ScoredVector objects, and namespace name.
"""
# get vector data and scores
if vector is None:
Expand Down Expand Up @@ -183,6 +190,12 @@ def _query(self, vector, top_k, filters=None):
return result

def find_documents(self, find_query):
"""
Find documents in the Astra index.
:param find_query: a dictionary with the query options
:returns: the documents found in the index
"""
response_dict = self._astra_db_collection.find(
filter=find_query.get("filter"),
sort=find_query.get("sort"),
Expand All @@ -195,6 +208,13 @@ def find_documents(self, find_query):
logger.warning(f"No documents found: {response_dict}")

def get_documents(self, ids: List[str], batch_size: int = 20) -> QueryResponse:
"""
Get documents from the Astra index by their ids.
:param ids: a list of document ids
:param batch_size: the batch size to use when querying the index
:returns: the documents found in the index
"""
document_batch = []

def batch_generator(chunks, batch_size):
Expand All @@ -213,6 +233,12 @@ def batch_generator(chunks, batch_size):
return formatted_docs

def insert(self, documents: List[Dict]):
"""
Insert documents into the Astra index.
:param documents: a list of documents to insert
:returns: the IDs of the inserted documents
"""
response_dict = self._astra_db_collection.insert_many(documents=documents)

inserted_ids = (
Expand All @@ -226,6 +252,13 @@ def insert(self, documents: List[Dict]):
return inserted_ids

def update_document(self, document: Dict, id_key: str):
"""
Update a document in the Astra index.
:param document: the document to update
:param id_key: the key to use as the document id
:returns: whether the document was updated successfully
"""
document_id = document.pop(id_key)

response_dict = self._astra_db_collection.find_one_and_update(
Expand All @@ -251,6 +284,13 @@ def delete(
delete_all: Optional[bool] = None,
filters: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
) -> int:
"""Delete documents from the Astra index.
:param ids: the ids of the documents to delete
:param delete_all: if `True`, delete all documents from the index
:param filters: additional filters to apply when deleting documents
:returns: the number of documents deleted
"""
if delete_all:
query = {"deleteMany": {}} # type: dict
if ids is not None:
Expand All @@ -276,7 +316,8 @@ def delete(

def count_documents(self) -> int:
"""
Returns how many documents are present in the document store.
Count the number of documents in the Astra index.
:returns: the number of documents in the index
"""
documents_count = self._astra_db_collection.count_documents()

Expand Down
Loading

0 comments on commit 3f8fa4c

Please sign in to comment.