Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refact!: change import paths #277

Merged
merged 5 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions integrations/astra/examples/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from haystack.components.preprocessors import DocumentCleaner, DocumentSplitter
from haystack.components.routers import FileTypeRouter
from haystack.components.writers import DocumentWriter
from haystack.document_stores import DuplicatePolicy
from haystack.document_stores.types import DuplicatePolicy

from astra_haystack.document_store import AstraDocumentStore
from astra_haystack.retriever import AstraRetriever
from haystack_integrations.components.retrievers.astra import AstraRetriever
from haystack_integrations.document_stores.astra import AstraDocumentStore

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
Expand Down
6 changes: 3 additions & 3 deletions integrations/astra/examples/pipeline_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder
from haystack.components.generators import OpenAIGenerator
from haystack.components.writers import DocumentWriter
from haystack.document_stores import DuplicatePolicy
from haystack.document_stores.types import DuplicatePolicy

from astra_haystack.document_store import AstraDocumentStore
from astra_haystack.retriever import AstraRetriever
from haystack_integrations.components.retrievers.astra import AstraRetriever
from haystack_integrations.document_stores.astra import AstraDocumentStore

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
Expand Down
15 changes: 9 additions & 6 deletions integrations/astra/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/m
Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues"
Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/astra"

[tool.hatch.build.targets.wheel]
packages = ["src/haystack_integrations"]

