Skip to content

Commit

Permalink
Add flake8-boolean-trap ruff rule (#985)
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter authored Jul 16, 2024
1 parent 5c4b3a8 commit b99a828
Show file tree
Hide file tree
Showing 28 changed files with 104 additions and 62 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `GriptapeCloudKnowledgeBaseVectorStoreDriver` to query Griptape Cloud Knowledge Bases.

### Changed
- **BREAKING**: `BaseVectorStoreDriver.upsert_text_artifacts` optional arguments are now keyword-only arguments.
- **BREAKING**: `BaseVectorStoreDriver.upsert_text_artifact` optional arguments are now keyword-only arguments.
- **BREAKING**: `BaseVectorStoreDriver.upsert_text` optional arguments are now keyword-only arguments.
- **BREAKING**: `BaseVectorStoreDriver.does_entry_exist` optional arguments are now keyword-only arguments.
- **BREAKING**: `BaseVectorStoreDriver.load_artifacts` optional arguments are now keyword-only arguments.
- **BREAKING**: `BaseVectorStoreDriver.upsert_vector` optional arguments are now keyword-only arguments.
- **BREAKING**: `BaseVectorStoreDriver.query` optional arguments are now keyword-only arguments.
- **BREAKING**: `EventListener.publish_event`'s `flush` argument is now a keyword-only argument.
- **BREAKING**: `BaseEventListenerDriver.publish_event`'s `flush` argument is now a keyword-only argument.

### Fixed
- Parameter `count` for `QdrantVectorStoreDriver.query` now optional as per documentation.
Expand Down
4 changes: 2 additions & 2 deletions griptape/artifacts/boolean_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ def parse_bool(cls, value: Union[str, bool]) -> BooleanArtifact:
if value is not None:
if isinstance(value, str):
if value.lower() == "true":
return BooleanArtifact(True)
return BooleanArtifact(True) # noqa: FBT003
elif value.lower() == "false":
return BooleanArtifact(False)
return BooleanArtifact(False) # noqa: FBT003
elif isinstance(value, bool):
return BooleanArtifact(value)
raise ValueError(f"Cannot convert '{value}' to BooleanArtifact")
Expand Down
6 changes: 3 additions & 3 deletions griptape/drivers/event_listener/base_event_listener_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@ class BaseEventListenerDriver(ABC):
def batch(self) -> list[dict]:
return self._batch

def publish_event(self, event: BaseEvent | dict, flush: bool = False) -> None:
def publish_event(self, event: BaseEvent | dict, *, flush: bool = False) -> None:
with self.futures_executor_fn() as executor:
executor.submit(self._safe_try_publish_event, event, flush)
executor.submit(self._safe_try_publish_event, event, flush=flush)

@abstractmethod
def try_publish_event_payload(self, event_payload: dict) -> None: ...

@abstractmethod
def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: ...

def _safe_try_publish_event(self, event: BaseEvent | dict, flush: bool) -> None:
def _safe_try_publish_event(self, event: BaseEvent | dict, *, flush: bool) -> None:
try:
event_payload = event if isinstance(event, dict) else event.to_dict()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class AmazonSageMakerJumpstartPromptDriver(BasePromptDriver):
)