[tool.hatch.version]
source = "vcs"
tag-pattern = 'integrations\/astra-v(?P<version>.*)'
Expand Down Expand Up @@ -71,7 +74,7 @@ dependencies = [
"ruff>=0.0.243",
]
[tool.hatch.envs.lint.scripts]
typing = "mypy --install-types --non-interactive {args:src/astra_haystack tests}"
typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
style = [
"ruff {args:.}",
"black --check --diff {args:.}",
Expand Down Expand Up @@ -141,25 +144,25 @@ unfixable = [
exclude = ["example"]

[tool.ruff.isort]
known-first-party = ["astra_haystack"]
known-first-party = ["haystack_integrations"]

[tool.ruff.flake8-tidy-imports]
ban-relative-imports = "all"
ban-relative-imports = "parents"

[tool.ruff.per-file-ignores]
# Tests can use magic values, assertions, and relative imports
"tests/**/*" = ["PLR2004", "S101", "TID252"]

[tool.coverage.run]
source_pkgs = ["astra_haystack", "tests"]
source_pkgs = ["haystack_integrations", "tests"]
branch = true
parallel = true
omit = [
"example"
]

[tool.coverage.paths]
astra_haystack = ["src/astra_haystack", "*/astra-store/src/astra_haystack"]
astra_haystack = ["src"]
tests = ["tests"]

[tool.coverage.report]
Expand All @@ -178,10 +181,10 @@ markers = [

[[tool.mypy.overrides]]
module = [
"astra_haystack.*",
"astra_client.*",
"pydantic.*",
"haystack.*",
"haystack_integrations.*",
"pytest.*"
]
ignore_missing_imports = true
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# SPDX-FileCopyrightText: 2023-present Anant Corporation <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from .retriever import AstraRetriever

__all__ = ["AstraRetriever"]
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from haystack import Document, component, default_from_dict, default_to_dict

from astra_haystack.document_store import AstraDocumentStore
from haystack_integrations.document_stores.astra import AstraDocumentStore


@component
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-FileCopyrightText: 2023-present Anant Corporation <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from astra_haystack.document_store import AstraDocumentStore
from .document_store import AstraDocumentStore

__all__ = ["AstraDocumentStore"]
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from haystack.document_stores.errors import DuplicateDocumentError, MissingDocumentError
from haystack.document_stores.types import DuplicatePolicy

from astra_haystack.astra_client import AstraClient
from astra_haystack.errors import AstraDocumentStoreFilterError
from astra_haystack.filters import _convert_filters
from .astra_client import AstraClient
from .errors import AstraDocumentStoreFilterError
from .filters import _convert_filters

logger = logging.getLogger(__name__)

Expand All @@ -40,7 +40,7 @@ def __init__(
astra_application_token: str,
astra_keyspace: str,
astra_collection: str,
embedding_dim: Optional[int] = 768,
embedding_dim: int = 768,
duplicates_policy: DuplicatePolicy = DuplicatePolicy.NONE,
similarity: str = "cosine",
):
Expand Down Expand Up @@ -104,17 +104,12 @@ def to_dict(self) -> Dict[str, Any]:
def write_documents(
self,
documents: List[Document],
index: Optional[str] = None,
batch_size: int = 20,
policy: DuplicatePolicy = DuplicatePolicy.NONE,
):
"""
Indexes documents for later queries.

:param documents: a list of Haystack Document objects.
:param index: Optional name of index where the documents shall be written to.
If None, the DocumentStore's default index (self.index) will be used.
:param batch_size: Number of documents that are passed to bulk function at a time.
:param policy: Handle duplicate documents based on DuplicatePolicy parameter options.
Parameter options : (SKIP, OVERWRITE, FAIL, NONE)
- `DuplicatePolicy.NONE`: Default policy, If a Document with the same id already exists,
Expand All @@ -125,26 +120,13 @@ def write_documents(
- `DuplicatePolicy.FAIL`: If a Document with the same id already exists, an error is raised.
:return: int
"""

if index is None and self.index is None:
msg = "No Astra client provided"
raise ValueError(msg)

if index is None:
index = self.index

if policy is None or policy == DuplicatePolicy.NONE:
if self.duplicates_policy is not None and self.duplicates_policy != DuplicatePolicy.NONE:
policy = self.duplicates_policy
else:
policy = DuplicatePolicy.SKIP

if batch_size > MAX_BATCH_SIZE:
logger.warning(
f"batch_size set to {batch_size}, "
f"but maximum batch_size for Astra when using the JSON API is 20. batch_size set to 20."
)
batch_size = MAX_BATCH_SIZE
batch_size = MAX_BATCH_SIZE

def _convert_input_document(document: Union[dict, Document]):
if isinstance(document, Document):
Expand Down Expand Up @@ -196,7 +178,7 @@ def _convert_input_document(document: Union[dict, Document]):
if policy == DuplicatePolicy.SKIP:
if len(new_documents) > 0:
for batch in _batches(new_documents, batch_size):
inserted_ids = index.insert(batch) # type: ignore
inserted_ids = self.index.insert(batch) # type: ignore
insertion_counter += len(inserted_ids)
logger.info(f"write_documents inserted documents with id {inserted_ids}")
else:
Expand All @@ -205,7 +187,7 @@ def _convert_input_document(document: Union[dict, Document]):
elif policy == DuplicatePolicy.OVERWRITE:
if len(new_documents) > 0:
for batch in _batches(new_documents, batch_size):
inserted_ids = index.insert(batch) # type: ignore
inserted_ids = self.index.insert(batch) # type: ignore
insertion_counter += len(inserted_ids)
logger.info(f"write_documents inserted documents with id {inserted_ids}")
else:
Expand All @@ -214,7 +196,7 @@ def _convert_input_document(document: Union[dict, Document]):
if len(duplicate_documents) > 0:
updated_ids = []
for duplicate_doc in duplicate_documents:
updated = index.update_document(duplicate_doc, "_id") # type: ignore
updated = self.index.update_document(duplicate_doc, "_id") # type: ignore
if updated:
updated_ids.append(duplicate_doc["_id"])
insertion_counter = insertion_counter + len(updated_ids)
Expand All @@ -225,7 +207,7 @@ def _convert_input_document(document: Union[dict, Document]):
elif policy == DuplicatePolicy.FAIL:
if len(new_documents) > 0:
for batch in _batches(new_documents, batch_size):
inserted_ids = index.insert(batch) # type: ignore
inserted_ids = self.index.insert(batch) # type: ignore
insertion_counter = insertion_counter + len(inserted_ids)
logger.info(f"write_documents inserted documents with id {inserted_ids}")
else:
Expand Down
2 changes: 1 addition & 1 deletion integrations/astra/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from haystack.document_stores.types import DuplicatePolicy

from astra_haystack.document_store import AstraDocumentStore
from haystack_integrations.document_stores.astra import AstraDocumentStore


@pytest.fixture
Expand Down
2 changes: 1 addition & 1 deletion integrations/astra/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from haystack.document_stores.types import DuplicatePolicy
from haystack.testing.document_store import DocumentStoreBaseTests

from astra_haystack.document_store import AstraDocumentStore
from haystack_integrations.document_stores.astra import AstraDocumentStore


@pytest.mark.skipif(
Expand Down
10 changes: 5 additions & 5 deletions integrations/astra/tests/test_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest

from astra_haystack.retriever import AstraRetriever
from haystack_integrations.components.retrievers.astra import AstraRetriever


@pytest.mark.skipif(
Expand All @@ -16,7 +16,7 @@
def test_retriever_to_json(document_store):
retriever = AstraRetriever(document_store, filters={"foo": "bar"}, top_k=99)
assert retriever.to_dict() == {
"type": "astra_haystack.retriever.AstraRetriever",
"type": "haystack_integrations.components.retrievers.astra.retriever.AstraRetriever",
"init_parameters": {
"filters": {"foo": "bar"},
"top_k": 99,
Expand All @@ -30,7 +30,7 @@ def test_retriever_to_json(document_store):
"embedding_dim": 768,
"similarity": "cosine",
},
"type": "astra_haystack.document_store.AstraDocumentStore",
"type": "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore",
},
},
}
Expand All @@ -43,7 +43,7 @@ def test_retriever_to_json(document_store):
@pytest.mark.integration
def test_retriever_from_json():
data = {
"type": "astra_haystack.retriever.AstraRetriever",
"type": "haystack_integrations.components.retrievers.astra.retriever.AstraRetriever",
"init_parameters": {
"filters": {"bar": "baz"},
"top_k": 42,
Expand All @@ -58,7 +58,7 @@ def test_retriever_from_json():
"embedding_dim": 768,
"similarity": "cosine",
},
"type": "astra_haystack.document_store.AstraDocumentStore",
"type": "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore",
},
},
}
Expand Down