@stream.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_stream(self, _: Attribute, stream: bool) -> None:
def validate_stream(self, _: Attribute, stream: bool) -> None: # noqa: FBT001
if stream:
raise ValueError("streaming is not supported")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class AmazonOpenSearchVectorStoreDriver(OpenSearchVectorStoreDriver):
def upsert_vector(
self,
vector: list[float],
*,
vector_id: Optional[str] = None,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class AzureMongoDbVectorStoreDriver(MongoDbAtlasVectorStoreDriver):
def query(
self,
query: str,
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
Expand Down
30 changes: 20 additions & 10 deletions griptape/drivers/vector/base_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,24 @@ def to_artifact(self) -> BaseArtifact:
def upsert_text_artifacts(
self,
artifacts: list[TextArtifact] | dict[str, list[TextArtifact]],
*,
meta: Optional[dict] = None,
**kwargs,
) -> None:
with self.futures_executor_fn() as executor:
if isinstance(artifacts, list):
utils.execute_futures_list(
[executor.submit(self.upsert_text_artifact, a, None, meta, **kwargs) for a in artifacts],
[
executor.submit(self.upsert_text_artifact, a, namespace=None, meta=meta, **kwargs)
for a in artifacts
],
)
else:
utils.execute_futures_dict(
{
namespace: executor.submit(self.upsert_text_artifact, a, namespace, meta, **kwargs)
namespace: executor.submit(
self.upsert_text_artifact, a, namespace=namespace, meta=meta, **kwargs
)
for namespace, artifact_list in artifacts.items()
for a in artifact_list
},
Expand All @@ -64,6 +70,7 @@ def upsert_text_artifacts(
def upsert_text_artifact(
self,
artifact: TextArtifact,
*,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
vector_id: Optional[str] = None,
Expand All @@ -75,7 +82,7 @@ def upsert_text_artifact(
value = artifact.to_text() if artifact.reference is None else artifact.to_text() + str(artifact.reference)
vector_id = self._get_default_vector_id(value)

if self.does_entry_exist(vector_id, namespace):
if self.does_entry_exist(vector_id, namespace=namespace):
return vector_id
else:
meta["artifact"] = artifact.to_json()
Expand All @@ -90,14 +97,15 @@ def upsert_text_artifact(
def upsert_text(
self,
string: str,
*,
vector_id: Optional[str] = None,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
**kwargs,
) -> str:
vector_id = self._get_default_vector_id(string) if vector_id is None else vector_id

if self.does_entry_exist(vector_id, namespace):
if self.does_entry_exist(vector_id, namespace=namespace):
return vector_id
else:
return self.upsert_vector(
Expand All @@ -108,14 +116,14 @@ def upsert_text(
**kwargs,
)

def does_entry_exist(self, vector_id: str, namespace: Optional[str] = None) -> bool:
def does_entry_exist(self, vector_id: str, *, namespace: Optional[str] = None) -> bool:
try:
return self.load_entry(vector_id, namespace) is not None
return self.load_entry(vector_id, namespace=namespace) is not None
except Exception:
return False

def load_artifacts(self, namespace: Optional[str] = None) -> ListArtifact:
result = self.load_entries(namespace)
def load_artifacts(self, *, namespace: Optional[str] = None) -> ListArtifact:
result = self.load_entries(namespace=namespace)
artifacts = [r.to_artifact() for r in result]

return ListArtifact([a for a in artifacts if isinstance(a, TextArtifact)])
Expand All @@ -127,22 +135,24 @@ def delete_vector(self, vector_id: str) -> None: ...
def upsert_vector(
self,
vector: list[float],
*,
vector_id: Optional[str] = None,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
**kwargs,
) -> str: ...

@abstractmethod
def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Optional[Entry]: ...
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[Entry]: ...

@abstractmethod
def load_entries(self, namespace: Optional[str] = None) -> list[Entry]: ...
def load_entries(self, *, namespace: Optional[str] = None) -> list[Entry]: ...

@abstractmethod
def query(
self,
query: str,
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
Expand Down
5 changes: 3 additions & 2 deletions griptape/drivers/vector/dummy_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,16 @@ def upsert_vector(
) -> str:
raise DummyException(__class__.__name__, "upsert_vector")

def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
raise DummyException(__class__.__name__, "load_entry")

def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
raise DummyException(__class__.__name__, "load_entries")

def query(
self,
query: str,
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,19 @@ def upsert_text(
) -> str:
raise NotImplementedError(f"{self.__class__.__name__} does not support text upsert.")

def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> BaseVectorStoreDriver.Entry:
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> BaseVectorStoreDriver.Entry:
raise NotImplementedError(f"{self.__class__.__name__} does not support entry loading.")

def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
raise NotImplementedError(f"{self.__class__.__name__} does not support entry loading.")

def load_artifacts(self, namespace: Optional[str] = None) -> ListArtifact:
def load_artifacts(self, *, namespace: Optional[str] = None) -> ListArtifact:
raise NotImplementedError(f"{self.__class__.__name__} does not support Artifact loading.")

def query(
self,
query: str,
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: Optional[bool] = None,
Expand Down
12 changes: 7 additions & 5 deletions griptape/drivers/vector/local_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def load_entries_from_file(self, json_file: TextIO) -> dict[str, BaseVectorStore
def upsert_vector(
self,
vector: list[float],
*,
vector_id: Optional[str] = None,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
Expand All @@ -60,7 +61,7 @@ def upsert_vector(
vector_id = vector_id if vector_id else utils.str_to_hash(str(vector))

with self.thread_lock:
self.entries[self._namespaced_vector_id(vector_id, namespace)] = self.Entry(
self.entries[self._namespaced_vector_id(vector_id, namespace=namespace)] = self.Entry(
id=vector_id,
vector=vector,
meta=meta,
Expand All @@ -75,15 +76,16 @@ def upsert_vector(

return vector_id

def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
return self.entries.get(self._namespaced_vector_id(vector_id, namespace), None)
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
return self.entries.get(self._namespaced_vector_id(vector_id, namespace=namespace), None)

def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
return [entry for key, entry in self.entries.items() if namespace is None or entry.namespace == namespace]

def query(
self,
query: str,
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
Expand Down Expand Up @@ -117,5 +119,5 @@ def query(
def delete_vector(self, vector_id: str) -> NoReturn:
raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")

def _namespaced_vector_id(self, vector_id: str, namespace: Optional[str]) -> str:
def _namespaced_vector_id(self, vector_id: str, *, namespace: Optional[str]) -> str:
return vector_id if namespace is None else f"{namespace}-{vector_id}"
8 changes: 6 additions & 2 deletions griptape/drivers/vector/marqo_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class MarqoVectorStoreDriver(BaseVectorStoreDriver):
def upsert_text(
self,
string: str,
*,
vector_id: Optional[str] = None,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
Expand Down Expand Up @@ -73,6 +74,7 @@ def upsert_text(
def upsert_text_artifact(
self,
artifact: TextArtifact,
*,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
vector_id: Optional[str] = None,
Expand Down Expand Up @@ -106,7 +108,7 @@ def upsert_text_artifact(
else:
raise ValueError(f"Failed to upsert text: {response}")

def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
"""Load a document entry from the Marqo index.
Args:
Expand All @@ -127,7 +129,7 @@ def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Optiona
else:
return None

def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
"""Load all document entries from the Marqo index.
Args:
Expand Down Expand Up @@ -167,6 +169,7 @@ def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreD
def query(
self,
query: str,
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
Expand Down Expand Up @@ -228,6 +231,7 @@ def get_indexes(self) -> list[str]:
def upsert_vector(
self,
vector: list[float],
*,
vector_id: Optional[str] = None,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
Expand Down
6 changes: 4 additions & 2 deletions griptape/drivers/vector/mongodb_atlas_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def get_collection(self) -> Collection:
def upsert_vector(
self,
vector: list[float],
*,
vector_id: Optional[str] = None,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
Expand All @@ -73,7 +74,7 @@ def upsert_vector(
)
return vector_id

def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
"""Loads a document entry from the MongoDB collection based on the vector ID.
Returns:
Expand All @@ -95,7 +96,7 @@ def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Optiona
meta=doc["meta"],
)

def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
"""Loads all document entries from the MongoDB collection.
Entries can optionally be filtered by namespace.
Expand All @@ -116,6 +117,7 @@ def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreD
def query(
self,
query: str,
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
Expand Down
6 changes: 4 additions & 2 deletions griptape/drivers/vector/opensearch_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class OpenSearchVectorStoreDriver(BaseVectorStoreDriver):
def upsert_vector(
self,
vector: list[float],
*,
vector_id: Optional[str] = None,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
Expand All @@ -66,7 +67,7 @@ def upsert_vector(

return response["_id"]

def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
"""Retrieves a specific vector entry from OpenSearch based on its identifier and optional namespace.
Returns:
Expand Down Expand Up @@ -95,7 +96,7 @@ def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Optiona
logging.error(f"Error while loading entry: {e}")
return None

def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
"""Retrieves all vector entries from OpenSearch that match the optional namespace.
Returns:
Expand All @@ -122,6 +123,7 @@ def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreD
def query(
self,
query: str,
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
Expand Down
Loading

0 comments on commit b99a828

Please sign in to comment.