diff --git a/Makefile b/Makefile index 13410094cb..d14c7184dc 100644 --- a/Makefile +++ b/Makefile @@ -33,22 +33,22 @@ test/integration: .PHONY: lint lint: ## Lint project. - @poetry run ruff check --fix griptape/ + @poetry run ruff check --fix .PHONY: format format: ## Format project. - @poetry run ruff format . + @poetry run ruff format .PHONY: check check: check/format check/lint check/types check/spell ## Run all checks. .PHONY: check/format check/format: - @poetry run ruff format --check griptape/ + @poetry run ruff format --check .PHONY: check/lint check/lint: - @poetry run ruff check griptape/ + @poetry run ruff check .PHONY: check/types check/types: diff --git a/docs/gen_ref_pages.py b/docs/gen_ref_pages.py index 62fe85b0c3..85a3f31271 100644 --- a/docs/gen_ref_pages.py +++ b/docs/gen_ref_pages.py @@ -1,11 +1,12 @@ """Generate the code reference pages and navigation.""" -from textwrap import dedent from pathlib import Path +from textwrap import dedent + import mkdocs_gen_files -def build_reference_docs(): +def build_reference_docs() -> None: nav = mkdocs_gen_files.Nav() for path in sorted(Path("griptape").rglob("*.py")): @@ -37,8 +38,8 @@ def build_reference_docs(): index_file.write( dedent( """ - # Overview - This section of the documentation is dedicated to a reference API of Griptape. + # Overview + This section of the documentation is dedicated to a reference API of Griptape. Here you will find every class, function, and method that is available to you when using the library. """ ) diff --git a/docs/plugins/swagger_ui_plugin.py b/docs/plugins/swagger_ui_plugin.py index 6d5fb52da3..499d74cf55 100644 --- a/docs/plugins/swagger_ui_plugin.py +++ b/docs/plugins/swagger_ui_plugin.py @@ -1,4 +1,5 @@ import os +from typing import Any import markdown from jinja2 import Environment, FileSystemLoader, select_autoescape @@ -11,7 +12,7 @@ } -def generate_page_contents(page): +def generate_page_contents(page: Any) -> str: spec_url = config_scheme["spec_url"] tmpl_url = config_scheme["template"] env = Environment(loader=FileSystemLoader("docs/plugins/tmpl"), autoescape=select_autoescape(["html", "xml"])) @@ -23,11 +24,11 @@ def generate_page_contents(page): return tmpl_out -def on_config(config): - print("INFO - swagger-ui plugin ENABLED") +def on_config(config: Any) -> None: + pass -def on_page_read_source(page, config): +def on_page_read_source(page: Any, config: Any) -> Any: index_path = os.path.join(config["docs_dir"], config_scheme["outfile"]) page_path = os.path.join(config["docs_dir"], page.file.src_path) if index_path == page_path: diff --git a/pyproject.toml b/pyproject.toml index 2e094141a6..a627418e3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -243,6 +243,11 @@ convention = "google" [tool.ruff.lint.per-file-ignores] "__init__.py" = ["I"] +"tests/*" = [ + "ANN001", # missing-type-function-argument + "ANN201", # missing-return-type-undocumented-public-function + "ANN202", # missing-return-type-private-function +] [tool.ruff.lint.flake8-tidy-imports.banned-api] "attr".msg = "The attr module is deprecated, use attrs instead." diff --git a/tests/integration/drivers/vector/test_pgvector_vector_store_driver.py b/tests/integration/drivers/vector/test_pgvector_vector_store_driver.py index e0345cecb7..c3ce6a09d9 100644 --- a/tests/integration/drivers/vector/test_pgvector_vector_store_driver.py +++ b/tests/integration/drivers/vector/test_pgvector_vector_store_driver.py @@ -1,9 +1,11 @@ import uuid + import pytest +from sqlalchemy import create_engine + from griptape.drivers import PgVectorVectorStoreDriver from tests.mocks.mock_embedding_driver import MockEmbeddingDriver from tests.utils.postgres import can_connect_to_postgres -from sqlalchemy import create_engine @pytest.mark.skipif(not can_connect_to_postgres(), reason="Postgres is not present") @@ -34,7 +36,7 @@ def test_initialize_requires_engine_or_connection_string(self, embedding_driver) def test_initialize_accepts_engine(self, embedding_driver): engine = create_engine(self.connection_string) - driver = PgVectorVectorStoreDriver(embedding_driver=embedding_driver, engine=engine, table_name=self.table_name) + driver = PgVectorVectorStoreDriver(embedding_driver=embedding_driver, engine=engine, table_name=self.table_name) # pyright: ignore[reportArgumentType] driver.setup() @@ -86,11 +88,9 @@ def test_can_insert_and_load_entry_with_namespace(self, vector_store_driver): assert result.vector == pytest.approx(self.vec1) def test_can_load_entries(self, vector_store_driver): - """ - Depending on when this test is executed relative to the others, - we don't know exactly how many vectors will be returned. We can - ensure that at least two exist and confirm that those are found. - """ + # Depending on when this test is executed relative to the others, + # we don't know exactly how many vectors will be returned. We can + # ensure that at least two exist and confirm that those are found. vec1_id = vector_store_driver.upsert_vector(self.vec1) vec2_id = vector_store_driver.upsert_vector(self.vec2) @@ -173,8 +173,9 @@ def test_query_returns_vectors_when_requested(self, vector_store_driver): assert results[0].vector == pytest.approx(embedding) def test_can_use_custom_table_name(self, embedding_driver, vector_store_driver): - """This test ensures at least one row exists in the default table before specifying - a custom table name. After inserting another row, we should be able to query only one + """This test ensures at least one row exists in the default table before specifying a custom table name. + + After inserting another row, we should be able to query only one vector from the table, and it should be the vector added to the table with the new name. """ vector_store_driver.upsert_vector(self.vec1) diff --git a/tests/integration/rules/test_rule.py b/tests/integration/rules/test_rule.py index f04996040f..a62263c576 100644 --- a/tests/integration/rules/test_rule.py +++ b/tests/integration/rules/test_rule.py @@ -1,14 +1,15 @@ -from tests.utils.structure_tester import StructureTester import pytest +from tests.utils.structure_tester import StructureTester + class TestRule: @pytest.fixture( autouse=True, params=StructureTester.RULE_CAPABLE_PROMPT_DRIVERS, ids=StructureTester.prompt_driver_id_fn ) def structure_tester(self, request): - from griptape.structures import Agent from griptape.rules import Rule + from griptape.structures import Agent agent = Agent(prompt_driver=request.param, rules=[Rule("Your name is Tony.")]) diff --git a/tests/integration/tasks/test_csv_extraction_task.py b/tests/integration/tasks/test_csv_extraction_task.py index 4624431ca4..db58b96158 100644 --- a/tests/integration/tasks/test_csv_extraction_task.py +++ b/tests/integration/tasks/test_csv_extraction_task.py @@ -1,6 +1,7 @@ -from tests.utils.structure_tester import StructureTester import pytest +from tests.utils.structure_tester import StructureTester + class TestCsvExtractionTask: @pytest.fixture( @@ -9,9 +10,9 @@ class TestCsvExtractionTask: ids=StructureTester.prompt_driver_id_fn, ) def structure_tester(self, request): - from griptape.tasks import ExtractionTask - from griptape.structures import Agent from griptape.engines import CsvExtractionEngine + from griptape.structures import Agent + from griptape.tasks import ExtractionTask columns = ["Name", "Age", "Address"] diff --git a/tests/integration/tasks/test_json_extraction_task.py b/tests/integration/tasks/test_json_extraction_task.py index fdd7140f3a..115f805dab 100644 --- a/tests/integration/tasks/test_json_extraction_task.py +++ b/tests/integration/tasks/test_json_extraction_task.py @@ -1,6 +1,7 @@ -from tests.utils.structure_tester import StructureTester import pytest +from tests.utils.structure_tester import StructureTester + class TestJsonExtractionTask: @pytest.fixture( @@ -9,11 +10,12 @@ class TestJsonExtractionTask: ids=StructureTester.prompt_driver_id_fn, ) def structure_tester(self, request): - from griptape.tasks import ExtractionTask - from griptape.structures import Agent - from griptape.engines import JsonExtractionEngine from schema import Schema + from griptape.engines import JsonExtractionEngine + from griptape.structures import Agent + from griptape.tasks import ExtractionTask + # Define some JSON data user_schema = Schema({"users": [{"name": str, "age": int, "location": str}]}).json_schema("UserSchema") diff --git a/tests/integration/tasks/test_prompt_task.py b/tests/integration/tasks/test_prompt_task.py index 6734df6784..1d223b4ca8 100644 --- a/tests/integration/tasks/test_prompt_task.py +++ b/tests/integration/tasks/test_prompt_task.py @@ -1,6 +1,7 @@ -from tests.utils.structure_tester import StructureTester import pytest +from tests.utils.structure_tester import StructureTester + class TestPromptTask: @pytest.fixture( diff --git a/tests/integration/tasks/test_rag_task.py b/tests/integration/tasks/test_rag_task.py index c0383002cc..ce3a9140de 100644 --- a/tests/integration/tasks/test_rag_task.py +++ b/tests/integration/tasks/test_rag_task.py @@ -1,7 +1,8 @@ +import pytest + from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.utils.defaults import rag_engine from tests.utils.structure_tester import StructureTester -import pytest class TestRagTask: @@ -11,10 +12,10 @@ class TestRagTask: ids=StructureTester.prompt_driver_id_fn, ) def structure_tester(self, request): + from griptape.artifacts import TextArtifact + from griptape.drivers import LocalVectorStoreDriver, OpenAiEmbeddingDriver from griptape.structures import Agent from griptape.tasks import RagTask - from griptape.drivers import LocalVectorStoreDriver, OpenAiEmbeddingDriver - from griptape.artifacts import TextArtifact vector_store_driver = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver()) artifact = TextArtifact("John Doe works as as software engineer at Griptape.") diff --git a/tests/integration/tasks/test_text_summary_task.py b/tests/integration/tasks/test_text_summary_task.py index 9cbf1d9058..ff6597ba00 100644 --- a/tests/integration/tasks/test_text_summary_task.py +++ b/tests/integration/tasks/test_text_summary_task.py @@ -1,6 +1,7 @@ -from tests.utils.structure_tester import StructureTester import pytest +from tests.utils.structure_tester import StructureTester + class TestTextSummaryTask: @pytest.fixture( @@ -10,8 +11,8 @@ class TestTextSummaryTask: ) def structure_tester(self, request): from griptape.engines.summary.prompt_summary_engine import PromptSummaryEngine - from griptape.tasks import TextSummaryTask from griptape.structures import Agent + from griptape.tasks import TextSummaryTask agent = Agent(conversation_memory=None, prompt_driver=request.param) agent.add_task(TextSummaryTask(summary_engine=PromptSummaryEngine(prompt_driver=request.param))) @@ -21,17 +22,17 @@ def structure_tester(self, request): def test_summary_task(self, structure_tester): structure_tester.run( """ - Meeting transcriot: - Miguel: Hi Brant, I want to discuss the workstream for our new product launch - Brant: Sure Miguel, is there anything in particular you want to discuss? - Miguel: Yes, I want to talk about how users enter into the product. - Brant: Ok, in that case let me add in Namita. - Namita: Hey everyone - Brant: Hi Namita, Miguel wants to discuss how users enter into the product. - Miguel: its too complicated and we should remove friction. for example, why do I need to fill out additional forms? I also find it difficult to find where to access the product when I first land on the landing page. - Brant: I would also add that I think there are too many steps. - Namita: Ok, I can work on the landing page to make the product more discoverable but brant can you work on the additional forms? - Brant: Yes but I would need to work with James from another team as he needs to unblock the sign up workflow. Miguel can you document any other concerns so that I can discuss with James only once? - Miguel: Sure. + Meeting transcriot: + Miguel: Hi Brant, I want to discuss the workstream for our new product launch + Brant: Sure Miguel, is there anything in particular you want to discuss? + Miguel: Yes, I want to talk about how users enter into the product. + Brant: Ok, in that case let me add in Namita. + Namita: Hey everyone + Brant: Hi Namita, Miguel wants to discuss how users enter into the product. + Miguel: its too complicated and we should remove friction. for example, why do I need to fill out additional forms? I also find it difficult to find where to access the product when I first land on the landing page. + Brant: I would also add that I think there are too many steps. + Namita: Ok, I can work on the landing page to make the product more discoverable but brant can you work on the additional forms? + Brant: Yes but I would need to work with James from another team as he needs to unblock the sign up workflow. Miguel can you document any other concerns so that I can discuss with James only once? + Miguel: Sure. """ ) diff --git a/tests/integration/tasks/test_tool_task.py b/tests/integration/tasks/test_tool_task.py index 1fa48e98d9..aee0af1107 100644 --- a/tests/integration/tasks/test_tool_task.py +++ b/tests/integration/tasks/test_tool_task.py @@ -1,6 +1,7 @@ -from tests.utils.structure_tester import StructureTester import pytest +from tests.utils.structure_tester import StructureTester + class TestToolTask: @pytest.fixture( diff --git a/tests/integration/tasks/test_toolkit_task.py b/tests/integration/tasks/test_toolkit_task.py index a4d8c30c32..8dfcfdc734 100644 --- a/tests/integration/tasks/test_toolkit_task.py +++ b/tests/integration/tasks/test_toolkit_task.py @@ -1,6 +1,7 @@ -from tests.utils.structure_tester import StructureTester import pytest +from tests.utils.structure_tester import StructureTester + class TestToolkitTask: @pytest.fixture( @@ -10,9 +11,10 @@ class TestToolkitTask: ) def structure_tester(self, request): import os - from griptape.structures import Agent - from griptape.tools import WebScraper, WebSearch, TaskMemoryClient + from griptape.drivers import GoogleWebSearchDriver + from griptape.structures import Agent + from griptape.tools import TaskMemoryClient, WebScraper, WebSearch return StructureTester( Agent( diff --git a/tests/integration/test_code_blocks.py b/tests/integration/test_code_blocks.py index 3a267a5290..2da683a2af 100644 --- a/tests/integration/test_code_blocks.py +++ b/tests/integration/test_code_blocks.py @@ -2,8 +2,8 @@ import os import pytest -from tests.utils.code_blocks import get_all_code_blocks, check_py_string +from tests.utils.code_blocks import check_py_string, get_all_code_blocks if "DOCS_ALL_CHANGED_FILES" in os.environ and os.environ["DOCS_ALL_CHANGED_FILES"] != "": docs_all_changed_files = os.environ["DOCS_ALL_CHANGED_FILES"].split() diff --git a/tests/integration/tools/test_calculator.py b/tests/integration/tools/test_calculator.py index 9015c7158e..2547b947d8 100644 --- a/tests/integration/tools/test_calculator.py +++ b/tests/integration/tools/test_calculator.py @@ -1,6 +1,7 @@ -from tests.utils.structure_tester import StructureTester import pytest +from tests.utils.structure_tester import StructureTester + class TestCalculator: @pytest.fixture( diff --git a/tests/integration/tools/test_file_manager.py b/tests/integration/tools/test_file_manager.py index 462e664700..8a283c6e85 100644 --- a/tests/integration/tools/test_file_manager.py +++ b/tests/integration/tools/test_file_manager.py @@ -1,6 +1,7 @@ -from tests.utils.structure_tester import StructureTester import pytest +from tests.utils.structure_tester import StructureTester + class TestFileManager: @pytest.fixture( diff --git a/tests/integration/tools/test_google_docs_client.py b/tests/integration/tools/test_google_docs_client.py index dfb1eb95b6..4d70aac17b 100644 --- a/tests/integration/tools/test_google_docs_client.py +++ b/tests/integration/tools/test_google_docs_client.py @@ -1,5 +1,7 @@ -import pytest import os + +import pytest + from tests.utils.structure_tester import StructureTester diff --git a/tests/integration/tools/test_google_drive_client.py b/tests/integration/tools/test_google_drive_client.py index 9bbbacfb58..23ebb1b328 100644 --- a/tests/integration/tools/test_google_drive_client.py +++ b/tests/integration/tools/test_google_drive_client.py @@ -1,5 +1,7 @@ -import pytest import os + +import pytest + from tests.utils.structure_tester import StructureTester diff --git a/tests/mocks/docker/fake_api.py b/tests/mocks/docker/fake_api.py index 3d5a411e5c..dcf0b45212 100644 --- a/tests/mocks/docker/fake_api.py +++ b/tests/mocks/docker/fake_api.py @@ -531,7 +531,6 @@ def post_fake_secret(): f"{prefix}/{CURRENT_VERSION}/containers/{FAKE_CONTAINER_ID}/unpause": post_fake_unpause_container, f"{prefix}/{CURRENT_VERSION}/containers/{FAKE_CONTAINER_ID}/restart": post_fake_restart_container, f"{prefix}/{CURRENT_VERSION}/containers/{FAKE_CONTAINER_ID}": delete_fake_remove_container, - f"{prefix}/{CURRENT_VERSION}/images/create": post_fake_image_create, f"{prefix}/{CURRENT_VERSION}/images/{FAKE_IMAGE_ID}": delete_fake_remove_image, f"{prefix}/{CURRENT_VERSION}/images/{FAKE_IMAGE_ID}/get": get_fake_get_image, f"{prefix}/{CURRENT_VERSION}/images/load": post_fake_load_image, @@ -544,20 +543,20 @@ def post_fake_secret(): f"{prefix}/{CURRENT_VERSION}/events": get_fake_events, (f"{prefix}/{CURRENT_VERSION}/volumes", "GET"): get_fake_volume_list, (f"{prefix}/{CURRENT_VERSION}/volumes/create", "POST"): get_fake_volume, - ("{1}/{0}/volumes/{2}".format(CURRENT_VERSION, prefix, FAKE_VOLUME_NAME), "GET"): get_fake_volume, - ("{1}/{0}/volumes/{2}".format(CURRENT_VERSION, prefix, FAKE_VOLUME_NAME), "DELETE"): fake_remove_volume, - ("{1}/{0}/nodes/{2}/update?version=1".format(CURRENT_VERSION, prefix, FAKE_NODE_ID), "POST"): post_fake_update_node, + (f"{prefix}/{CURRENT_VERSION}/volumes/{FAKE_VOLUME_NAME}", "GET"): get_fake_volume, + (f"{prefix}/{CURRENT_VERSION}/volumes/{FAKE_VOLUME_NAME}", "DELETE"): fake_remove_volume, + (f"{prefix}/{CURRENT_VERSION}/nodes/{FAKE_NODE_ID}/update?version=1", "POST"): post_fake_update_node, (f"{prefix}/{CURRENT_VERSION}/swarm/join", "POST"): post_fake_join_swarm, (f"{prefix}/{CURRENT_VERSION}/networks", "GET"): get_fake_network_list, (f"{prefix}/{CURRENT_VERSION}/networks/create", "POST"): post_fake_network, - ("{1}/{0}/networks/{2}".format(CURRENT_VERSION, prefix, FAKE_NETWORK_ID), "GET"): get_fake_network, - ("{1}/{0}/networks/{2}".format(CURRENT_VERSION, prefix, FAKE_NETWORK_ID), "DELETE"): delete_fake_network, + (f"{prefix}/{CURRENT_VERSION}/networks/{FAKE_NETWORK_ID}", "GET"): get_fake_network, + (f"{prefix}/{CURRENT_VERSION}/networks/{FAKE_NETWORK_ID}", "DELETE"): delete_fake_network, ( - "{1}/{0}/networks/{2}/connect".format(CURRENT_VERSION, prefix, FAKE_NETWORK_ID), + f"{prefix}/{CURRENT_VERSION}/networks/{FAKE_NETWORK_ID}/connect", "POST", ): post_fake_network_connect, ( - "{1}/{0}/networks/{2}/disconnect".format(CURRENT_VERSION, prefix, FAKE_NETWORK_ID), + f"{prefix}/{CURRENT_VERSION}/networks/{FAKE_NETWORK_ID}/disconnect", "POST", ): post_fake_network_disconnect, f"{prefix}/{CURRENT_VERSION}/secrets/create": post_fake_secret, diff --git a/tests/mocks/docker/fake_api_client.py b/tests/mocks/docker/fake_api_client.py index 05b06216af..25df7ab83f 100644 --- a/tests/mocks/docker/fake_api_client.py +++ b/tests/mocks/docker/fake_api_client.py @@ -1,15 +1,14 @@ import copy +from unittest import mock import docker from docker.constants import DEFAULT_DOCKER_API_VERSION -from unittest import mock + from . import fake_api class CopyReturnMagicMock(mock.MagicMock): - """ - A MagicMock which deep copies every return value. - """ + """A MagicMock which deep copies every return value.""" def _mock_call(self, *args, **kwargs): ret = super()._mock_call(*args, **kwargs) @@ -19,13 +18,11 @@ def _mock_call(self, *args, **kwargs): def make_fake_api_client(overrides=None): - """ - Returns non-complete fake APIClient. + """Returns non-complete fake APIClient. This returns most of the default cases correctly, but most arguments that change behaviour will not work. """ - if overrides is None: overrides = {} api_client = docker.APIClient(version=DEFAULT_DOCKER_API_VERSION) @@ -57,9 +54,7 @@ def make_fake_api_client(overrides=None): def make_fake_client(overrides=None): - """ - Returns a Client with a fake APIClient. - """ + """Returns a Client with a fake APIClient.""" client = docker.DockerClient(version=DEFAULT_DOCKER_API_VERSION) client.api = make_fake_api_client(overrides) return client diff --git a/tests/mocks/invalid_mock_tool/tool.py b/tests/mocks/invalid_mock_tool/tool.py index 91b2f78f74..fc761cae57 100644 --- a/tests/mocks/invalid_mock_tool/tool.py +++ b/tests/mocks/invalid_mock_tool/tool.py @@ -1,5 +1,6 @@ from attrs import define, field -from schema import Schema, Literal +from schema import Literal, Schema + from griptape.tools import BaseTool from griptape.utils.decorators import activity diff --git a/tests/mocks/mock_audio_input_task.py b/tests/mocks/mock_audio_input_task.py index d6a27d9689..95b8c88d08 100644 --- a/tests/mocks/mock_audio_input_task.py +++ b/tests/mocks/mock_audio_input_task.py @@ -1,4 +1,5 @@ from attrs import define + from griptape.artifacts import TextArtifact from griptape.tasks.base_audio_input_task import BaseAudioInputTask diff --git a/tests/mocks/mock_embedding_driver.py b/tests/mocks/mock_embedding_driver.py index e21c56308d..46d9bf5157 100644 --- a/tests/mocks/mock_embedding_driver.py +++ b/tests/mocks/mock_embedding_driver.py @@ -1,4 +1,7 @@ -from attrs import field, define +from __future__ import annotations + +from attrs import define, field + from griptape.drivers import BaseEmbeddingDriver from tests.mocks.mock_tokenizer import MockTokenizer diff --git a/tests/mocks/mock_event_listener_driver.py b/tests/mocks/mock_event_listener_driver.py index 560fb87338..5833dd1c07 100644 --- a/tests/mocks/mock_event_listener_driver.py +++ b/tests/mocks/mock_event_listener_driver.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from attrs import define from griptape.drivers import BaseEventListenerDriver diff --git a/tests/mocks/mock_failing_prompt_driver.py b/tests/mocks/mock_failing_prompt_driver.py index 0dbeb8fdac..18895fdc9a 100644 --- a/tests/mocks/mock_failing_prompt_driver.py +++ b/tests/mocks/mock_failing_prompt_driver.py @@ -1,12 +1,17 @@ from __future__ import annotations -from collections.abc import Iterator + +from typing import TYPE_CHECKING + from attrs import define from griptape.artifacts import TextArtifact -from griptape.common import PromptStack, Message, TextMessageContent, DeltaMessage, TextDeltaMessageContent +from griptape.common import DeltaMessage, Message, PromptStack, TextDeltaMessageContent, TextMessageContent from griptape.drivers import BasePromptDriver from griptape.tokenizers import BaseTokenizer, OpenAiTokenizer +if TYPE_CHECKING: + from collections.abc import Iterator + @define class MockFailingPromptDriver(BasePromptDriver): diff --git a/tests/mocks/mock_image_generation_driver.py b/tests/mocks/mock_image_generation_driver.py index de94771e2b..10de110711 100644 --- a/tests/mocks/mock_image_generation_driver.py +++ b/tests/mocks/mock_image_generation_driver.py @@ -1,5 +1,9 @@ +from __future__ import annotations + from typing import Optional + from attrs import define + from griptape.artifacts import ImageArtifact from griptape.drivers.image_generation.base_image_generation_driver import BaseImageGenerationDriver diff --git a/tests/mocks/mock_image_generation_task.py b/tests/mocks/mock_image_generation_task.py index 1c79b42a9e..b55c5c9953 100644 --- a/tests/mocks/mock_image_generation_task.py +++ b/tests/mocks/mock_image_generation_task.py @@ -13,7 +13,7 @@ def input(self) -> TextArtifact: return self._input @input.setter - def input(self, value: str): + def input(self, value: str) -> None: self._input = TextArtifact(value) def run(self) -> ImageArtifact: diff --git a/tests/mocks/mock_image_query_driver.py b/tests/mocks/mock_image_query_driver.py index d3bec164f3..8f8cc888cb 100644 --- a/tests/mocks/mock_image_query_driver.py +++ b/tests/mocks/mock_image_query_driver.py @@ -1,8 +1,11 @@ +from __future__ import annotations + from typing import Optional + from attrs import define + from griptape.artifacts import ImageArtifact, TextArtifact from griptape.drivers import BaseImageQueryDriver -from griptape.drivers.image_generation.base_image_generation_driver import BaseImageGenerationDriver @define diff --git a/tests/mocks/mock_multi_text_input_task.py b/tests/mocks/mock_multi_text_input_task.py index 7ab5aedf9e..be00bbf650 100644 --- a/tests/mocks/mock_multi_text_input_task.py +++ b/tests/mocks/mock_multi_text_input_task.py @@ -1,4 +1,5 @@ from attrs import define + from griptape.artifacts import TextArtifact from griptape.tasks import BaseMultiTextInputTask diff --git a/tests/mocks/mock_prompt_driver.py b/tests/mocks/mock_prompt_driver.py index 4786b78a6f..5a23dd8a29 100644 --- a/tests/mocks/mock_prompt_driver.py +++ b/tests/mocks/mock_prompt_driver.py @@ -1,17 +1,19 @@ from __future__ import annotations -from collections.abc import Iterator -from typing import Callable +from typing import TYPE_CHECKING, Callable from attrs import define, field from griptape.artifacts import TextArtifact -from griptape.common import PromptStack, Message, DeltaMessage, TextMessageContent, TextDeltaMessageContent +from griptape.common import DeltaMessage, Message, PromptStack, TextDeltaMessageContent, TextMessageContent from griptape.drivers import BasePromptDriver -from griptape.tokenizers import BaseTokenizer - from tests.mocks.mock_tokenizer import MockTokenizer +if TYPE_CHECKING: + from collections.abc import Iterator + + from griptape.tokenizers import BaseTokenizer + @define class MockPromptDriver(BasePromptDriver): diff --git a/tests/mocks/mock_serializable.py b/tests/mocks/mock_serializable.py index b02c071aaa..8491fbb1cc 100644 --- a/tests/mocks/mock_serializable.py +++ b/tests/mocks/mock_serializable.py @@ -1,5 +1,9 @@ -from attrs import define, field +from __future__ import annotations + from typing import Optional + +from attrs import define, field + from griptape.mixins import SerializableMixin diff --git a/tests/mocks/mock_structure_config.py b/tests/mocks/mock_structure_config.py index 8309f541b5..3f95288f49 100644 --- a/tests/mocks/mock_structure_config.py +++ b/tests/mocks/mock_structure_config.py @@ -1,9 +1,10 @@ -from attrs import define, field, Factory +from attrs import Factory, define, field + from griptape.config import StructureConfig +from tests.mocks.mock_embedding_driver import MockEmbeddingDriver from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver from tests.mocks.mock_image_query_driver import MockImageQueryDriver from tests.mocks.mock_prompt_driver import MockPromptDriver -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver @define diff --git a/tests/mocks/mock_task.py b/tests/mocks/mock_task.py index 42595f6eba..81aa037137 100644 --- a/tests/mocks/mock_task.py +++ b/tests/mocks/mock_task.py @@ -1,5 +1,6 @@ from attrs import define, field -from griptape.artifacts import TextArtifact, BaseArtifact + +from griptape.artifacts import BaseArtifact, TextArtifact from griptape.tasks import BaseTask diff --git a/tests/mocks/mock_text_input_task.py b/tests/mocks/mock_text_input_task.py index 930c77e74f..f1439bd428 100644 --- a/tests/mocks/mock_text_input_task.py +++ b/tests/mocks/mock_text_input_task.py @@ -1,4 +1,5 @@ from attrs import define + from griptape.artifacts import TextArtifact from griptape.tasks import BaseTextInputTask diff --git a/tests/mocks/mock_tokenizer.py b/tests/mocks/mock_tokenizer.py index eff103e991..b16332ce0b 100644 --- a/tests/mocks/mock_tokenizer.py +++ b/tests/mocks/mock_tokenizer.py @@ -1,5 +1,7 @@ from __future__ import annotations + from attrs import define + from griptape.tokenizers import BaseTokenizer diff --git a/tests/mocks/mock_tool/tool.py b/tests/mocks/mock_tool/tool.py index e79023a6e5..7d09f391e1 100644 --- a/tests/mocks/mock_tool/tool.py +++ b/tests/mocks/mock_tool/tool.py @@ -1,6 +1,7 @@ from attrs import define, field -from schema import Schema, Literal -from griptape.artifacts import TextArtifact, ErrorArtifact, BaseArtifact, ListArtifact +from schema import Literal, Schema + +from griptape.artifacts import BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact from griptape.tools import BaseTool from griptape.utils.decorators import activity @@ -49,7 +50,7 @@ def test_str_output(self, value: dict) -> str: @activity(config={"description": "test description"}) def test_no_schema(self, value: dict) -> str: - return f"no schema" + return "no schema" @activity(config={"description": "test description"}) def test_list_output(self, value: dict) -> ListArtifact: diff --git a/tests/unit/artifacts/test_action_artifact.py b/tests/unit/artifacts/test_action_artifact.py index e415bbdaf3..2530ed8c36 100644 --- a/tests/unit/artifacts/test_action_artifact.py +++ b/tests/unit/artifacts/test_action_artifact.py @@ -1,7 +1,9 @@ import json + import pytest -from griptape.common import ToolAction + from griptape.artifacts import ActionArtifact, BaseArtifact +from griptape.common import ToolAction class TestActionArtifact: @@ -11,7 +13,7 @@ def action(self) -> ToolAction: def test___add__(self, action): with pytest.raises(NotImplementedError): - result = ActionArtifact(action) + ActionArtifact(action) + ActionArtifact(action) + ActionArtifact(action) def test_to_text(self, action): assert ActionArtifact(action).to_text() == json.dumps(action.to_dict()) diff --git a/tests/unit/artifacts/test_audio_artifact.py b/tests/unit/artifacts/test_audio_artifact.py index 93ea816e45..e4c01a192d 100644 --- a/tests/unit/artifacts/test_audio_artifact.py +++ b/tests/unit/artifacts/test_audio_artifact.py @@ -1,4 +1,5 @@ import pytest + from griptape.artifacts import AudioArtifact, BaseArtifact diff --git a/tests/unit/artifacts/test_base_artifact.py b/tests/unit/artifacts/test_base_artifact.py index a7d7acaaf3..6cf8f4466f 100644 --- a/tests/unit/artifacts/test_base_artifact.py +++ b/tests/unit/artifacts/test_base_artifact.py @@ -1,12 +1,13 @@ import pytest + from griptape.artifacts import ( BaseArtifact, - TextArtifact, + BlobArtifact, ErrorArtifact, + ImageArtifact, InfoArtifact, ListArtifact, - BlobArtifact, - ImageArtifact, + TextArtifact, ) diff --git a/tests/unit/artifacts/test_base_media_artifact.py b/tests/unit/artifacts/test_base_media_artifact.py index 2829a1e2f7..b07e5db799 100644 --- a/tests/unit/artifacts/test_base_media_artifact.py +++ b/tests/unit/artifacts/test_base_media_artifact.py @@ -1,5 +1,4 @@ import pytest - from attrs import define from griptape.artifacts import MediaArtifact diff --git a/tests/unit/artifacts/test_blob_artifact.py b/tests/unit/artifacts/test_blob_artifact.py index 08844d2415..a50a673f49 100644 --- a/tests/unit/artifacts/test_blob_artifact.py +++ b/tests/unit/artifacts/test_blob_artifact.py @@ -1,6 +1,8 @@ import base64 + import pytest -from griptape.artifacts import BlobArtifact, BaseArtifact + +from griptape.artifacts import BaseArtifact, BlobArtifact class TestBlobArtifact: diff --git a/tests/unit/artifacts/test_boolean_artifact.py b/tests/unit/artifacts/test_boolean_artifact.py index bcad676730..699711929d 100644 --- a/tests/unit/artifacts/test_boolean_artifact.py +++ b/tests/unit/artifacts/test_boolean_artifact.py @@ -1,4 +1,5 @@ import pytest + from griptape.artifacts import BooleanArtifact @@ -31,5 +32,5 @@ def test_value_type_conversion(self): assert BooleanArtifact("false").value is True assert BooleanArtifact([1]).value is True assert BooleanArtifact([]).value is False - assert BooleanArtifact(False) == False - assert BooleanArtifact(True) == True + assert BooleanArtifact(False) is False + assert BooleanArtifact(True) is True diff --git a/tests/unit/artifacts/test_image_artifact.py b/tests/unit/artifacts/test_image_artifact.py index 687397260c..885f262341 100644 --- a/tests/unit/artifacts/test_image_artifact.py +++ b/tests/unit/artifacts/test_image_artifact.py @@ -1,5 +1,6 @@ import pytest -from griptape.artifacts import ImageArtifact, BaseArtifact + +from griptape.artifacts import BaseArtifact, ImageArtifact class TestImageArtifact: diff --git a/tests/unit/artifacts/test_list_artifact.py b/tests/unit/artifacts/test_list_artifact.py index 044ca8ed52..06d2346458 100644 --- a/tests/unit/artifacts/test_list_artifact.py +++ b/tests/unit/artifacts/test_list_artifact.py @@ -1,5 +1,6 @@ import pytest -from griptape.artifacts import ListArtifact, TextArtifact, BlobArtifact, CsvRowArtifact + +from griptape.artifacts import BlobArtifact, CsvRowArtifact, ListArtifact, TextArtifact class TestListArtifact: diff --git a/tests/unit/artifacts/test_text_artifact.py b/tests/unit/artifacts/test_text_artifact.py index 6ea2c66976..067da0912e 100644 --- a/tests/unit/artifacts/test_text_artifact.py +++ b/tests/unit/artifacts/test_text_artifact.py @@ -1,6 +1,8 @@ import json + import pytest -from griptape.artifacts import TextArtifact, BaseArtifact + +from griptape.artifacts import BaseArtifact, TextArtifact from griptape.tokenizers import OpenAiTokenizer from tests.mocks.mock_embedding_driver import MockEmbeddingDriver diff --git a/tests/unit/chunkers/test_markdown_chunker.py b/tests/unit/chunkers/test_markdown_chunker.py index 08709c0924..1ee0f71c7e 100644 --- a/tests/unit/chunkers/test_markdown_chunker.py +++ b/tests/unit/chunkers/test_markdown_chunker.py @@ -1,4 +1,5 @@ import pytest + from griptape.chunkers import MarkdownChunker from tests.unit.chunkers.test_text_chunker import gen_paragraph diff --git a/tests/unit/chunkers/test_pdf_chunker.py b/tests/unit/chunkers/test_pdf_chunker.py index 605c2f6e66..3578c864d0 100644 --- a/tests/unit/chunkers/test_pdf_chunker.py +++ b/tests/unit/chunkers/test_pdf_chunker.py @@ -1,6 +1,8 @@ import os + import pytest from pypdf import PdfReader + from griptape.chunkers import PdfChunker MAX_TOKENS = 500 diff --git a/tests/unit/chunkers/test_text_chunker.py b/tests/unit/chunkers/test_text_chunker.py index 243b287e1c..67e07ea120 100644 --- a/tests/unit/chunkers/test_text_chunker.py +++ b/tests/unit/chunkers/test_text_chunker.py @@ -1,4 +1,5 @@ import pytest + from griptape.artifacts import TextArtifact from griptape.chunkers import TextChunker from tests.unit.chunkers.utils import gen_paragraph diff --git a/tests/unit/chunkers/utils.py b/tests/unit/chunkers/utils.py index 80335e9788..b9e7b85390 100644 --- a/tests/unit/chunkers/utils.py +++ b/tests/unit/chunkers/utils.py @@ -5,7 +5,9 @@ def gen_paragraph(max_tokens: int, tokenizer: BaseTokenizer, sentence_separator: all_text = "" word = "foo" index = 0 - add_word = lambda base, w, i: sentence_separator.join([base, f"{w}-{i}"]) + + def add_word(base, w, i): + return sentence_separator.join([base, f"{w}-{i}"]) while max_tokens >= tokenizer.count_tokens(add_word(all_text, word, index)): all_text = f"{word}-{index}" if all_text == "" else add_word(all_text, word, index) diff --git a/tests/unit/common/contents/test_action_call_message_content.py b/tests/unit/common/contents/test_action_call_message_content.py index 2e2e69d27a..d6c3f438f8 100644 --- a/tests/unit/common/contents/test_action_call_message_content.py +++ b/tests/unit/common/contents/test_action_call_message_content.py @@ -1,6 +1,7 @@ import pytest + from griptape.artifacts.action_artifact import ActionArtifact -from griptape.common import ActionCallMessageContent, ActionCallDeltaMessageContent, ToolAction +from griptape.common import ActionCallDeltaMessageContent, ActionCallMessageContent, ToolAction class TestActionCallMessageContent: diff --git a/tests/unit/common/contents/test_action_result_message_content.py b/tests/unit/common/contents/test_action_result_message_content.py index b1bcc356d3..c5eed60d94 100644 --- a/tests/unit/common/contents/test_action_result_message_content.py +++ b/tests/unit/common/contents/test_action_result_message_content.py @@ -1,4 +1,5 @@ import pytest + from griptape.artifacts.text_artifact import TextArtifact from griptape.common import ActionResultMessageContent, ToolAction diff --git a/tests/unit/common/contents/test_image_message_content.py b/tests/unit/common/contents/test_image_message_content.py index b6c1b4c4f1..ff8dbe59d0 100644 --- a/tests/unit/common/contents/test_image_message_content.py +++ b/tests/unit/common/contents/test_image_message_content.py @@ -1,4 +1,5 @@ import pytest + from griptape.artifacts.image_artifact import ImageArtifact from griptape.common import ImageMessageContent diff --git a/tests/unit/common/contents/test_text_message_content.py b/tests/unit/common/contents/test_text_message_content.py index 01a3c0fd46..eab9eb718e 100644 --- a/tests/unit/common/contents/test_text_message_content.py +++ b/tests/unit/common/contents/test_text_message_content.py @@ -1,5 +1,5 @@ from griptape.artifacts.text_artifact import TextArtifact -from griptape.common import TextMessageContent, TextDeltaMessageContent +from griptape.common import TextDeltaMessageContent, TextMessageContent class TestTextMessageContent: diff --git a/tests/unit/common/test_action.py b/tests/unit/common/test_action.py index db5284839a..8fcf09a57d 100644 --- a/tests/unit/common/test_action.py +++ b/tests/unit/common/test_action.py @@ -1,5 +1,7 @@ -import pytest import json + +import pytest + from griptape.common import ToolAction diff --git a/tests/unit/common/test_prompt_stack.py b/tests/unit/common/test_prompt_stack.py index 83a16e140b..ee7cd21ce4 100644 --- a/tests/unit/common/test_prompt_stack.py +++ b/tests/unit/common/test_prompt_stack.py @@ -1,9 +1,14 @@ import pytest -from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact, ActionArtifact -from griptape.common import ImageMessageContent, PromptStack, TextMessageContent -from griptape.common import ActionCallMessageContent -from griptape.common import ActionResultMessageContent, ToolAction +from griptape.artifacts import ActionArtifact, ImageArtifact, ListArtifact, TextArtifact +from griptape.common import ( + ActionCallMessageContent, + ActionResultMessageContent, + ImageMessageContent, + PromptStack, + TextMessageContent, + ToolAction, +) class TestPromptStack: diff --git a/tests/unit/config/test_amazon_bedrock_structure_config.py b/tests/unit/config/test_amazon_bedrock_structure_config.py index 824e6ce114..44206c6889 100644 --- a/tests/unit/config/test_amazon_bedrock_structure_config.py +++ b/tests/unit/config/test_amazon_bedrock_structure_config.py @@ -1,5 +1,6 @@ import boto3 from pytest import fixture + from griptape.config import AmazonBedrockStructureConfig from tests.utils.aws import mock_aws_credentials diff --git a/tests/unit/config/test_anthropic_structure_config.py b/tests/unit/config/test_anthropic_structure_config.py index b41309a838..dbf402bb87 100644 --- a/tests/unit/config/test_anthropic_structure_config.py +++ b/tests/unit/config/test_anthropic_structure_config.py @@ -1,4 +1,5 @@ from pytest import fixture + from griptape.config import AnthropicStructureConfig diff --git a/tests/unit/config/test_azure_openai_structure_config.py b/tests/unit/config/test_azure_openai_structure_config.py index 58d557fb91..ab5fdf30dd 100644 --- a/tests/unit/config/test_azure_openai_structure_config.py +++ b/tests/unit/config/test_azure_openai_structure_config.py @@ -1,4 +1,5 @@ from pytest import fixture + from griptape.config import AzureOpenAiStructureConfig diff --git a/tests/unit/config/test_cohere_structure_config.py b/tests/unit/config/test_cohere_structure_config.py index 44ed3e4d84..9074e6af1a 100644 --- a/tests/unit/config/test_cohere_structure_config.py +++ b/tests/unit/config/test_cohere_structure_config.py @@ -1,4 +1,5 @@ from pytest import fixture + from griptape.config import CohereStructureConfig diff --git a/tests/unit/config/test_google_structure_config.py b/tests/unit/config/test_google_structure_config.py index 469493e2c7..a3dc18556f 100644 --- a/tests/unit/config/test_google_structure_config.py +++ b/tests/unit/config/test_google_structure_config.py @@ -1,4 +1,5 @@ from pytest import fixture + from griptape.config import GoogleStructureConfig diff --git a/tests/unit/config/test_openai_structure_config.py b/tests/unit/config/test_openai_structure_config.py index 19321006f4..16fd74aa4f 100644 --- a/tests/unit/config/test_openai_structure_config.py +++ b/tests/unit/config/test_openai_structure_config.py @@ -1,4 +1,5 @@ from pytest import fixture + from griptape.config import OpenAiStructureConfig diff --git a/tests/unit/config/test_structure_config.py b/tests/unit/config/test_structure_config.py index 27aaf81c42..27afa8e6e8 100644 --- a/tests/unit/config/test_structure_config.py +++ b/tests/unit/config/test_structure_config.py @@ -1,4 +1,5 @@ from pytest import fixture + from griptape.config import StructureConfig diff --git a/tests/unit/drivers/embedding/test_amazon_bedrock_cohere_embedding_driver.py b/tests/unit/drivers/embedding/test_amazon_bedrock_cohere_embedding_driver.py index ba8edad902..974d8b1434 100644 --- a/tests/unit/drivers/embedding/test_amazon_bedrock_cohere_embedding_driver.py +++ b/tests/unit/drivers/embedding/test_amazon_bedrock_cohere_embedding_driver.py @@ -1,5 +1,7 @@ -import pytest from unittest import mock + +import pytest + from griptape.drivers import AmazonBedrockCohereEmbeddingDriver diff --git a/tests/unit/drivers/embedding/test_amazon_bedrock_titan_embedding_driver.py b/tests/unit/drivers/embedding/test_amazon_bedrock_titan_embedding_driver.py index df4455c248..e3005879a8 100644 --- a/tests/unit/drivers/embedding/test_amazon_bedrock_titan_embedding_driver.py +++ b/tests/unit/drivers/embedding/test_amazon_bedrock_titan_embedding_driver.py @@ -1,5 +1,7 @@ -import pytest from unittest import mock + +import pytest + from griptape.drivers import AmazonBedrockTitanEmbeddingDriver diff --git a/tests/unit/drivers/embedding/test_azure_openai_embedding_driver.py b/tests/unit/drivers/embedding/test_azure_openai_embedding_driver.py index d2c20c0439..700434de9c 100644 --- a/tests/unit/drivers/embedding/test_azure_openai_embedding_driver.py +++ b/tests/unit/drivers/embedding/test_azure_openai_embedding_driver.py @@ -1,5 +1,7 @@ from unittest.mock import Mock + import pytest + from griptape.drivers import AzureOpenAiEmbeddingDriver diff --git a/tests/unit/drivers/embedding/test_base_embedding_driver.py b/tests/unit/drivers/embedding/test_base_embedding_driver.py index 24b07778d9..e65c0946e4 100644 --- a/tests/unit/drivers/embedding/test_base_embedding_driver.py +++ b/tests/unit/drivers/embedding/test_base_embedding_driver.py @@ -1,7 +1,9 @@ +from unittest.mock import patch + import pytest + from griptape.artifacts import TextArtifact from tests.mocks.mock_embedding_driver import MockEmbeddingDriver -from unittest.mock import patch class TestBaseEmbeddingDriver: diff --git a/tests/unit/drivers/embedding/test_cohere_embedding_driver.py b/tests/unit/drivers/embedding/test_cohere_embedding_driver.py index af6a5576d6..024e0e74cd 100644 --- a/tests/unit/drivers/embedding/test_cohere_embedding_driver.py +++ b/tests/unit/drivers/embedding/test_cohere_embedding_driver.py @@ -1,5 +1,7 @@ from unittest.mock import Mock + import pytest + from griptape.drivers import CohereEmbeddingDriver diff --git a/tests/unit/drivers/embedding/test_dummy_embedding_driver.py b/tests/unit/drivers/embedding/test_dummy_embedding_driver.py index 35f81bf77d..335a14af7c 100644 --- a/tests/unit/drivers/embedding/test_dummy_embedding_driver.py +++ b/tests/unit/drivers/embedding/test_dummy_embedding_driver.py @@ -1,6 +1,6 @@ -from griptape.drivers import DummyEmbeddingDriver import pytest +from griptape.drivers import DummyEmbeddingDriver from griptape.exceptions import DummyException diff --git a/tests/unit/drivers/embedding/test_google_embedding_driver.py b/tests/unit/drivers/embedding/test_google_embedding_driver.py index 324b95ddb5..9e756491ef 100644 --- a/tests/unit/drivers/embedding/test_google_embedding_driver.py +++ b/tests/unit/drivers/embedding/test_google_embedding_driver.py @@ -1,5 +1,7 @@ from unittest.mock import MagicMock + import pytest + from griptape.drivers import GoogleEmbeddingDriver diff --git a/tests/unit/drivers/embedding/test_ollama_embedding_driver.py b/tests/unit/drivers/embedding/test_ollama_embedding_driver.py index 3886ab874a..6dda239301 100644 --- a/tests/unit/drivers/embedding/test_ollama_embedding_driver.py +++ b/tests/unit/drivers/embedding/test_ollama_embedding_driver.py @@ -1,4 +1,5 @@ import pytest + from griptape.drivers import OllamaEmbeddingDriver diff --git a/tests/unit/drivers/embedding/test_openai_embedding_driver.py b/tests/unit/drivers/embedding/test_openai_embedding_driver.py index fd30dd30f1..78879345ac 100644 --- a/tests/unit/drivers/embedding/test_openai_embedding_driver.py +++ b/tests/unit/drivers/embedding/test_openai_embedding_driver.py @@ -1,5 +1,7 @@ -from unittest.mock import Mock, MagicMock +from unittest.mock import Mock + import pytest + from griptape.drivers import OpenAiEmbeddingDriver from griptape.tokenizers import OpenAiTokenizer diff --git a/tests/unit/drivers/embedding/test_sagemaker_jumpstart_embedding_driver.py b/tests/unit/drivers/embedding/test_sagemaker_jumpstart_embedding_driver.py index 268b47c54c..2315dcb0ba 100644 --- a/tests/unit/drivers/embedding/test_sagemaker_jumpstart_embedding_driver.py +++ b/tests/unit/drivers/embedding/test_sagemaker_jumpstart_embedding_driver.py @@ -1,5 +1,7 @@ -import pytest from unittest import mock + +import pytest + from griptape.drivers import AmazonSageMakerJumpstartEmbeddingDriver from griptape.tokenizers.openai_tokenizer import OpenAiTokenizer diff --git a/tests/unit/drivers/embedding/test_voyageai_embedding_driver.py b/tests/unit/drivers/embedding/test_voyageai_embedding_driver.py index 69db0213cb..5371f8db0a 100644 --- a/tests/unit/drivers/embedding/test_voyageai_embedding_driver.py +++ b/tests/unit/drivers/embedding/test_voyageai_embedding_driver.py @@ -1,5 +1,7 @@ -import pytest from unittest.mock import Mock + +import pytest + from griptape.drivers import VoyageAiEmbeddingDriver diff --git a/tests/unit/drivers/event_listener/test_amazon_sqs_event_listener_driver.py b/tests/unit/drivers/event_listener/test_amazon_sqs_event_listener_driver.py index 706831d670..0513dc1ed0 100644 --- a/tests/unit/drivers/event_listener/test_amazon_sqs_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_amazon_sqs_event_listener_driver.py @@ -1,8 +1,9 @@ -from pytest import fixture -from moto import mock_sqs import boto3 -from tests.mocks.mock_event import MockEvent +from moto import mock_sqs +from pytest import fixture + from griptape.drivers.event_listener.amazon_sqs_event_listener_driver import AmazonSqsEventListenerDriver +from tests.mocks.mock_event import MockEvent from tests.utils.aws import mock_aws_credentials diff --git a/tests/unit/drivers/event_listener/test_aws_iot_event_listener_driver.py b/tests/unit/drivers/event_listener/test_aws_iot_event_listener_driver.py index 9a5fe9ec05..a9c778dff6 100644 --- a/tests/unit/drivers/event_listener/test_aws_iot_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_aws_iot_event_listener_driver.py @@ -1,8 +1,9 @@ -from pytest import fixture -from moto import mock_iotdata import boto3 -from tests.mocks.mock_event import MockEvent +from moto import mock_iotdata +from pytest import fixture + from griptape.drivers.event_listener.aws_iot_core_event_listener_driver import AwsIotCoreEventListenerDriver +from tests.mocks.mock_event import MockEvent from tests.utils.aws import mock_aws_credentials diff --git a/tests/unit/drivers/event_listener/test_base_event_listener_driver.py b/tests/unit/drivers/event_listener/test_base_event_listener_driver.py index 383c0be898..04cfef34b0 100644 --- a/tests/unit/drivers/event_listener/test_base_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_base_event_listener_driver.py @@ -1,4 +1,5 @@ from unittest.mock import MagicMock + from tests.mocks.mock_event import MockEvent from tests.mocks.mock_event_listener_driver import MockEventListenerDriver diff --git a/tests/unit/drivers/event_listener/test_pusher_event_listener_driver.py b/tests/unit/drivers/event_listener/test_pusher_event_listener_driver.py index 6f0636b5cc..b14a56324d 100644 --- a/tests/unit/drivers/event_listener/test_pusher_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_pusher_event_listener_driver.py @@ -1,7 +1,9 @@ +from unittest.mock import Mock + from pytest import fixture -from tests.mocks.mock_event import MockEvent + from griptape.drivers import PusherEventListenerDriver -from unittest.mock import Mock +from tests.mocks.mock_event import MockEvent class TestPusherEventListenerDriver: diff --git a/tests/unit/drivers/event_listener/test_webhook_event_listener_driver.py b/tests/unit/drivers/event_listener/test_webhook_event_listener_driver.py index 50021cbe34..36ace5d43c 100644 --- a/tests/unit/drivers/event_listener/test_webhook_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_webhook_event_listener_driver.py @@ -1,7 +1,9 @@ from unittest.mock import Mock + from pytest import fixture -from tests.mocks.mock_event import MockEvent + from griptape.drivers.event_listener.webhook_event_listener_driver import WebhookEventListenerDriver +from tests.mocks.mock_event import MockEvent class TestWebhookEventListenerDriver: diff --git a/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py b/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py index 8d1693adef..f53ff10a31 100644 --- a/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py +++ b/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py @@ -1,9 +1,11 @@ import os import tempfile + import boto3 import pytest from moto import mock_s3 -from griptape.artifacts import ErrorArtifact, ListArtifact, InfoArtifact, TextArtifact + +from griptape.artifacts import ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact from griptape.drivers import AmazonS3FileManagerDriver from griptape.loaders import TextLoader from tests.utils.aws import mock_aws_credentials @@ -30,16 +32,16 @@ def bucket(self, s3_client): bucket = "test-bucket" s3_client.create_bucket(Bucket=bucket) - def write_file(path: str, content: bytes): + def write_file(path: str, content: bytes) -> None: s3_client.put_object(Bucket=bucket, Key=path, Body=content) - def mkdir(path: str): + def mkdir(path: str) -> None: # S3-style empty directories, such as is created via the `Create Folder` button # in the AWS S3 console (essentially, an empty file with a trailing slash). s3_dir_key = path.rstrip("/") + "/" s3_client.put_object(Bucket=bucket, Key=s3_dir_key) - def copy_test_resource(resource_path: str): + def copy_test_resource(resource_path: str) -> None: file_dir = os.path.dirname(__file__) full_path = os.path.join(file_dir, "../../../resources", resource_path) full_path = os.path.normpath(full_path) @@ -250,8 +252,8 @@ def test_save_file_failure(self, workdir, path, expected, temp_dir, driver, s3_c # loop over the files in the bucket and print them response = s3_client.list_objects_v2(Bucket=bucket) - for obj in response.get("Contents", []): - print(obj.get("Key")) + for _obj in response.get("Contents", []): + pass assert isinstance(artifact, ErrorArtifact) assert artifact.value == expected diff --git a/tests/unit/drivers/file_manager/test_local_file_manager_driver.py b/tests/unit/drivers/file_manager/test_local_file_manager_driver.py index b3f4ec5611..234bebdf99 100644 --- a/tests/unit/drivers/file_manager/test_local_file_manager_driver.py +++ b/tests/unit/drivers/file_manager/test_local_file_manager_driver.py @@ -1,8 +1,10 @@ import os -from pathlib import Path import tempfile +from pathlib import Path + import pytest -from griptape.artifacts import ErrorArtifact, ListArtifact, InfoArtifact, TextArtifact + +from griptape.artifacts import ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact from griptape.drivers import LocalFileManagerDriver from griptape.loaders.text_loader import TextLoader @@ -12,17 +14,17 @@ class TestLocalFileManagerDriver: def temp_dir(self): with tempfile.TemporaryDirectory() as temp_dir: - def write_file(path: str, content: bytes): + def write_file(path: str, content: bytes) -> None: full_path = os.path.join(temp_dir, path) os.makedirs(os.path.dirname(full_path), exist_ok=True) with open(full_path, "wb") as f: f.write(content) - def mkdir(path: str): + def mkdir(path: str) -> None: full_path = os.path.join(temp_dir, path) os.makedirs(full_path, exist_ok=True) - def copy_test_resources(resource_path: str): + def copy_test_resources(resource_path: str) -> None: file_dir = os.path.dirname(__file__) full_path = os.path.join(file_dir, "../../../resources", resource_path) full_path = os.path.normpath(full_path) diff --git a/tests/unit/drivers/image_generation/test_amazon_bedrock_stable_diffusion_image_generation_driver.py b/tests/unit/drivers/image_generation/test_amazon_bedrock_stable_diffusion_image_generation_driver.py index a2c51f58be..e9e393c343 100644 --- a/tests/unit/drivers/image_generation/test_amazon_bedrock_stable_diffusion_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_amazon_bedrock_stable_diffusion_image_generation_driver.py @@ -37,7 +37,7 @@ def test_init(self, driver): def test_init_requires_image_generation_model_driver(self, session): with pytest.raises(TypeError): - AmazonBedrockImageGenerationDriver(session=session, model="stability.stable-diffusion-xl-v1") # pyright: ignore + AmazonBedrockImageGenerationDriver(session=session, model="stability.stable-diffusion-xl-v1") # pyright: ignore[reportCallIssue] def test_try_text_to_image(self, driver): driver.bedrock_client.invoke_model.return_value = { diff --git a/tests/unit/drivers/image_generation/test_azure_openai_image_generation_driver.py b/tests/unit/drivers/image_generation/test_azure_openai_image_generation_driver.py index 2166bc28a0..bfc05d0bfb 100644 --- a/tests/unit/drivers/image_generation/test_azure_openai_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_azure_openai_image_generation_driver.py @@ -1,5 +1,7 @@ -import pytest from unittest.mock import Mock + +import pytest + from griptape.drivers import AzureOpenAiImageGenerationDriver @@ -27,7 +29,7 @@ def test_init_requires_endpoint(self): with pytest.raises(TypeError): AzureOpenAiImageGenerationDriver( model="dall-e-3", client=Mock(), azure_deployment="dalle-deployment", image_size="512x512" - ) # pyright: ignore + ) # pyright: ignore[reportCallIssues] def test_try_text_to_image(self, driver): driver.client.images.generate.return_value = Mock(data=[Mock(b64_json=b"aW1hZ2UgZGF0YQ==")]) diff --git a/tests/unit/drivers/image_generation/test_dummy_image_generation_driver.py b/tests/unit/drivers/image_generation/test_dummy_image_generation_driver.py index 971c39b894..cb9ac7d08e 100644 --- a/tests/unit/drivers/image_generation/test_dummy_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_dummy_image_generation_driver.py @@ -1,7 +1,7 @@ -from griptape.drivers import DummyImageGenerationDriver -from griptape.artifacts import ImageArtifact import pytest +from griptape.artifacts import ImageArtifact +from griptape.drivers import DummyImageGenerationDriver from griptape.exceptions import DummyException diff --git a/tests/unit/drivers/image_generation/test_leonardo_image_generation_driver.py b/tests/unit/drivers/image_generation/test_leonardo_image_generation_driver.py index 564d3616a0..212543d184 100644 --- a/tests/unit/drivers/image_generation/test_leonardo_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_leonardo_image_generation_driver.py @@ -1,6 +1,8 @@ import uuid -from unittest.mock import Mock, PropertyMock, MagicMock +from unittest.mock import Mock + import pytest + from griptape.drivers import LeonardoImageGenerationDriver diff --git a/tests/unit/drivers/image_generation/test_openai_image_generation_driver.py b/tests/unit/drivers/image_generation/test_openai_image_generation_driver.py index 8ca488eb14..466d2bed63 100644 --- a/tests/unit/drivers/image_generation/test_openai_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_openai_image_generation_driver.py @@ -1,5 +1,7 @@ -import pytest from unittest.mock import Mock + +import pytest + from griptape.drivers import OpenAiImageGenerationDriver diff --git a/tests/unit/drivers/image_generation_model/test_bedrock_stable_diffusion_image_model_driver.py b/tests/unit/drivers/image_generation_model/test_bedrock_stable_diffusion_image_model_driver.py index cdd4e95b7e..c43570b88d 100644 --- a/tests/unit/drivers/image_generation_model/test_bedrock_stable_diffusion_image_model_driver.py +++ b/tests/unit/drivers/image_generation_model/test_bedrock_stable_diffusion_image_model_driver.py @@ -118,5 +118,5 @@ def test_get_generated_image_failed(self, model_driver): response = {"artifacts": [{"finishReason": "ERROR", "base64": base64.b64encode(image_bytes).decode("utf-8")}]} - with pytest.raises(Exception): + with pytest.raises(Exception, match="Image generation failed:"): model_driver.get_generated_image(response) diff --git a/tests/unit/drivers/image_query/test_amazon_bedrock_image_query_driver.py b/tests/unit/drivers/image_query/test_amazon_bedrock_image_query_driver.py index 57336e8ea6..c4a4ea60d3 100644 --- a/tests/unit/drivers/image_query/test_amazon_bedrock_image_query_driver.py +++ b/tests/unit/drivers/image_query/test_amazon_bedrock_image_query_driver.py @@ -1,8 +1,10 @@ -import pytest import io from unittest.mock import Mock -from griptape.drivers import AmazonBedrockImageQueryDriver + +import pytest + from griptape.artifacts import ImageArtifact, TextArtifact +from griptape.drivers import AmazonBedrockImageQueryDriver class TestAmazonBedrockImageQueryDriver: diff --git a/tests/unit/drivers/image_query/test_anthropic_image_query_driver.py b/tests/unit/drivers/image_query/test_anthropic_image_query_driver.py index 24958d58fa..65cf1e3b49 100644 --- a/tests/unit/drivers/image_query/test_anthropic_image_query_driver.py +++ b/tests/unit/drivers/image_query/test_anthropic_image_query_driver.py @@ -1,8 +1,10 @@ -import pytest import base64 from unittest.mock import Mock -from griptape.drivers import AnthropicImageQueryDriver + +import pytest + from griptape.artifacts import ImageArtifact +from griptape.drivers import AnthropicImageQueryDriver class TestAnthropicImageQueryDriver: @@ -55,7 +57,7 @@ def test_try_query_max_tokens_value(self, mock_client): assert text_artifact.value == "Content" def test_try_query_max_tokens_none(self, mock_client): - driver = AnthropicImageQueryDriver(model="test-model", max_tokens=None) # pyright: ignore + driver = AnthropicImageQueryDriver(model="test-model", max_tokens=None) # pyright: ignore[reportArgumentType] test_prompt_string = "Prompt String" test_binary_data = b"test-data" with pytest.raises(TypeError): diff --git a/tests/unit/drivers/image_query/test_azure_openai_image_query_driver.py b/tests/unit/drivers/image_query/test_azure_openai_image_query_driver.py index a443198614..e18e0f1b82 100644 --- a/tests/unit/drivers/image_query/test_azure_openai_image_query_driver.py +++ b/tests/unit/drivers/image_query/test_azure_openai_image_query_driver.py @@ -1,7 +1,9 @@ -import pytest from unittest.mock import Mock -from griptape.drivers import AzureOpenAiImageQueryDriver + +import pytest + from griptape.artifacts import ImageArtifact +from griptape.drivers import AzureOpenAiImageQueryDriver class TestAzureOpenAiVisionImageQueryDriver: @@ -52,7 +54,7 @@ def test_try_query_multiple_choices(self, mock_completion_create): azure_endpoint="test-endpoint", azure_deployment="test-deployment", model="gpt-4" ) - with pytest.raises(Exception): + with pytest.raises(Exception, match="Image query responses with more than one choice are not supported yet."): driver.try_query("Prompt String", [ImageArtifact(value=b"test-data", width=100, height=100, format="png")]) def _expected_messages(self, expected_prompt_string, expected_binary_data): diff --git a/tests/unit/drivers/image_query/test_dummy_image_query_driver.py b/tests/unit/drivers/image_query/test_dummy_image_query_driver.py index 8efcfa7493..dedd8cb393 100644 --- a/tests/unit/drivers/image_query/test_dummy_image_query_driver.py +++ b/tests/unit/drivers/image_query/test_dummy_image_query_driver.py @@ -1,7 +1,7 @@ -from griptape.drivers import DummyImageQueryDriver -from griptape.artifacts import ImageArtifact import pytest +from griptape.artifacts import ImageArtifact +from griptape.drivers import DummyImageQueryDriver from griptape.exceptions import DummyException diff --git a/tests/unit/drivers/image_query/test_openai_image_query_driver.py b/tests/unit/drivers/image_query/test_openai_image_query_driver.py index 08f0c70c9e..36177dc00a 100644 --- a/tests/unit/drivers/image_query/test_openai_image_query_driver.py +++ b/tests/unit/drivers/image_query/test_openai_image_query_driver.py @@ -1,7 +1,9 @@ -import pytest from unittest.mock import Mock -from griptape.drivers import OpenAiImageQueryDriver + +import pytest + from griptape.artifacts import ImageArtifact +from griptape.drivers import OpenAiImageQueryDriver class TestOpenAiVisionImageQueryDriver: @@ -43,7 +45,7 @@ def test_try_query_multiple_choices(self, mock_completion_create): mock_completion_create.return_value.choices.append(Mock(message=Mock(content="expected_output_text2"))) driver = OpenAiImageQueryDriver(model="gpt-4-vision-preview") - with pytest.raises(Exception): + with pytest.raises(Exception, match="Image query responses with more than one choice are not supported yet."): driver.try_query("Prompt String", [ImageArtifact(value=b"test-data", width=100, height=100, format="png")]) def _expected_messages(self, expected_prompt_string, expected_binary_data): diff --git a/tests/unit/drivers/image_query_models/test_bedrock_claude_image_query_model_driver.py b/tests/unit/drivers/image_query_models/test_bedrock_claude_image_query_model_driver.py index 14fa8ff28d..c274f71dd5 100644 --- a/tests/unit/drivers/image_query_models/test_bedrock_claude_image_query_model_driver.py +++ b/tests/unit/drivers/image_query_models/test_bedrock_claude_image_query_model_driver.py @@ -1,6 +1,7 @@ import pytest -from griptape.drivers import BedrockClaudeImageQueryModelDriver + from griptape.artifacts import ImageArtifact, TextArtifact +from griptape.drivers import BedrockClaudeImageQueryModelDriver class TestBedrockClaudeImageQueryModelDriver: diff --git a/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py index ba79a4def7..1b4f2c303f 100644 --- a/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py +++ b/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py @@ -1,12 +1,13 @@ +import boto3 import pytest from moto import mock_dynamodb -import boto3 -from tests.mocks.mock_prompt_driver import MockPromptDriver -from tests.utils.aws import mock_aws_credentials + +from griptape.drivers import AmazonDynamoDbConversationMemoryDriver from griptape.memory.structure import ConversationMemory -from griptape.tasks import PromptTask from griptape.structures import Pipeline -from griptape.drivers import AmazonDynamoDbConversationMemoryDriver +from griptape.tasks import PromptTask +from tests.mocks.mock_prompt_driver import MockPromptDriver +from tests.utils.aws import mock_aws_credentials class TestDynamoDbConversationMemoryDriver: diff --git a/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py index c794afd0e7..59d3f00b43 100644 --- a/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py +++ b/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py @@ -1,10 +1,13 @@ +import contextlib import os + import pytest -from tests.mocks.mock_prompt_driver import MockPromptDriver + from griptape.drivers import LocalConversationMemoryDriver from griptape.memory.structure import ConversationMemory -from griptape.tasks import PromptTask from griptape.structures import Pipeline +from griptape.tasks import PromptTask +from tests.mocks.mock_prompt_driver import MockPromptDriver class TestLocalConversationMemoryDriver: @@ -28,7 +31,7 @@ def test_store(self): try: with open(self.MEMORY_FILE_PATH): - assert False + raise AssertionError() except FileNotFoundError: assert True @@ -74,8 +77,6 @@ def test_autoload(self): assert autoloaded_memory.runs[0].input.value == "test" assert autoloaded_memory.runs[0].output.value == "mock output" - def __delete_file(self, file_path): - try: + def __delete_file(self, file_path) -> None: + with contextlib.suppress(FileNotFoundError): os.remove(file_path) - except FileNotFoundError: - pass diff --git a/tests/unit/drivers/memory/conversation/test_redis_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_redis_conversation_memory_driver.py index 1af9d74dc1..6cc255a3fc 100644 --- a/tests/unit/drivers/memory/conversation/test_redis_conversation_memory_driver.py +++ b/tests/unit/drivers/memory/conversation/test_redis_conversation_memory_driver.py @@ -1,7 +1,8 @@ import pytest import redis -from griptape.memory.structure.base_conversation_memory import BaseConversationMemory + from griptape.drivers.memory.conversation.redis_conversation_memory_driver import RedisConversationMemoryDriver +from griptape.memory.structure.base_conversation_memory import BaseConversationMemory TEST_CONVERSATION = '{"type": "ConversationMemory", "runs": [{"type": "Run", "id": "729ca6be5d79433d9762eb06dfd677e2", "input": {"type": "TextArtifact", "id": "1234", "value": "Hi There, Hello"}, "output": {"type": "TextArtifact", "id": "123", "value": "Hello! How can I assist you today?"}}], "max_runs": 2}' CONVERSATION_ID = "117151897f344ff684b553d0655d8f39" @@ -31,7 +32,7 @@ def driver(self): def test_store(self, driver): memory = BaseConversationMemory.from_json(TEST_CONVERSATION) - assert driver.store(memory) == None + assert driver.store(memory) is None def test_load(self, driver): memory = driver.load() diff --git a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py index e31b6a4483..fa3c250f81 100644 --- a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py @@ -1,10 +1,8 @@ import pytest -from griptape.artifacts import ImageArtifact, TextArtifact, ListArtifact, ErrorArtifact, ActionArtifact -from griptape.common import PromptStack -from griptape.common import TextDeltaMessageContent, ActionCallDeltaMessageContent, ToolAction +from griptape.artifacts import ActionArtifact, ErrorArtifact, ImageArtifact, ListArtifact, TextArtifact +from griptape.common import ActionCallDeltaMessageContent, PromptStack, TextDeltaMessageContent, ToolAction from griptape.drivers import AmazonBedrockPromptDriver - from tests.mocks.mock_tool.tool import MockTool diff --git a/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py index a75fc6ed0c..e74797e42d 100644 --- a/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py @@ -1,11 +1,13 @@ +import json +from io import BytesIO from typing import Any + +import pytest from botocore.response import StreamingBody -from griptape.tokenizers import HuggingFaceTokenizer -from griptape.drivers.prompt.amazon_sagemaker_jumpstart_prompt_driver import AmazonSageMakerJumpstartPromptDriver + from griptape.common import PromptStack -from io import BytesIO -import json -import pytest +from griptape.drivers.prompt.amazon_sagemaker_jumpstart_prompt_driver import AmazonSageMakerJumpstartPromptDriver +from griptape.tokenizers import HuggingFaceTokenizer def to_streaming_body(data: Any) -> StreamingBody: diff --git a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py index e8f6d337fa..d5fb0f710a 100644 --- a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py @@ -1,10 +1,11 @@ -from griptape.artifacts.error_artifact import ErrorArtifact -from griptape.drivers import AnthropicPromptDriver -from griptape.common import PromptStack, TextDeltaMessageContent, ActionCallDeltaMessageContent, ToolAction -from griptape.artifacts import TextArtifact, ActionArtifact, ImageArtifact, ListArtifact from unittest.mock import Mock + import pytest +from griptape.artifacts import ActionArtifact, ImageArtifact, ListArtifact, TextArtifact +from griptape.artifacts.error_artifact import ErrorArtifact +from griptape.common import ActionCallDeltaMessageContent, PromptStack, TextDeltaMessageContent, ToolAction +from griptape.drivers import AnthropicPromptDriver from tests.mocks.mock_tool.tool import MockTool diff --git a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py index 9e56b39bd7..aa4e991132 100644 --- a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py @@ -1,8 +1,10 @@ -import pytest -from griptape.artifacts import TextArtifact, ActionArtifact from unittest.mock import Mock + +import pytest + +from griptape.artifacts import ActionArtifact, TextArtifact +from griptape.common import ActionCallDeltaMessageContent, TextDeltaMessageContent from griptape.drivers import AzureOpenAiChatPromptDriver -from griptape.common import TextDeltaMessageContent, ActionCallDeltaMessageContent from tests.unit.drivers.prompt.test_openai_chat_prompt_driver import TestOpenAiChatPromptDriverFixtureMixin diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 6eb000e1f0..3c2bb333e4 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -1,11 +1,11 @@ +from griptape.artifacts import ErrorArtifact, TextArtifact +from griptape.common import PromptStack from griptape.common.prompt_stack.messages.message import Message from griptape.events import FinishPromptEvent, StartPromptEvent -from griptape.common import PromptStack -from tests.mocks.mock_prompt_driver import MockPromptDriver -from tests.mocks.mock_failing_prompt_driver import MockFailingPromptDriver -from griptape.artifacts import ErrorArtifact, TextArtifact -from griptape.tasks import PromptTask from griptape.structures import Pipeline +from griptape.tasks import PromptTask +from tests.mocks.mock_failing_prompt_driver import MockFailingPromptDriver +from tests.mocks.mock_prompt_driver import MockPromptDriver class TestBasePromptDriver: diff --git a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py index 167c08b347..197e35e840 100644 --- a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py @@ -5,9 +5,8 @@ from griptape.artifacts.action_artifact import ActionArtifact from griptape.artifacts.list_artifact import ListArtifact from griptape.artifacts.text_artifact import TextArtifact -from griptape.common import PromptStack, ToolAction +from griptape.common import ActionCallDeltaMessageContent, PromptStack, TextDeltaMessageContent, ToolAction from griptape.drivers import CoherePromptDriver -from griptape.common import TextDeltaMessageContent, ActionCallDeltaMessageContent from tests.mocks.mock_tool.tool import MockTool diff --git a/tests/unit/drivers/prompt/test_dummy_prompt_driver.py b/tests/unit/drivers/prompt/test_dummy_prompt_driver.py index d569b55af4..8b2a966c0f 100644 --- a/tests/unit/drivers/prompt/test_dummy_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_dummy_prompt_driver.py @@ -1,6 +1,6 @@ -from griptape.drivers import DummyPromptDriver import pytest +from griptape.drivers import DummyPromptDriver from griptape.exceptions import DummyException diff --git a/tests/unit/drivers/prompt/test_google_prompt_driver.py b/tests/unit/drivers/prompt/test_google_prompt_driver.py index b1a72d10d4..5770f153f7 100644 --- a/tests/unit/drivers/prompt/test_google_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_google_prompt_driver.py @@ -1,15 +1,15 @@ -from google.generativeai.types import ContentDict, GenerationConfig +from unittest.mock import Mock + +import pytest from google.generativeai.protos import FunctionCall, FunctionResponse, Part -from griptape.artifacts import TextArtifact, ImageArtifact, ActionArtifact +from google.generativeai.types import ContentDict, GenerationConfig +from google.protobuf.json_format import MessageToDict + +from griptape.artifacts import ActionArtifact, ImageArtifact, TextArtifact from griptape.artifacts.list_artifact import ListArtifact -from griptape.common import TextDeltaMessageContent, ActionCallDeltaMessageContent, ToolAction +from griptape.common import ActionCallDeltaMessageContent, PromptStack, TextDeltaMessageContent, ToolAction from griptape.drivers import GooglePromptDriver -from griptape.common import PromptStack -from unittest.mock import Mock from tests.mocks.mock_tool.tool import MockTool -from google.protobuf.json_format import MessageToDict - -import pytest class TestGooglePromptDriver: diff --git a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py index ec7ea73f8f..96e9cf1a11 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py @@ -1,7 +1,8 @@ -from griptape.drivers import HuggingFaceHubPromptDriver -from griptape.common import PromptStack, TextDeltaMessageContent import pytest +from griptape.common import PromptStack, TextDeltaMessageContent +from griptape.drivers import HuggingFaceHubPromptDriver + class TestHuggingFaceHubPromptDriver: @pytest.fixture diff --git a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py index a63d697fb5..90e3e459db 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py @@ -1,7 +1,8 @@ -from griptape.drivers import HuggingFacePipelinePromptDriver -from griptape.common import PromptStack import pytest +from griptape.common import PromptStack +from griptape.drivers import HuggingFacePipelinePromptDriver + class TestHuggingFacePipelinePromptDriver: @pytest.fixture(autouse=True) diff --git a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py index a247a77ab7..7c02c860f9 100644 --- a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py @@ -1,8 +1,9 @@ +import pytest + +from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact +from griptape.common import PromptStack from griptape.common.prompt_stack.contents.text_delta_message_content import TextDeltaMessageContent from griptape.drivers import OllamaPromptDriver -from griptape.common import PromptStack -from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact -import pytest class TestOllamaPromptDriver: diff --git a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py index 59772ff23b..7bc778a329 100644 --- a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py @@ -1,12 +1,13 @@ -from griptape.artifacts import ImageArtifact, ListArtifact -from griptape.artifacts import TextArtifact, ActionArtifact +from unittest.mock import Mock + +import pytest + +from griptape.artifacts import ActionArtifact, ImageArtifact, ListArtifact, TextArtifact +from griptape.common import ActionCallDeltaMessageContent, PromptStack, TextDeltaMessageContent, ToolAction from griptape.drivers import OpenAiChatPromptDriver -from griptape.common import PromptStack, TextDeltaMessageContent, ActionCallDeltaMessageContent, ToolAction from griptape.tokenizers import OpenAiTokenizer -from unittest.mock import Mock from tests.mocks.mock_tokenizer import MockTokenizer from tests.mocks.mock_tool.tool import MockTool -import pytest class TestOpenAiChatPromptDriverFixtureMixin: @@ -283,7 +284,7 @@ def __init__( remaining_tokens=234, limit_requests=345, limit_tokens=456, - ): + ) -> None: self.reset_requests_in = reset_requests_in self.reset_requests_in_unit = reset_requests_in_unit self.reset_tokens_in = reset_tokens_in diff --git a/tests/unit/drivers/rerank/test_cohere_rerank_driver.py b/tests/unit/drivers/rerank/test_cohere_rerank_driver.py index 952546e5a9..c5a6eb4c21 100644 --- a/tests/unit/drivers/rerank/test_cohere_rerank_driver.py +++ b/tests/unit/drivers/rerank/test_cohere_rerank_driver.py @@ -1,5 +1,6 @@ import pytest -from cohere import RerankResponseResultsItemDocument, RerankResponseResultsItem +from cohere import RerankResponseResultsItem, RerankResponseResultsItemDocument + from griptape.artifacts import TextArtifact from griptape.drivers import CohereRerankDriver diff --git a/tests/unit/drivers/sql/test_amazon_redshift_sql_driver.py b/tests/unit/drivers/sql/test_amazon_redshift_sql_driver.py index d67e1c5572..24cf8fd3c4 100644 --- a/tests/unit/drivers/sql/test_amazon_redshift_sql_driver.py +++ b/tests/unit/drivers/sql/test_amazon_redshift_sql_driver.py @@ -1,7 +1,8 @@ -import pytest import boto3 +import pytest from botocore.stub import Stubber -from griptape.drivers import BaseSqlDriver, AmazonRedshiftSqlDriver + +from griptape.drivers import AmazonRedshiftSqlDriver, BaseSqlDriver class TestAmazonRedshiftSqlDriver: diff --git a/tests/unit/drivers/sql/test_snowflake_sql_driver.py b/tests/unit/drivers/sql/test_snowflake_sql_driver.py index 91403a4671..818eed7ec0 100644 --- a/tests/unit/drivers/sql/test_snowflake_sql_driver.py +++ b/tests/unit/drivers/sql/test_snowflake_sql_driver.py @@ -1,8 +1,10 @@ from dataclasses import dataclass from unittest import mock + import pytest -from sqlalchemy import create_engine from snowflake.connector import SnowflakeConnection +from sqlalchemy import create_engine + from griptape.drivers import BaseSqlDriver, SnowflakeSqlDriver diff --git a/tests/unit/drivers/sql/test_sql_driver.py b/tests/unit/drivers/sql/test_sql_driver.py index d4caf6f509..46e07752c8 100644 --- a/tests/unit/drivers/sql/test_sql_driver.py +++ b/tests/unit/drivers/sql/test_sql_driver.py @@ -1,4 +1,5 @@ import pytest + from griptape.drivers import SqlDriver diff --git a/tests/unit/drivers/structure_run/test_griptape_cloud_structure_run_driver.py b/tests/unit/drivers/structure_run/test_griptape_cloud_structure_run_driver.py index b056241ec6..ffc6374a3c 100644 --- a/tests/unit/drivers/structure_run/test_griptape_cloud_structure_run_driver.py +++ b/tests/unit/drivers/structure_run/test_griptape_cloud_structure_run_driver.py @@ -1,5 +1,6 @@ import pytest -from griptape.artifacts import TextArtifact, InfoArtifact + +from griptape.artifacts import InfoArtifact, TextArtifact class TestGriptapeCloudStructureRunDriver: diff --git a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py index cb7b3058e5..339aa5d89b 100644 --- a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py +++ b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py @@ -1,11 +1,9 @@ import os -import pytest -from griptape.artifacts.text_artifact import TextArtifact + +from griptape.drivers import LocalStructureRunDriver +from griptape.structures import Agent, Pipeline from griptape.tasks import StructureRunTask -from griptape.structures import Agent from tests.mocks.mock_prompt_driver import MockPromptDriver -from griptape.drivers import LocalStructureRunDriver -from griptape.structures import Pipeline class TestLocalStructureRunDriver: @@ -22,7 +20,7 @@ def test_run(self): def test_run_with_env(self): pipeline = Pipeline() - agent = Agent(prompt_driver=MockPromptDriver(mock_output=lambda _: os.environ["key"])) + agent = Agent(prompt_driver=MockPromptDriver(mock_output=lambda _: os.environ["KEY"])) driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent, env={"key": "value"}) task = StructureRunTask(driver=driver) diff --git a/tests/unit/drivers/text_to_speech/test_elevenlabs_audio_generation_driver.py b/tests/unit/drivers/text_to_speech/test_elevenlabs_audio_generation_driver.py index 2d90bc2f59..68786bd2e3 100644 --- a/tests/unit/drivers/text_to_speech/test_elevenlabs_audio_generation_driver.py +++ b/tests/unit/drivers/text_to_speech/test_elevenlabs_audio_generation_driver.py @@ -1,5 +1,7 @@ -import pytest from unittest.mock import Mock + +import pytest + from griptape.drivers import ElevenLabsTextToSpeechDriver diff --git a/tests/unit/drivers/transcription/test_openai_audio_transcription_driver.py b/tests/unit/drivers/transcription/test_openai_audio_transcription_driver.py index 57c5a5e2e0..90e7d95240 100644 --- a/tests/unit/drivers/transcription/test_openai_audio_transcription_driver.py +++ b/tests/unit/drivers/transcription/test_openai_audio_transcription_driver.py @@ -1,6 +1,7 @@ -import pytest from unittest.mock import Mock +import pytest + from griptape.artifacts import AudioArtifact from griptape.drivers import OpenAiAudioTranscriptionDriver diff --git a/tests/unit/drivers/vector/test_amazon_opensearch_vector_store_driver.py b/tests/unit/drivers/vector/test_amazon_opensearch_vector_store_driver.py index b66cc057e3..e0ae3329fe 100644 --- a/tests/unit/drivers/vector/test_amazon_opensearch_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_amazon_opensearch_vector_store_driver.py @@ -1,8 +1,10 @@ +from unittest.mock import Mock, create_autospec, patch + +import boto3 +import numpy as np import pytest -from unittest.mock import patch, Mock, create_autospec + from griptape.drivers import AmazonOpenSearchVectorStoreDriver -import numpy as np -import boto3 class TestAmazonOpenSearchVectorStoreDriver: diff --git a/tests/unit/drivers/vector/test_azure_mongodb_vector_store_driver.py b/tests/unit/drivers/vector/test_azure_mongodb_vector_store_driver.py index b684869142..03f9a89bf4 100644 --- a/tests/unit/drivers/vector/test_azure_mongodb_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_azure_mongodb_vector_store_driver.py @@ -1,7 +1,6 @@ -import pytest import mongomock -from unittest.mock import patch -from pymongo.errors import OperationFailure +import pytest + from griptape.artifacts import TextArtifact from griptape.drivers import AzureMongoDbVectorStoreDriver, BaseVectorStoreDriver from tests.mocks.mock_embedding_driver import MockEmbeddingDriver diff --git a/tests/unit/drivers/vector/test_base_local_vector_store_driver.py b/tests/unit/drivers/vector/test_base_local_vector_store_driver.py index 6cd1763f64..674c3ffb9a 100644 --- a/tests/unit/drivers/vector/test_base_local_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_base_local_vector_store_driver.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod -import pytest from unittest.mock import patch + +import pytest + from griptape.artifacts import TextArtifact from griptape.artifacts.csv_row_artifact import CsvRowArtifact diff --git a/tests/unit/drivers/vector/test_dummy_vector_store_driver.py b/tests/unit/drivers/vector/test_dummy_vector_store_driver.py index abb7f3b38d..fcc144f8d1 100644 --- a/tests/unit/drivers/vector/test_dummy_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_dummy_vector_store_driver.py @@ -1,4 +1,5 @@ import pytest + from griptape.drivers import DummyVectorStoreDriver from griptape.exceptions import DummyException diff --git a/tests/unit/drivers/vector/test_griptape_cloud_knowledge_base_vector_store_driver.py b/tests/unit/drivers/vector/test_griptape_cloud_knowledge_base_vector_store_driver.py index 957edebb86..19d0f5cdca 100644 --- a/tests/unit/drivers/vector/test_griptape_cloud_knowledge_base_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_griptape_cloud_knowledge_base_vector_store_driver.py @@ -1,5 +1,7 @@ -import pytest import uuid + +import pytest + from griptape.drivers import GriptapeCloudKnowledgeBaseVectorStoreDriver diff --git a/tests/unit/drivers/vector/test_local_vector_store_driver.py b/tests/unit/drivers/vector/test_local_vector_store_driver.py index 937f14ece3..3612364e4d 100644 --- a/tests/unit/drivers/vector/test_local_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_local_vector_store_driver.py @@ -1,4 +1,5 @@ import pytest + from griptape.artifacts import TextArtifact from griptape.drivers import LocalVectorStoreDriver from tests.mocks.mock_embedding_driver import MockEmbeddingDriver diff --git a/tests/unit/drivers/vector/test_marqo_vector_store_driver.py b/tests/unit/drivers/vector/test_marqo_vector_store_driver.py index f42906035d..a0e0dc072a 100644 --- a/tests/unit/drivers/vector/test_marqo_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_marqo_vector_store_driver.py @@ -1,7 +1,9 @@ from collections import namedtuple + import pytest -from griptape.drivers import MarqoVectorStoreDriver + from griptape.artifacts import TextArtifact +from griptape.drivers import MarqoVectorStoreDriver from tests.mocks.mock_embedding_driver import MockEmbeddingDriver diff --git a/tests/unit/drivers/vector/test_mongodb_atlas_vector_store_driver.py b/tests/unit/drivers/vector/test_mongodb_atlas_vector_store_driver.py index 5b9aeed06e..39e59f1e2c 100644 --- a/tests/unit/drivers/vector/test_mongodb_atlas_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_mongodb_atlas_vector_store_driver.py @@ -1,9 +1,8 @@ -import pytest import mongomock -from unittest.mock import patch -from pymongo.errors import OperationFailure +import pytest + from griptape.artifacts import TextArtifact -from griptape.drivers import MongoDbAtlasVectorStoreDriver, BaseVectorStoreDriver +from griptape.drivers import BaseVectorStoreDriver, MongoDbAtlasVectorStoreDriver from tests.mocks.mock_embedding_driver import MockEmbeddingDriver diff --git a/tests/unit/drivers/vector/test_opensearch_vector_store_driver.py b/tests/unit/drivers/vector/test_opensearch_vector_store_driver.py index d2c967cafc..a010ae72cf 100644 --- a/tests/unit/drivers/vector/test_opensearch_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_opensearch_vector_store_driver.py @@ -1,7 +1,9 @@ +from unittest.mock import Mock, create_autospec, patch + +import numpy as np import pytest -from unittest.mock import patch, Mock, create_autospec + from griptape.drivers import OpenSearchVectorStoreDriver -import numpy as np class TestOpenSearchVectorStoreDriver: diff --git a/tests/unit/drivers/vector/test_persistent_local_vector_store_driver.py b/tests/unit/drivers/vector/test_persistent_local_vector_store_driver.py index 8f6773fc15..0c0f592e8f 100644 --- a/tests/unit/drivers/vector/test_persistent_local_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_persistent_local_vector_store_driver.py @@ -1,6 +1,8 @@ import os import tempfile + import pytest + from griptape.artifacts import TextArtifact from griptape.drivers import LocalVectorStoreDriver from tests.mocks.mock_embedding_driver import MockEmbeddingDriver diff --git a/tests/unit/drivers/vector/test_pgvector_vector_store_driver.py b/tests/unit/drivers/vector/test_pgvector_vector_store_driver.py index 3854ea4f1d..be0b954e2b 100644 --- a/tests/unit/drivers/vector/test_pgvector_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_pgvector_vector_store_driver.py @@ -1,10 +1,12 @@ -from typing import Any import uuid -import pytest +from typing import Any from unittest.mock import MagicMock, Mock + +import pytest +from sqlalchemy import create_engine + from griptape.drivers import PgVectorVectorStoreDriver from tests.mocks.mock_embedding_driver import MockEmbeddingDriver -from sqlalchemy import create_engine class TestPgVectorVectorStoreDriver: @@ -30,14 +32,14 @@ def mock_session(self, mocker): def test_initialize_requires_engine_or_connection_string(self, embedding_driver): with pytest.raises(ValueError): - driver = PgVectorVectorStoreDriver(embedding_driver=embedding_driver, table_name=self.table_name) + PgVectorVectorStoreDriver(embedding_driver=embedding_driver, table_name=self.table_name) def test_initialize_accepts_engine(self, embedding_driver): engine: Any = create_engine(self.connection_string) - driver = PgVectorVectorStoreDriver(embedding_driver=embedding_driver, engine=engine, table_name=self.table_name) + PgVectorVectorStoreDriver(embedding_driver=embedding_driver, engine=engine, table_name=self.table_name) def test_initialize_accepts_connection_string(self, embedding_driver): - driver = PgVectorVectorStoreDriver( + PgVectorVectorStoreDriver( embedding_driver=embedding_driver, connection_string=self.connection_string, table_name=self.table_name ) diff --git a/tests/unit/drivers/vector/test_pinecone_vector_storage_driver.py b/tests/unit/drivers/vector/test_pinecone_vector_storage_driver.py index 7aea4d4116..58bfde062e 100644 --- a/tests/unit/drivers/vector/test_pinecone_vector_storage_driver.py +++ b/tests/unit/drivers/vector/test_pinecone_vector_storage_driver.py @@ -1,17 +1,11 @@ import pytest -from griptape import utils from griptape.artifacts import TextArtifact from griptape.drivers import PineconeVectorStoreDriver from tests.mocks.mock_embedding_driver import MockEmbeddingDriver class TestPineconeVectorStorageDriver: - """ - This should really be under `unit` but the Pinecone client results - in tests hanging on GitHub. - """ - @pytest.fixture(autouse=True) def mock_pinecone(self, mocker): # Create a fake response diff --git a/tests/unit/drivers/vector/test_qdrant_vector_store_driver.py b/tests/unit/drivers/vector/test_qdrant_vector_store_driver.py index c999af8b93..8ba6dc2307 100644 --- a/tests/unit/drivers/vector/test_qdrant_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_qdrant_vector_store_driver.py @@ -1,9 +1,11 @@ -import pytest +import uuid from unittest.mock import MagicMock, patch + +import pytest + from griptape.drivers import QdrantVectorStoreDriver -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver from griptape.utils import import_optional_dependency -import uuid +from tests.mocks.mock_embedding_driver import MockEmbeddingDriver class TestQdrantVectorStoreDriver: diff --git a/tests/unit/drivers/vector/test_redis_vector_store_driver.py b/tests/unit/drivers/vector/test_redis_vector_store_driver.py index 18759a2d73..9db5f116cb 100644 --- a/tests/unit/drivers/vector/test_redis_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_redis_vector_store_driver.py @@ -1,7 +1,9 @@ from unittest.mock import MagicMock + import pytest -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver + from griptape.drivers import RedisVectorStoreDriver +from tests.mocks.mock_embedding_driver import MockEmbeddingDriver class TestRedisVectorStorageDriver: diff --git a/tests/unit/drivers/web_scraper/test_markdownify_web_scraper_driver.py b/tests/unit/drivers/web_scraper/test_markdownify_web_scraper_driver.py index 33500839fb..c9dbae4a8f 100644 --- a/tests/unit/drivers/web_scraper/test_markdownify_web_scraper_driver.py +++ b/tests/unit/drivers/web_scraper/test_markdownify_web_scraper_driver.py @@ -1,4 +1,5 @@ from textwrap import dedent + import pytest from griptape.drivers.web_scraper.markdownify_web_scraper_driver import MarkdownifyWebScraperDriver @@ -21,7 +22,7 @@ def web_scraper(self): def test_scrape_url(self, web_scraper): artifact = web_scraper.scrape_url("https://example.com/") - assert "[foobar](foobar.com)" == artifact.value + assert artifact.value == "[foobar](foobar.com)" def test_scrape_url_whitespace(self, web_scraper, mock_content): mock_content.return_value = dedent( @@ -46,35 +47,35 @@ def test_scrape_url_whitespace(self, web_scraper, mock_content): """ ) artifact = web_scraper.scrape_url("https://example.com/") - assert "foo\n---\n\n* bar:\n + baz\n + baz\n\n + baz" == artifact.value + assert artifact.value == "foo\n---\n\n* bar:\n + baz\n + baz\n\n + baz" def test_scrape_url_no_excludes(self): web_scraper = MarkdownifyWebScraperDriver(exclude_tags=[], exclude_classes=[], exclude_ids=[]) artifact = web_scraper.scrape_url("https://example.com/") - assert "[foobar](foobar.com)" == artifact.value + assert artifact.value == "[foobar](foobar.com)" def test_scrape_url_exclude_links(self): web_scraper = MarkdownifyWebScraperDriver(include_links=False) artifact = web_scraper.scrape_url("https://example.com/") - assert "foobar" == artifact.value + assert artifact.value == "foobar" def test_scrape_url_exclude_tags(self, mock_content): mock_content.return_value = "powwow" web_scraper = MarkdownifyWebScraperDriver(exclude_tags=["wow"], exclude_classes=[], exclude_ids=[]) artifact = web_scraper.scrape_url("https://example.com/") - assert "pow" == artifact.value + assert artifact.value == "pow" def test_scrape_url_exclude_classes(self, mock_content): mock_content.return_value = 'powwow' web_scraper = MarkdownifyWebScraperDriver(exclude_tags=[], exclude_classes=["now"], exclude_ids=[]) artifact = web_scraper.scrape_url("https://example.com/") - assert "pow" == artifact.value + assert artifact.value == "pow" def test_scrape_url_exclude_ids(self, mock_content): mock_content.return_value = 'powwow' web_scraper = MarkdownifyWebScraperDriver(exclude_tags=[], exclude_classes=[], exclude_ids=["cow"]) artifact = web_scraper.scrape_url("https://example.com/") - assert "pow" == artifact.value + assert artifact.value == "pow" def test_scrape_url_raises_on_empty_string_from_playwright(self, web_scraper, mock_content): mock_content.return_value = "" diff --git a/tests/unit/drivers/web_scraper/test_proxy_web_scraper_driver.py b/tests/unit/drivers/web_scraper/test_proxy_web_scraper_driver.py index 569800c6a0..0cc61d39db 100644 --- a/tests/unit/drivers/web_scraper/test_proxy_web_scraper_driver.py +++ b/tests/unit/drivers/web_scraper/test_proxy_web_scraper_driver.py @@ -1,7 +1,7 @@ import pytest -from griptape.drivers import ProxyWebScraperDriver from griptape.artifacts import TextArtifact +from griptape.drivers import ProxyWebScraperDriver class TestProxyWebScraperDriver: @@ -26,7 +26,7 @@ def test_scrape_url(self, web_scraper, mock_client): output = web_scraper.scrape_url("https://example.com/") mock_client.assert_called_with("https://example.com/", proxies=web_scraper.proxies, test_param="test_param") assert isinstance(output, TextArtifact) - assert "test_scrape" == output.value + assert output.value == "test_scrape" def test_scrape_url_error(self, web_scraper, mock_client_error): with pytest.raises(Exception, match="test_error"): diff --git a/tests/unit/drivers/web_search/test_duck_duck_go_web_search_driver.py b/tests/unit/drivers/web_search/test_duck_duck_go_web_search_driver.py index fcacc274c8..f3e835af1a 100644 --- a/tests/unit/drivers/web_search/test_duck_duck_go_web_search_driver.py +++ b/tests/unit/drivers/web_search/test_duck_duck_go_web_search_driver.py @@ -1,7 +1,9 @@ -import pytest import json -from griptape.drivers import DuckDuckGoWebSearchDriver + +import pytest + from griptape.artifacts import ListArtifact +from griptape.drivers import DuckDuckGoWebSearchDriver class TestDuckDuckGoWebSearchDriver: diff --git a/tests/unit/drivers/web_search/test_google_web_search_driver.py b/tests/unit/drivers/web_search/test_google_web_search_driver.py index 9ecb92f468..275d246ec1 100644 --- a/tests/unit/drivers/web_search/test_google_web_search_driver.py +++ b/tests/unit/drivers/web_search/test_google_web_search_driver.py @@ -1,10 +1,11 @@ -from pytest import fixture -import pytest -from griptape.drivers import GoogleWebSearchDriver -from griptape.artifacts import ErrorArtifact import json + +import pytest +from pytest import fixture from pytest_mock import MockerFixture +from griptape.drivers import GoogleWebSearchDriver + class TestGoogleWebSearchDriver: @fixture diff --git a/tests/unit/engines/extraction/test_csv_extraction_engine.py b/tests/unit/engines/extraction/test_csv_extraction_engine.py index ded595d59d..01125d1b1a 100644 --- a/tests/unit/engines/extraction/test_csv_extraction_engine.py +++ b/tests/unit/engines/extraction/test_csv_extraction_engine.py @@ -1,4 +1,5 @@ import pytest + from griptape.engines import CsvExtractionEngine from tests.mocks.mock_prompt_driver import MockPromptDriver diff --git a/tests/unit/engines/extraction/test_json_extraction_engine.py b/tests/unit/engines/extraction/test_json_extraction_engine.py index 797c5de7ab..bdf84d708b 100644 --- a/tests/unit/engines/extraction/test_json_extraction_engine.py +++ b/tests/unit/engines/extraction/test_json_extraction_engine.py @@ -1,5 +1,6 @@ import pytest from schema import Schema + from griptape.artifacts import ErrorArtifact from griptape.engines import JsonExtractionEngine from tests.mocks.mock_prompt_driver import MockPromptDriver diff --git a/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py b/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py index e5ba50a5be..ab4c0ba6a8 100644 --- a/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py +++ b/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py @@ -1,4 +1,5 @@ import pytest + from griptape.artifacts import TextArtifact from griptape.common import Reference from griptape.engines.rag import RagContext diff --git a/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py b/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py index f262d6d062..31d095e614 100644 --- a/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py +++ b/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py @@ -1,4 +1,5 @@ import pytest + from griptape.artifacts import TextArtifact from griptape.engines.rag import RagContext from griptape.engines.rag.modules import PromptResponseRagModule diff --git a/tests/unit/engines/rag/modules/generation/test_rulesets_before_response_rag_module.py b/tests/unit/engines/rag/modules/generation/test_rulesets_before_response_rag_module.py index 2750257f4c..bc85cf2664 100644 --- a/tests/unit/engines/rag/modules/generation/test_rulesets_before_response_rag_module.py +++ b/tests/unit/engines/rag/modules/generation/test_rulesets_before_response_rag_module.py @@ -1,6 +1,6 @@ from griptape.engines.rag import RagContext from griptape.engines.rag.modules import RulesetsBeforeResponseRagModule -from griptape.rules import Ruleset, Rule +from griptape.rules import Rule, Ruleset class TestRulesetsBeforeResponseRagModule: diff --git a/tests/unit/engines/rag/modules/generation/test_text_chunks_response_rag_module.py b/tests/unit/engines/rag/modules/generation/test_text_chunks_response_rag_module.py index 6488d650e6..6ad4853a2c 100644 --- a/tests/unit/engines/rag/modules/generation/test_text_chunks_response_rag_module.py +++ b/tests/unit/engines/rag/modules/generation/test_text_chunks_response_rag_module.py @@ -1,4 +1,5 @@ import pytest + from griptape.artifacts import TextArtifact from griptape.engines.rag import RagContext from griptape.engines.rag.modules import TextChunksResponseRagModule diff --git a/tests/unit/engines/rag/modules/retrieval/test_text_chunks_rerank_rag_module.py b/tests/unit/engines/rag/modules/retrieval/test_text_chunks_rerank_rag_module.py index fa3bfecb24..cfa493d5a9 100644 --- a/tests/unit/engines/rag/modules/retrieval/test_text_chunks_rerank_rag_module.py +++ b/tests/unit/engines/rag/modules/retrieval/test_text_chunks_rerank_rag_module.py @@ -1,5 +1,6 @@ import pytest from cohere import RerankResponseResultsItem, RerankResponseResultsItemDocument + from griptape.artifacts import TextArtifact from griptape.drivers import CohereRerankDriver from griptape.engines.rag import RagContext diff --git a/tests/unit/engines/rag/test_rag_engine.py b/tests/unit/engines/rag/test_rag_engine.py index a39c0c2f11..34be3ebed5 100644 --- a/tests/unit/engines/rag/test_rag_engine.py +++ b/tests/unit/engines/rag/test_rag_engine.py @@ -1,8 +1,9 @@ import pytest + from griptape.drivers import LocalVectorStoreDriver -from griptape.engines.rag import RagEngine, RagContext -from griptape.engines.rag.modules import VectorStoreRetrievalRagModule, PromptResponseRagModule -from griptape.engines.rag.stages import RetrievalRagStage, ResponseRagStage +from griptape.engines.rag import RagContext, RagEngine +from griptape.engines.rag.modules import PromptResponseRagModule, VectorStoreRetrievalRagModule +from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage from tests.mocks.mock_embedding_driver import MockEmbeddingDriver from tests.mocks.mock_prompt_driver import MockPromptDriver diff --git a/tests/unit/engines/summary/test_prompt_summary_engine.py b/tests/unit/engines/summary/test_prompt_summary_engine.py index 34c6e3563d..f58319fdcc 100644 --- a/tests/unit/engines/summary/test_prompt_summary_engine.py +++ b/tests/unit/engines/summary/test_prompt_summary_engine.py @@ -1,9 +1,11 @@ +import os + import pytest -from griptape.artifacts import TextArtifact, ListArtifact -from griptape.engines import PromptSummaryEngine + +from griptape.artifacts import ListArtifact, TextArtifact from griptape.common import PromptStack +from griptape.engines import PromptSummaryEngine from tests.mocks.mock_prompt_driver import MockPromptDriver -import os class TestPromptSummaryEngine: diff --git a/tests/unit/events/test_base_event.py b/tests/unit/events/test_base_event.py index 595c90f1fb..778f7c096a 100644 --- a/tests/unit/events/test_base_event.py +++ b/tests/unit/events/test_base_event.py @@ -1,17 +1,19 @@ import time + import pytest + from griptape.artifacts.base_artifact import BaseArtifact from griptape.events import ( - StartPromptEvent, + BaseEvent, + CompletionChunkEvent, + FinishActionsSubtaskEvent, FinishPromptEvent, - StartTaskEvent, + FinishStructureRunEvent, FinishTaskEvent, StartActionsSubtaskEvent, - FinishActionsSubtaskEvent, - CompletionChunkEvent, + StartPromptEvent, StartStructureRunEvent, - FinishStructureRunEvent, - BaseEvent, + StartTaskEvent, ) from tests.mocks.mock_event import MockEvent diff --git a/tests/unit/events/test_completion_chunk_event.py b/tests/unit/events/test_completion_chunk_event.py index aa9618a53c..71b94919ce 100644 --- a/tests/unit/events/test_completion_chunk_event.py +++ b/tests/unit/events/test_completion_chunk_event.py @@ -1,4 +1,5 @@ import pytest + from griptape.events import CompletionChunkEvent diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index 2f32837e08..b71d88ca6d 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -1,23 +1,25 @@ from unittest.mock import Mock + import pytest -from griptape.events.base_event import BaseEvent -from griptape.structures import Pipeline -from griptape.tasks import ToolkitTask, ActionsSubtask + from griptape.events import ( - StartTaskEvent, + CompletionChunkEvent, + EventListener, + FinishActionsSubtaskEvent, + FinishPromptEvent, + FinishStructureRunEvent, FinishTaskEvent, StartActionsSubtaskEvent, - FinishActionsSubtaskEvent, StartPromptEvent, - FinishPromptEvent, StartStructureRunEvent, - FinishStructureRunEvent, - CompletionChunkEvent, - EventListener, + StartTaskEvent, ) +from griptape.events.base_event import BaseEvent +from griptape.structures import Pipeline +from griptape.tasks import ActionsSubtask, ToolkitTask +from tests.mocks.mock_event import MockEvent from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool -from tests.mocks.mock_event import MockEvent class TestEventListener: @@ -107,7 +109,7 @@ def test_publish_event(self): mock_event_listener_driver = Mock() mock_event_listener_driver.try_publish_event_payload.return_value = None - def event_handler(_: BaseEvent): + def event_handler(_: BaseEvent) -> None: return None mock_event = MockEvent() diff --git a/tests/unit/events/test_finish_actions_subtask_event.py b/tests/unit/events/test_finish_actions_subtask_event.py index 14d7cbdde0..e38ac2d5d5 100644 --- a/tests/unit/events/test_finish_actions_subtask_event.py +++ b/tests/unit/events/test_finish_actions_subtask_event.py @@ -1,4 +1,5 @@ import pytest + from griptape.events import FinishActionsSubtaskEvent from griptape.structures import Agent from griptape.tasks import ActionsSubtask, ToolkitTask diff --git a/tests/unit/events/test_finish_prompt_event.py b/tests/unit/events/test_finish_prompt_event.py index 7443fce0c5..2161c0df7d 100644 --- a/tests/unit/events/test_finish_prompt_event.py +++ b/tests/unit/events/test_finish_prompt_event.py @@ -1,4 +1,5 @@ import pytest + from griptape.events import FinishPromptEvent diff --git a/tests/unit/events/test_finish_task_event.py b/tests/unit/events/test_finish_task_event.py index 40e71c9ea3..d1f4e25c90 100644 --- a/tests/unit/events/test_finish_task_event.py +++ b/tests/unit/events/test_finish_task_event.py @@ -1,6 +1,7 @@ import pytest -from griptape.structures import Agent + from griptape.events import FinishTaskEvent +from griptape.structures import Agent from griptape.tasks import PromptTask from tests.mocks.mock_prompt_driver import MockPromptDriver diff --git a/tests/unit/events/test_start_actions_subtask_event.py b/tests/unit/events/test_start_actions_subtask_event.py index d8b63de224..c468ae1161 100644 --- a/tests/unit/events/test_start_actions_subtask_event.py +++ b/tests/unit/events/test_start_actions_subtask_event.py @@ -1,4 +1,5 @@ import pytest + from griptape.events import StartActionsSubtaskEvent from griptape.structures import Agent from griptape.tasks import ActionsSubtask, ToolkitTask diff --git a/tests/unit/events/test_start_prompt_event.py b/tests/unit/events/test_start_prompt_event.py index 4ef08ec5cc..038d6646bb 100644 --- a/tests/unit/events/test_start_prompt_event.py +++ b/tests/unit/events/test_start_prompt_event.py @@ -1,6 +1,7 @@ import pytest -from griptape.events import StartPromptEvent + from griptape.common import PromptStack +from griptape.events import StartPromptEvent class TestStartPromptEvent: diff --git a/tests/unit/events/test_start_structure_run_event.py b/tests/unit/events/test_start_structure_run_event.py index c2f1b923db..3b67b0069e 100644 --- a/tests/unit/events/test_start_structure_run_event.py +++ b/tests/unit/events/test_start_structure_run_event.py @@ -1,4 +1,5 @@ import pytest + from griptape.artifacts.text_artifact import TextArtifact from griptape.events import StartStructureRunEvent diff --git a/tests/unit/events/test_start_task_event.py b/tests/unit/events/test_start_task_event.py index f4d243421b..be1d3d8f09 100644 --- a/tests/unit/events/test_start_task_event.py +++ b/tests/unit/events/test_start_task_event.py @@ -1,4 +1,5 @@ import pytest + from griptape.events import StartTaskEvent from griptape.structures import Agent from griptape.tasks import PromptTask diff --git a/tests/unit/loaders/test_audio_loader.py b/tests/unit/loaders/test_audio_loader.py index b7946da03c..605e8d1b37 100644 --- a/tests/unit/loaders/test_audio_loader.py +++ b/tests/unit/loaders/test_audio_loader.py @@ -32,8 +32,8 @@ def test_load_collection(self, create_source, loader): assert len(collection) == len(resource_paths) - keys = {loader.to_key(source) for source in sources} - for key in collection.keys(): + {loader.to_key(source) for source in sources} + for key in collection: artifact = collection[key] assert isinstance(artifact, AudioArtifact) assert artifact.name.endswith(".wav") diff --git a/tests/unit/loaders/test_blob_loader.py b/tests/unit/loaders/test_blob_loader.py index f2b4627268..4812e669c8 100644 --- a/tests/unit/loaders/test_blob_loader.py +++ b/tests/unit/loaders/test_blob_loader.py @@ -1,4 +1,5 @@ import pytest + from griptape.artifacts import BlobArtifact from griptape.loaders import BlobLoader diff --git a/tests/unit/loaders/test_csv_loader.py b/tests/unit/loaders/test_csv_loader.py index 579146ba24..89721077e3 100644 --- a/tests/unit/loaders/test_csv_loader.py +++ b/tests/unit/loaders/test_csv_loader.py @@ -1,4 +1,5 @@ import pytest + from griptape.loaders.csv_loader import CsvLoader from tests.mocks.mock_embedding_driver import MockEmbeddingDriver diff --git a/tests/unit/loaders/test_dataframe_loader.py b/tests/unit/loaders/test_dataframe_loader.py index 5365555589..51d878ec9b 100644 --- a/tests/unit/loaders/test_dataframe_loader.py +++ b/tests/unit/loaders/test_dataframe_loader.py @@ -1,7 +1,8 @@ import os + import pandas as pd import pytest -from griptape import utils + from griptape.loaders.dataframe_loader import DataFrameLoader from tests.mocks.mock_embedding_driver import MockEmbeddingDriver diff --git a/tests/unit/loaders/test_email_loader.py b/tests/unit/loaders/test_email_loader.py index ef69b83486..ce99075bf6 100644 --- a/tests/unit/loaders/test_email_loader.py +++ b/tests/unit/loaders/test_email_loader.py @@ -1,12 +1,14 @@ from __future__ import annotations +import email from email import message -from griptape.artifacts import ErrorArtifact, ListArtifact -from griptape.loaders import EmailLoader from typing import Optional -import email + import pytest +from griptape.artifacts import ErrorArtifact, ListArtifact +from griptape.loaders import EmailLoader + class TestEmailLoader: @pytest.fixture(autouse=True) diff --git a/tests/unit/loaders/test_pdf_loader.py b/tests/unit/loaders/test_pdf_loader.py index 0ab78b8b63..337a010b2d 100644 --- a/tests/unit/loaders/test_pdf_loader.py +++ b/tests/unit/loaders/test_pdf_loader.py @@ -1,8 +1,5 @@ -import os -from pathlib import Path -from typing import IO import pytest -from griptape import utils + from griptape.loaders import PdfLoader from tests.mocks.mock_embedding_driver import MockEmbeddingDriver diff --git a/tests/unit/loaders/test_sql_loader.py b/tests/unit/loaders/test_sql_loader.py index 8541e4fb8d..dcbf84a534 100644 --- a/tests/unit/loaders/test_sql_loader.py +++ b/tests/unit/loaders/test_sql_loader.py @@ -1,5 +1,6 @@ import pytest from sqlalchemy.pool import StaticPool + from griptape.drivers import SqlDriver from griptape.loaders import SqlLoader from tests.mocks.mock_embedding_driver import MockEmbeddingDriver diff --git a/tests/unit/loaders/test_text_loader.py b/tests/unit/loaders/test_text_loader.py index 0c59df12fe..07527f9e62 100644 --- a/tests/unit/loaders/test_text_loader.py +++ b/tests/unit/loaders/test_text_loader.py @@ -1,4 +1,5 @@ import pytest + from griptape.loaders.text_loader import TextLoader from tests.mocks.mock_embedding_driver import MockEmbeddingDriver diff --git a/tests/unit/loaders/test_web_loader.py b/tests/unit/loaders/test_web_loader.py index e265735397..57c68bc275 100644 --- a/tests/unit/loaders/test_web_loader.py +++ b/tests/unit/loaders/test_web_loader.py @@ -1,4 +1,5 @@ import pytest + from griptape.artifacts.error_artifact import ErrorArtifact from griptape.loaders import WebLoader from tests.mocks.mock_embedding_driver import MockEmbeddingDriver diff --git a/tests/unit/memory/meta/test_action_subtask_meta_entry.py b/tests/unit/memory/meta/test_action_subtask_meta_entry.py index f5da6ee017..c406677133 100644 --- a/tests/unit/memory/meta/test_action_subtask_meta_entry.py +++ b/tests/unit/memory/meta/test_action_subtask_meta_entry.py @@ -1,4 +1,5 @@ import pytest + from griptape.memory.meta import ActionSubtaskMetaEntry diff --git a/tests/unit/memory/meta/test_meta_memory.py b/tests/unit/memory/meta/test_meta_memory.py index bbdacf5b4a..73f552d0d6 100644 --- a/tests/unit/memory/meta/test_meta_memory.py +++ b/tests/unit/memory/meta/test_meta_memory.py @@ -1,5 +1,6 @@ import pytest -from griptape.memory.meta import MetaMemory, ActionSubtaskMetaEntry + +from griptape.memory.meta import ActionSubtaskMetaEntry, MetaMemory class TestMetaMemory: diff --git a/tests/unit/memory/structure/test_conversation_memory.py b/tests/unit/memory/structure/test_conversation_memory.py index 613d4b1fea..2ffd7b8cbd 100644 --- a/tests/unit/memory/structure/test_conversation_memory.py +++ b/tests/unit/memory/structure/test_conversation_memory.py @@ -1,12 +1,12 @@ import json -from griptape.structures import Agent + +from griptape.artifacts import TextArtifact from griptape.common import PromptStack -from griptape.memory.structure import ConversationMemory, Run, BaseConversationMemory -from griptape.structures import Pipeline +from griptape.memory.structure import BaseConversationMemory, ConversationMemory, Run +from griptape.structures import Agent, Pipeline +from griptape.tasks import PromptTask from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tokenizer import MockTokenizer -from griptape.tasks import PromptTask -from griptape.artifacts import TextArtifact class TestConversationMemory: diff --git a/tests/unit/memory/structure/test_summary_conversation_memory.py b/tests/unit/memory/structure/test_summary_conversation_memory.py index e625ac6c6b..4396c7b23d 100644 --- a/tests/unit/memory/structure/test_summary_conversation_memory.py +++ b/tests/unit/memory/structure/test_summary_conversation_memory.py @@ -1,9 +1,8 @@ import json - +from griptape.artifacts import TextArtifact from griptape.memory.structure import Run, SummaryConversationMemory from griptape.structures import Pipeline -from griptape.artifacts import TextArtifact from griptape.tasks import PromptTask from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_structure_config import MockStructureConfig diff --git a/tests/unit/memory/tool/storage/test_blob_artifact_storage.py b/tests/unit/memory/tool/storage/test_blob_artifact_storage.py index dd42b6bc2d..0b28870e10 100644 --- a/tests/unit/memory/tool/storage/test_blob_artifact_storage.py +++ b/tests/unit/memory/tool/storage/test_blob_artifact_storage.py @@ -1,4 +1,5 @@ import pytest + from griptape.artifacts import BlobArtifact, TextArtifact from griptape.memory.task.storage import BlobArtifactStorage diff --git a/tests/unit/memory/tool/storage/test_text_artifact_storage.py b/tests/unit/memory/tool/storage/test_text_artifact_storage.py index 706a80f0cc..a3814145c1 100644 --- a/tests/unit/memory/tool/storage/test_text_artifact_storage.py +++ b/tests/unit/memory/tool/storage/test_text_artifact_storage.py @@ -1,4 +1,5 @@ import pytest + from griptape.artifacts import BlobArtifact, TextArtifact from tests.utils import defaults diff --git a/tests/unit/memory/tool/test_task_memory.py b/tests/unit/memory/tool/test_task_memory.py index fc1cf75c79..279c3ca3f7 100644 --- a/tests/unit/memory/tool/test_task_memory.py +++ b/tests/unit/memory/tool/test_task_memory.py @@ -1,6 +1,6 @@ import pytest -from griptape.artifacts import CsvRowArtifact, BlobArtifact, ErrorArtifact, InfoArtifact -from griptape.artifacts import TextArtifact, ListArtifact + +from griptape.artifacts import BlobArtifact, CsvRowArtifact, ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact from griptape.memory import TaskMemory from griptape.memory.task.storage import BlobArtifactStorage, TextArtifactStorage from griptape.structures import Agent diff --git a/tests/unit/mixins/test_activity_mixin.py b/tests/unit/mixins/test_activity_mixin.py index 45db7cc7f6..5dff5aa0d8 100644 --- a/tests/unit/mixins/test_activity_mixin.py +++ b/tests/unit/mixins/test_activity_mixin.py @@ -1,5 +1,6 @@ import pytest -from schema import Schema, Literal, Optional +from schema import Literal, Optional, Schema + from tests.mocks.mock_tool.tool import MockTool diff --git a/tests/unit/mixins/test_image_artifact_file_output_mixin.py b/tests/unit/mixins/test_image_artifact_file_output_mixin.py index 69a2f1d71b..03c44e081b 100644 --- a/tests/unit/mixins/test_image_artifact_file_output_mixin.py +++ b/tests/unit/mixins/test_image_artifact_file_output_mixin.py @@ -19,7 +19,7 @@ def test_output_file(self): artifact = ImageArtifact(name="test.png", value=b"test", height=1, width=1, format="png") class Test(BlobArtifactFileOutputMixin): - def run(self): + def run(self) -> None: self._write_to_file(artifact) outfile = os.path.join(tempfile.gettempdir(), artifact.name) @@ -34,7 +34,7 @@ def test_output_dir(self): artifact = ImageArtifact(name="test.png", value=b"test", height=1, width=1, format="png") class Test(BlobArtifactFileOutputMixin): - def run(self): + def run(self) -> None: self._write_to_file(artifact) outdir = tempfile.gettempdir() diff --git a/tests/unit/mixins/test_seriliazable_mixin.py b/tests/unit/mixins/test_seriliazable_mixin.py index 1704000e33..afb3d1eb45 100644 --- a/tests/unit/mixins/test_seriliazable_mixin.py +++ b/tests/unit/mixins/test_seriliazable_mixin.py @@ -1,11 +1,13 @@ import json + import pytest + +from griptape.artifacts import BaseArtifact, TextArtifact from griptape.drivers import OpenAiChatPromptDriver -from griptape.memory.structure import ConversationMemory from griptape.memory import TaskMemory -from tests.mocks.mock_serializable import MockSerializable +from griptape.memory.structure import ConversationMemory from griptape.schemas import BaseSchema -from griptape.artifacts import BaseArtifact, TextArtifact +from tests.mocks.mock_serializable import MockSerializable class TestSerializableMixin: diff --git a/tests/unit/schemas/test_base_schema.py b/tests/unit/schemas/test_base_schema.py index fcbd08c7f4..f3a3f0c1f9 100644 --- a/tests/unit/schemas/test_base_schema.py +++ b/tests/unit/schemas/test_base_schema.py @@ -1,13 +1,16 @@ from __future__ import annotations + from datetime import datetime +from typing import Literal, Optional, Union + import pytest -from typing import Union, Optional, Literal from marshmallow import fields + from griptape.artifacts import BaseArtifact, TextArtifact +from griptape.loaders import TextLoader from griptape.schemas import PolymorphicSchema -from griptape.schemas.bytes_field import Bytes from griptape.schemas.base_schema import BaseSchema -from griptape.loaders import TextLoader +from griptape.schemas.bytes_field import Bytes from tests.mocks.mock_serializable import MockSerializable @@ -62,8 +65,8 @@ def test_get_field_type_info(self): assert BaseSchema._get_field_type_info(list) == (list, (), False) - assert BaseSchema._get_field_type_info(Literal["foo"]) == (str, (), False) # pyright: ignore - assert BaseSchema._get_field_type_info(Literal[5]) == (int, (), False) # pyright: ignore + assert BaseSchema._get_field_type_info(Literal["foo"]) == (str, (), False) # pyright: ignore[reportArgumentType] + assert BaseSchema._get_field_type_info(Literal[5]) == (int, (), False) # pyright: ignore[reportArgumentType] def test_is_list_sequence(self): assert BaseSchema.is_list_sequence(list) diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index baceac8253..414861e29e 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -1,15 +1,15 @@ import pytest -from griptape.memory.structure import ConversationMemory + +from griptape.engines import PromptSummaryEngine from griptape.memory import TaskMemory +from griptape.memory.structure import ConversationMemory from griptape.memory.task.storage import TextArtifactStorage from griptape.rules import Rule, Ruleset from griptape.structures import Agent -from griptape.tasks import PromptTask, BaseTask, ToolkitTask -from griptape.engines import PromptSummaryEngine - +from griptape.tasks import BaseTask, PromptTask, ToolkitTask +from tests.mocks.mock_embedding_driver import MockEmbeddingDriver from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver class TestAgent: @@ -149,13 +149,13 @@ def test_add_tasks(self): try: agent.add_tasks(first_task, second_task) - assert False + raise AssertionError() except ValueError: assert True try: agent + [first_task, second_task] - assert False + raise AssertionError() except ValueError: assert True diff --git a/tests/unit/structures/test_pipeline.py b/tests/unit/structures/test_pipeline.py index e63937a629..a9028cf083 100644 --- a/tests/unit/structures/test_pipeline.py +++ b/tests/unit/structures/test_pipeline.py @@ -1,14 +1,15 @@ -import pytest import time -from griptape.artifacts import TextArtifact, ErrorArtifact +import pytest + +from griptape.artifacts import ErrorArtifact, TextArtifact +from griptape.memory.structure import ConversationMemory from griptape.memory.task.storage import TextArtifactStorage from griptape.rules import Rule, Ruleset +from griptape.structures import Pipeline +from griptape.tasks import BaseTask, CodeExecutionTask, PromptTask, ToolkitTask from griptape.tokenizers import OpenAiTokenizer -from griptape.tasks import PromptTask, BaseTask, ToolkitTask, CodeExecutionTask -from griptape.memory.structure import ConversationMemory from tests.mocks.mock_prompt_driver import MockPromptDriver -from griptape.structures import Pipeline from tests.mocks.mock_tool.tool import MockTool from tests.unit.structures.test_agent import MockEmbeddingDriver diff --git a/tests/unit/structures/test_workflow.py b/tests/unit/structures/test_workflow.py index bf55e852fd..44b27f1bad 100644 --- a/tests/unit/structures/test_workflow.py +++ b/tests/unit/structures/test_workflow.py @@ -1,16 +1,17 @@ import time -import pytest +import pytest from pytest import fixture + +from griptape.artifacts import ErrorArtifact, TextArtifact +from griptape.memory.structure import ConversationMemory from griptape.memory.task.storage import TextArtifactStorage -from tests.mocks.mock_prompt_driver import MockPromptDriver from griptape.rules import Rule, Ruleset -from griptape.tasks import PromptTask, BaseTask, ToolkitTask, CodeExecutionTask from griptape.structures import Workflow -from griptape.artifacts import ErrorArtifact, TextArtifact -from griptape.memory.structure import ConversationMemory -from tests.mocks.mock_tool.tool import MockTool +from griptape.tasks import BaseTask, CodeExecutionTask, PromptTask, ToolkitTask from tests.mocks.mock_embedding_driver import MockEmbeddingDriver +from tests.mocks.mock_prompt_driver import MockPromptDriver +from tests.mocks.mock_tool.tool import MockTool class TestWorkflow: @@ -777,7 +778,7 @@ def test_run_with_error_artifact_no_fail_fast(self, error_artifact_task, waiting assert workflow.output is not None @staticmethod - def _validate_topology_1(workflow): + def _validate_topology_1(workflow) -> None: assert len(workflow.tasks) == 4 assert workflow.input_task.id == "task1" assert workflow.output_task.id == "task4" @@ -805,8 +806,8 @@ def _validate_topology_1(workflow): assert task4.child_ids == [] @staticmethod - def _validate_topology_2(workflow): - """Adapted from https://en.wikipedia.org/wiki/Directed_acyclic_graph#/media/File:Tred-G.svg""" + def _validate_topology_2(workflow) -> None: + """Adapted from https://en.wikipedia.org/wiki/Directed_acyclic_graph#/media/File:Tred-G.svg.""" assert len(workflow.tasks) == 5 assert workflow.input_task.id == "taska" assert workflow.output_task.id == "taske" @@ -839,7 +840,7 @@ def _validate_topology_2(workflow): assert taske.child_ids == [] @staticmethod - def _validate_topology_3(workflow): + def _validate_topology_3(workflow) -> None: assert len(workflow.tasks) == 4 assert workflow.input_task.id == "task1" assert workflow.output_task.id == "task3" @@ -867,7 +868,7 @@ def _validate_topology_3(workflow): assert task4.child_ids == ["task2"] @staticmethod - def _validate_topology_4(workflow): + def _validate_topology_4(workflow) -> None: assert len(workflow.tasks) == 9 assert workflow.input_task.id == "collect_movie_info" assert workflow.output_task.id == "summarize_to_slack" diff --git a/tests/unit/tasks/test_actions_subtask.py b/tests/unit/tasks/test_actions_subtask.py index c6e5ca0384..e25a42120a 100644 --- a/tests/unit/tasks/test_actions_subtask.py +++ b/tests/unit/tasks/test_actions_subtask.py @@ -1,10 +1,11 @@ import json -from griptape.artifacts import ListArtifact, TextArtifact, ActionArtifact + +from griptape.artifacts import ActionArtifact, ListArtifact, TextArtifact from griptape.artifacts.error_artifact import ErrorArtifact -from tests.mocks.mock_tool.tool import MockTool -from griptape.tasks import ToolkitTask, ActionsSubtask -from griptape.structures import Agent from griptape.common import ToolAction +from griptape.structures import Agent +from griptape.tasks import ActionsSubtask, ToolkitTask +from tests.mocks.mock_tool.tool import MockTool class TestActionsSubtask: diff --git a/tests/unit/tasks/test_audio_transcription_task.py b/tests/unit/tasks/test_audio_transcription_task.py index 3a53fd49dc..66d26f08eb 100644 --- a/tests/unit/tasks/test_audio_transcription_task.py +++ b/tests/unit/tasks/test_audio_transcription_task.py @@ -5,7 +5,7 @@ from griptape.artifacts import AudioArtifact, TextArtifact from griptape.engines import AudioTranscriptionEngine from griptape.structures import Agent, Pipeline -from griptape.tasks import BaseTask, AudioTranscriptionTask +from griptape.tasks import AudioTranscriptionTask, BaseTask from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_structure_config import MockStructureConfig @@ -40,14 +40,9 @@ def test_config_audio_transcription_engine(self, audio_artifact): def test_run(self, audio_artifact, audio_transcription_engine): audio_transcription_engine.run.return_value = TextArtifact("mock transcription") - logger = Mock() task = AudioTranscriptionTask(audio_artifact, audio_transcription_engine=audio_transcription_engine) - pipeline = Pipeline(prompt_driver=MockPromptDriver(), logger=logger) + pipeline = Pipeline(prompt_driver=MockPromptDriver()) pipeline.add_task(task) assert pipeline.run().output.to_text() == "mock transcription" - - def test_before_run(self, audio_artifact, audio_transcription_engine): - task = AudioTranscriptionTask(audio_artifact, audio_transcription_engine=audio_transcription_engine) - task diff --git a/tests/unit/tasks/test_base_audio_input_task.py b/tests/unit/tasks/test_base_audio_input_task.py index e110748802..6ace6f45b3 100644 --- a/tests/unit/tasks/test_base_audio_input_task.py +++ b/tests/unit/tasks/test_base_audio_input_task.py @@ -1,7 +1,7 @@ import pytest -from tests.mocks.mock_audio_input_task import MockAudioInputTask from griptape.artifacts import AudioArtifact, TextArtifact +from tests.mocks.mock_audio_input_task import MockAudioInputTask from tests.mocks.mock_text_input_task import MockTextInputTask diff --git a/tests/unit/tasks/test_base_multi_text_input_task.py b/tests/unit/tasks/test_base_multi_text_input_task.py index ad4776aeeb..3d8d67a55e 100644 --- a/tests/unit/tasks/test_base_multi_text_input_task.py +++ b/tests/unit/tasks/test_base_multi_text_input_task.py @@ -1,7 +1,7 @@ -from tests.mocks.mock_prompt_driver import MockPromptDriver -from griptape.structures import Pipeline from griptape.artifacts import TextArtifact +from griptape.structures import Pipeline from tests.mocks.mock_multi_text_input_task import MockMultiTextInputTask +from tests.mocks.mock_prompt_driver import MockPromptDriver class TestBaseMultiTextInputTask: diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index 7fe2810f51..bb60bae61f 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -1,9 +1,8 @@ import pytest from griptape.artifacts import TextArtifact -from griptape.structures import Agent +from griptape.structures import Agent, Workflow from griptape.tasks import ActionsSubtask -from griptape.structures import Workflow from tests.mocks.mock_embedding_driver import MockEmbeddingDriver from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_task import MockTask diff --git a/tests/unit/tasks/test_base_text_input_task.py b/tests/unit/tasks/test_base_text_input_task.py index 14c3c3f2ef..86dc98805a 100644 --- a/tests/unit/tasks/test_base_text_input_task.py +++ b/tests/unit/tasks/test_base_text_input_task.py @@ -1,7 +1,7 @@ -from tests.mocks.mock_prompt_driver import MockPromptDriver -from griptape.structures import Pipeline from griptape.artifacts import TextArtifact -from griptape.rules import Ruleset, Rule +from griptape.rules import Rule, Ruleset +from griptape.structures import Pipeline +from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_text_input_task import MockTextInputTask diff --git a/tests/unit/tasks/test_code_execution_task.py b/tests/unit/tasks/test_code_execution_task.py index b947149122..3178e29dbe 100644 --- a/tests/unit/tasks/test_code_execution_task.py +++ b/tests/unit/tasks/test_code_execution_task.py @@ -1,4 +1,4 @@ -from griptape.artifacts import BaseArtifact, TextArtifact, ErrorArtifact +from griptape.artifacts import BaseArtifact, ErrorArtifact, TextArtifact from griptape.structures import Pipeline from griptape.tasks import CodeExecutionTask from tests.mocks.mock_prompt_driver import MockPromptDriver diff --git a/tests/unit/tasks/test_csv_extraction_task.py b/tests/unit/tasks/test_csv_extraction_task.py index 9b8fb15bbd..14c426fe76 100644 --- a/tests/unit/tasks/test_csv_extraction_task.py +++ b/tests/unit/tasks/test_csv_extraction_task.py @@ -30,4 +30,4 @@ def test_config_extraction_engine(self, task): def test_missing_extraction_engine(self, task): with pytest.raises(ValueError): - task.extraction_engine + task.extraction_engine # noqa: B018 diff --git a/tests/unit/tasks/test_extraction_task.py b/tests/unit/tasks/test_extraction_task.py index 5e5ec09f6a..ac6a5c34e4 100644 --- a/tests/unit/tasks/test_extraction_task.py +++ b/tests/unit/tasks/test_extraction_task.py @@ -1,4 +1,5 @@ import pytest + from griptape.engines import CsvExtractionEngine from griptape.structures import Agent from griptape.tasks import ExtractionTask diff --git a/tests/unit/tasks/test_image_query_task.py b/tests/unit/tasks/test_image_query_task.py index dd49402133..95406a95d4 100644 --- a/tests/unit/tasks/test_image_query_task.py +++ b/tests/unit/tasks/test_image_query_task.py @@ -70,7 +70,7 @@ def test_missing_image_generation_engine(self, text_artifact, image_artifact): task = ImageQueryTask((text_artifact, [image_artifact, image_artifact])) with pytest.raises(ValueError, match="Image Query Engine"): - task.image_query_engine + task.image_query_engine # noqa: B018 def test_run(self, image_query_engine, text_artifact, image_artifact): task = ImageQueryTask((text_artifact, [image_artifact, image_artifact]), image_query_engine=image_query_engine) diff --git a/tests/unit/tasks/test_inpainting_image_generation_task.py b/tests/unit/tasks/test_inpainting_image_generation_task.py index 9dc6aff54e..afade8d394 100644 --- a/tests/unit/tasks/test_inpainting_image_generation_task.py +++ b/tests/unit/tasks/test_inpainting_image_generation_task.py @@ -1,11 +1,12 @@ -from griptape.artifacts.list_artifact import ListArtifact -from griptape.engines import InpaintingImageGenerationEngine from unittest.mock import Mock import pytest -from griptape.tasks import BaseTask, InpaintingImageGenerationTask -from griptape.artifacts import TextArtifact, ImageArtifact + +from griptape.artifacts import ImageArtifact, TextArtifact +from griptape.artifacts.list_artifact import ListArtifact +from griptape.engines import InpaintingImageGenerationEngine from griptape.structures import Agent +from griptape.tasks import BaseTask, InpaintingImageGenerationTask from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver from tests.mocks.mock_structure_config import MockStructureConfig @@ -59,4 +60,4 @@ def test_missing_image_generation_engine(self, text_artifact, image_artifact): task = InpaintingImageGenerationTask((text_artifact, image_artifact, image_artifact)) with pytest.raises(ValueError): - task.image_generation_engine + task.image_generation_engine # noqa: B018 diff --git a/tests/unit/tasks/test_json_extraction_task.py b/tests/unit/tasks/test_json_extraction_task.py index 0366652b06..f42448663a 100644 --- a/tests/unit/tasks/test_json_extraction_task.py +++ b/tests/unit/tasks/test_json_extraction_task.py @@ -1,10 +1,11 @@ -from griptape.engines import JsonExtractionEngine import pytest from schema import Schema + +from griptape.engines import JsonExtractionEngine from griptape.structures import Agent from griptape.tasks import JsonExtractionTask -from tests.mocks.mock_structure_config import MockStructureConfig from tests.mocks.mock_prompt_driver import MockPromptDriver +from tests.mocks.mock_structure_config import MockStructureConfig class TestJsonExtractionTask: @@ -34,4 +35,4 @@ def test_config_extraction_engine(self, task): def test_missing_extraction_engine(self, task): with pytest.raises(ValueError): - task.extraction_engine + task.extraction_engine # noqa: B018 diff --git a/tests/unit/tasks/test_outpainting_image_generation_task.py b/tests/unit/tasks/test_outpainting_image_generation_task.py index 148ea133df..1e59711645 100644 --- a/tests/unit/tasks/test_outpainting_image_generation_task.py +++ b/tests/unit/tasks/test_outpainting_image_generation_task.py @@ -1,10 +1,10 @@ -from griptape.artifacts.list_artifact import ListArtifact -from griptape.engines import OutpaintingImageGenerationEngine from unittest.mock import Mock import pytest from griptape.artifacts import ImageArtifact, TextArtifact +from griptape.artifacts.list_artifact import ListArtifact +from griptape.engines import OutpaintingImageGenerationEngine from griptape.structures import Agent from griptape.tasks import BaseTask, OutpaintingImageGenerationTask from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver @@ -60,4 +60,4 @@ def test_missing_image_generation_engine(self, text_artifact, image_artifact): task = OutpaintingImageGenerationTask((text_artifact, image_artifact, image_artifact)) with pytest.raises(ValueError): - task.image_generation_engine + task.image_generation_engine # noqa: B018 diff --git a/tests/unit/tasks/test_prompt_image_generation_task.py b/tests/unit/tasks/test_prompt_image_generation_task.py index 4f6117c070..c3add57206 100644 --- a/tests/unit/tasks/test_prompt_image_generation_task.py +++ b/tests/unit/tasks/test_prompt_image_generation_task.py @@ -1,4 +1,3 @@ -from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver from unittest.mock import Mock import pytest @@ -7,6 +6,7 @@ from griptape.engines import PromptImageGenerationEngine from griptape.structures import Agent from griptape.tasks import BaseTask, PromptImageGenerationTask +from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver from tests.mocks.mock_structure_config import MockStructureConfig @@ -37,4 +37,4 @@ def test_missing_summary_engine(self): task = PromptImageGenerationTask("foo bar") with pytest.raises(ValueError): - task.image_generation_engine + task.image_generation_engine # noqa: B018 diff --git a/tests/unit/tasks/test_prompt_task.py b/tests/unit/tasks/test_prompt_task.py index c76f0284be..083ea6da53 100644 --- a/tests/unit/tasks/test_prompt_task.py +++ b/tests/unit/tasks/test_prompt_task.py @@ -1,14 +1,15 @@ import pytest + from griptape.artifacts.image_artifact import ImageArtifact from griptape.artifacts.list_artifact import ListArtifact from griptape.artifacts.text_artifact import TextArtifact from griptape.memory.structure import ConversationMemory from griptape.memory.structure.run import Run -from tests.mocks.mock_structure_config import MockStructureConfig -from griptape.tasks import PromptTask from griptape.rules import Rule -from tests.mocks.mock_prompt_driver import MockPromptDriver from griptape.structures import Pipeline +from griptape.tasks import PromptTask +from tests.mocks.mock_prompt_driver import MockPromptDriver +from tests.mocks.mock_structure_config import MockStructureConfig class TestPromptTask: @@ -37,7 +38,7 @@ def test_missing_prompt_driver(self): task = PromptTask("test") with pytest.raises(ValueError): - task.prompt_driver + task.prompt_driver # noqa: B018 def test_input(self): # Str diff --git a/tests/unit/tasks/test_rag_task.py b/tests/unit/tasks/test_rag_task.py index c9b82f2083..c9a6a844d4 100644 --- a/tests/unit/tasks/test_rag_task.py +++ b/tests/unit/tasks/test_rag_task.py @@ -1,4 +1,5 @@ import pytest + from griptape.engines.rag import RagEngine from griptape.engines.rag.modules import PromptResponseRagModule from griptape.engines.rag.stages import ResponseRagStage diff --git a/tests/unit/tasks/test_structure_run_task.py b/tests/unit/tasks/test_structure_run_task.py index d89e98c917..1053ade9e1 100644 --- a/tests/unit/tasks/test_structure_run_task.py +++ b/tests/unit/tasks/test_structure_run_task.py @@ -1,8 +1,7 @@ +from griptape.drivers import LocalStructureRunDriver +from griptape.structures import Agent, Pipeline from griptape.tasks import StructureRunTask -from griptape.structures import Agent from tests.mocks.mock_prompt_driver import MockPromptDriver -from griptape.drivers import LocalStructureRunDriver -from griptape.structures import Pipeline class TestStructureRunTask: diff --git a/tests/unit/tasks/test_text_summary_task.py b/tests/unit/tasks/test_text_summary_task.py index d7a474373b..bb08f9d312 100644 --- a/tests/unit/tasks/test_text_summary_task.py +++ b/tests/unit/tasks/test_text_summary_task.py @@ -1,9 +1,10 @@ import pytest -from tests.mocks.mock_structure_config import MockStructureConfig + from griptape.engines import PromptSummaryEngine +from griptape.structures import Agent from griptape.tasks import TextSummaryTask from tests.mocks.mock_prompt_driver import MockPromptDriver -from griptape.structures import Agent +from tests.mocks.mock_structure_config import MockStructureConfig class TestTextSummaryTask: @@ -34,4 +35,4 @@ def test_missing_summary_engine(self): task = TextSummaryTask("test") with pytest.raises(ValueError): - task.summary_engine + task.summary_engine # noqa: B018 diff --git a/tests/unit/tasks/test_text_to_speech_task.py b/tests/unit/tasks/test_text_to_speech_task.py index 7a8e49364c..86bc1d2cee 100644 --- a/tests/unit/tasks/test_text_to_speech_task.py +++ b/tests/unit/tasks/test_text_to_speech_task.py @@ -1,6 +1,6 @@ from unittest.mock import Mock -from griptape.artifacts import TextArtifact, AudioArtifact +from griptape.artifacts import AudioArtifact, TextArtifact from griptape.engines import TextToSpeechEngine from griptape.structures import Agent, Pipeline from griptape.tasks import BaseTask, TextToSpeechTask diff --git a/tests/unit/tasks/test_tool_task.py b/tests/unit/tasks/test_tool_task.py index 2af8b73c3a..c30a08218e 100644 --- a/tests/unit/tasks/test_tool_task.py +++ b/tests/unit/tasks/test_tool_task.py @@ -1,8 +1,10 @@ import json + import pytest + from griptape.artifacts import TextArtifact from griptape.structures import Agent -from griptape.tasks import ToolTask, ActionsSubtask +from griptape.tasks import ActionsSubtask, ToolTask from tests.mocks.mock_embedding_driver import MockEmbeddingDriver from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index 2217ba70c8..cd5dd21f8a 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -1,9 +1,9 @@ from griptape.artifacts import ErrorArtifact, TextArtifact -from griptape.structures import Agent -from griptape.tasks import ToolkitTask, ActionsSubtask, PromptTask from griptape.common import ToolAction -from tests.mocks.mock_tool.tool import MockTool +from griptape.structures import Agent +from griptape.tasks import ActionsSubtask, PromptTask, ToolkitTask from tests.mocks.mock_prompt_driver import MockPromptDriver +from tests.mocks.mock_tool.tool import MockTool from tests.utils import defaults @@ -166,7 +166,7 @@ def test_init(self): try: ToolkitTask("test", tools=[MockTool(), MockTool()]) - assert False + raise AssertionError() except ValueError: assert True @@ -231,7 +231,7 @@ def test_init_from_prompt_1(self): assert subtask.output is None def test_init_from_prompt_2(self): - valid_input = """Thought: need to test\nObservation: test + valid_input = """Thought: need to test\nObservation: test observation\nAnswer: test output""" task = ToolkitTask("test", tools=[MockTool(name="Tool1")]) diff --git a/tests/unit/tasks/test_variation_image_generation_task.py b/tests/unit/tasks/test_variation_image_generation_task.py index 6a9533da36..1189e7fc62 100644 --- a/tests/unit/tasks/test_variation_image_generation_task.py +++ b/tests/unit/tasks/test_variation_image_generation_task.py @@ -1,13 +1,14 @@ -from griptape.artifacts.list_artifact import ListArtifact -from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver -from tests.mocks.mock_structure_config import MockStructureConfig from unittest.mock import Mock import pytest -from griptape.tasks import BaseTask, VariationImageGenerationTask -from griptape.artifacts import TextArtifact, ImageArtifact + +from griptape.artifacts import ImageArtifact, TextArtifact +from griptape.artifacts.list_artifact import ListArtifact from griptape.engines import VariationImageGenerationEngine from griptape.structures import Agent +from griptape.tasks import BaseTask, VariationImageGenerationTask +from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver +from tests.mocks.mock_structure_config import MockStructureConfig class TestVariationImageGenerationTask: @@ -56,4 +57,4 @@ def test_missing_summary_engine(self, text_artifact, image_artifact): task = VariationImageGenerationTask((text_artifact, image_artifact)) with pytest.raises(ValueError): - task.image_generation_engine + task.image_generation_engine # noqa: B018 diff --git a/tests/unit/tokenizers/test_amazon_bedrock_tokenizer.py b/tests/unit/tokenizers/test_amazon_bedrock_tokenizer.py index 2b77ba3dc1..a5a4fbba54 100644 --- a/tests/unit/tokenizers/test_amazon_bedrock_tokenizer.py +++ b/tests/unit/tokenizers/test_amazon_bedrock_tokenizer.py @@ -1,6 +1,7 @@ -from griptape.tokenizers import AmazonBedrockTokenizer import pytest +from griptape.tokenizers import AmazonBedrockTokenizer + class TestAmazonBedrockTokenizer: @pytest.fixture diff --git a/tests/unit/tokenizers/test_anthropic_tokenizer.py b/tests/unit/tokenizers/test_anthropic_tokenizer.py index ee165d2705..70c6fec966 100644 --- a/tests/unit/tokenizers/test_anthropic_tokenizer.py +++ b/tests/unit/tokenizers/test_anthropic_tokenizer.py @@ -1,4 +1,5 @@ import pytest + from griptape.tokenizers import AnthropicTokenizer diff --git a/tests/unit/tokenizers/test_base_tokenizer.py b/tests/unit/tokenizers/test_base_tokenizer.py index eed15b9b2b..08fd42c726 100644 --- a/tests/unit/tokenizers/test_base_tokenizer.py +++ b/tests/unit/tokenizers/test_base_tokenizer.py @@ -1,4 +1,5 @@ import logging + from tests.mocks.mock_tokenizer import MockTokenizer diff --git a/tests/unit/tokenizers/test_cohere_tokenizer.py b/tests/unit/tokenizers/test_cohere_tokenizer.py index 9ca23f4f0c..049f94d247 100644 --- a/tests/unit/tokenizers/test_cohere_tokenizer.py +++ b/tests/unit/tokenizers/test_cohere_tokenizer.py @@ -1,5 +1,6 @@ import cohere import pytest + from griptape.tokenizers import CohereTokenizer diff --git a/tests/unit/tokenizers/test_dummy_tokenizer.py b/tests/unit/tokenizers/test_dummy_tokenizer.py index 855fb6eee4..90269eb2a7 100644 --- a/tests/unit/tokenizers/test_dummy_tokenizer.py +++ b/tests/unit/tokenizers/test_dummy_tokenizer.py @@ -1,4 +1,5 @@ import pytest + from griptape.exceptions import DummyException from griptape.tokenizers import DummyTokenizer diff --git a/tests/unit/tokenizers/test_google_tokenizer.py b/tests/unit/tokenizers/test_google_tokenizer.py index 34510cdacf..60c14f889f 100644 --- a/tests/unit/tokenizers/test_google_tokenizer.py +++ b/tests/unit/tokenizers/test_google_tokenizer.py @@ -1,5 +1,7 @@ -import pytest from unittest.mock import Mock + +import pytest + from griptape.common import PromptStack from griptape.common.prompt_stack.messages.message import Message from griptape.tokenizers import GoogleTokenizer diff --git a/tests/unit/tokenizers/test_hugging_face_tokenizer.py b/tests/unit/tokenizers/test_hugging_face_tokenizer.py index dcb309a845..2c98a33f47 100644 --- a/tests/unit/tokenizers/test_hugging_face_tokenizer.py +++ b/tests/unit/tokenizers/test_hugging_face_tokenizer.py @@ -3,6 +3,7 @@ environ["TRANSFORMERS_VERBOSITY"] = "error" import pytest # noqa: E402 + from griptape.tokenizers import HuggingFaceTokenizer # noqa: E402 diff --git a/tests/unit/tokenizers/test_openai_tokenizer.py b/tests/unit/tokenizers/test_openai_tokenizer.py index 4aa42a87a2..026143ab72 100644 --- a/tests/unit/tokenizers/test_openai_tokenizer.py +++ b/tests/unit/tokenizers/test_openai_tokenizer.py @@ -1,4 +1,5 @@ import pytest + from griptape.tokenizers import OpenAiTokenizer diff --git a/tests/unit/tokenizers/test_simple_tokenizer.py b/tests/unit/tokenizers/test_simple_tokenizer.py index a34c0d4818..7402f5e0e9 100644 --- a/tests/unit/tokenizers/test_simple_tokenizer.py +++ b/tests/unit/tokenizers/test_simple_tokenizer.py @@ -1,4 +1,5 @@ import pytest + from griptape.tokenizers import SimpleTokenizer diff --git a/tests/unit/tokenizers/test_voyageai_tokenizer.py b/tests/unit/tokenizers/test_voyageai_tokenizer.py index 46c9490f7e..d4529b3fa2 100644 --- a/tests/unit/tokenizers/test_voyageai_tokenizer.py +++ b/tests/unit/tokenizers/test_voyageai_tokenizer.py @@ -1,4 +1,5 @@ import pytest + from griptape.tokenizers import VoyageAiTokenizer diff --git a/tests/unit/tools/test_aws_iam.py b/tests/unit/tools/test_aws_iam.py index 2a256425e4..a399fcc544 100644 --- a/tests/unit/tools/test_aws_iam.py +++ b/tests/unit/tools/test_aws_iam.py @@ -1,7 +1,8 @@ +import boto3 from pytest import fixture + from griptape.tools import AwsIamClient from tests.utils.aws import mock_aws_credentials -import boto3 class TestAwsIamClient: diff --git a/tests/unit/tools/test_aws_s3.py b/tests/unit/tools/test_aws_s3.py index 6fe62b71c7..d75d5c7347 100644 --- a/tests/unit/tools/test_aws_s3.py +++ b/tests/unit/tools/test_aws_s3.py @@ -1,7 +1,8 @@ +import boto3 from pytest import fixture + from griptape.tools import AwsS3Client from tests.utils.aws import mock_aws_credentials -import boto3 class TestAwsS3Client: diff --git a/tests/unit/tools/test_base_tool.py b/tests/unit/tools/test_base_tool.py index 75154b5096..2c114cbd28 100644 --- a/tests/unit/tools/test_base_tool.py +++ b/tests/unit/tools/test_base_tool.py @@ -1,10 +1,12 @@ import inspect import os + import pytest import yaml -from schema import SchemaMissingKeyError, Schema, Or -from griptape.tasks import ActionsSubtask, ToolkitTask +from schema import Or, Schema, SchemaMissingKeyError + from griptape.common import ToolAction +from griptape.tasks import ActionsSubtask, ToolkitTask from tests.mocks.mock_tool.tool import MockTool from tests.utils import defaults @@ -195,9 +197,9 @@ def test_validate(self, tool): def test_invalid_config(self): try: - from tests.mocks.invalid_mock_tool.tool import InvalidMockTool # noqa + from tests.mocks.invalid_mock_tool.tool import InvalidMockTool # noqa: F401 - assert False + raise AssertionError() except SchemaMissingKeyError: assert True diff --git a/tests/unit/tools/test_computer.py b/tests/unit/tools/test_computer.py index b11fd080ec..96827b32c8 100644 --- a/tests/unit/tools/test_computer.py +++ b/tests/unit/tools/test_computer.py @@ -1,6 +1,7 @@ import pytest -from tests.mocks.docker.fake_api_client import make_fake_client + from griptape.tools import Computer +from tests.mocks.docker.fake_api_client import make_fake_client class TestComputer: diff --git a/tests/unit/tools/test_date_time.py b/tests/unit/tools/test_date_time.py index daa511f046..c534ae69bf 100644 --- a/tests/unit/tools/test_date_time.py +++ b/tests/unit/tools/test_date_time.py @@ -1,6 +1,7 @@ -from griptape.tools import DateTime from datetime import datetime +from griptape.tools import DateTime + class TestDateTime: def test_get_current_datetime(self): diff --git a/tests/unit/tools/test_email_client.py b/tests/unit/tools/test_email_client.py index 183d7b2658..081b62ca4a 100644 --- a/tests/unit/tools/test_email_client.py +++ b/tests/unit/tools/test_email_client.py @@ -1,8 +1,8 @@ -from griptape.artifacts import ErrorArtifact, InfoArtifact, ListArtifact +import pytest + +from griptape.artifacts import ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact from griptape.loaders.email_loader import EmailLoader -from griptape.artifacts import TextArtifact from griptape.tools import EmailClient -import pytest class TestEmailClient: diff --git a/tests/unit/tools/test_file_manager.py b/tests/unit/tools/test_file_manager.py index 06c49f32bb..ac1b3e3fe9 100644 --- a/tests/unit/tools/test_file_manager.py +++ b/tests/unit/tools/test_file_manager.py @@ -1,9 +1,11 @@ -import os.path import os +import os.path import tempfile from pathlib import Path + import pytest -from griptape.artifacts import TextArtifact, ListArtifact + +from griptape.artifacts import ListArtifact, TextArtifact from griptape.artifacts.error_artifact import ErrorArtifact from griptape.drivers.file_manager.local_file_manager_driver import LocalFileManagerDriver from griptape.loaders.text_loader import TextLoader diff --git a/tests/unit/tools/test_google_drive_client.py b/tests/unit/tools/test_google_drive_client.py index 5d2f62df7e..55f3c168ff 100644 --- a/tests/unit/tools/test_google_drive_client.py +++ b/tests/unit/tools/test_google_drive_client.py @@ -1,5 +1,5 @@ -from griptape.tools import GoogleDriveClient from griptape.artifacts import ErrorArtifact +from griptape.tools import GoogleDriveClient class TestGoogleDriveClient: diff --git a/tests/unit/tools/test_griptape_cloud_knowledge_base_client.py b/tests/unit/tools/test_griptape_cloud_knowledge_base_client.py index 9feba9cbfa..5bfd2ab1a5 100644 --- a/tests/unit/tools/test_griptape_cloud_knowledge_base_client.py +++ b/tests/unit/tools/test_griptape_cloud_knowledge_base_client.py @@ -1,6 +1,7 @@ import pytest from requests import exceptions -from griptape.artifacts import TextArtifact, ErrorArtifact + +from griptape.artifacts import ErrorArtifact, TextArtifact class TestGriptapeCloudKnowledgeBaseClient: @@ -75,10 +76,10 @@ def test_get_knowledge_base_description(self, client): def test_get_knowledge_base_description_error(self, client_no_description): exception_match_text = f"No description found for Knowledge Base {client_no_description.knowledge_base_id}. Please set a description, or manually set the `GriptapeCloudKnowledgeBaseClient.description` attribute." - with pytest.raises(ValueError, match=exception_match_text) as e: + with pytest.raises(ValueError, match=exception_match_text): client_no_description._get_knowledge_base_description() def test_get_knowledge_base_kb_error(self, client_kb_not_found): exception_match_text = f"Error accessing Knowledge Base {client_kb_not_found.knowledge_base_id}." - with pytest.raises(ValueError, match=exception_match_text) as e: + with pytest.raises(ValueError, match=exception_match_text): client_kb_not_found._get_knowledge_base_description() diff --git a/tests/unit/tools/test_inpainting_image_generation_client.py b/tests/unit/tools/test_inpainting_image_generation_client.py index 9e1d017bbe..1c48ce0fe4 100644 --- a/tests/unit/tools/test_inpainting_image_generation_client.py +++ b/tests/unit/tools/test_inpainting_image_generation_client.py @@ -53,7 +53,7 @@ def test_image_inpainting_with_outfile(self, image_generation_engine, image_load engine=image_generation_engine, output_file=outfile, image_loader=image_loader ) - image_generator.engine.run.return_value = Mock( # pyright: ignore + image_generator.engine.run.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess] value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" ) diff --git a/tests/unit/tools/test_openweather_client.py b/tests/unit/tools/test_openweather_client.py index 319a7ec2a8..cead389584 100644 --- a/tests/unit/tools/test_openweather_client.py +++ b/tests/unit/tools/test_openweather_client.py @@ -1,5 +1,7 @@ -import pytest from unittest.mock import patch + +import pytest + from griptape.artifacts import ErrorArtifact from griptape.tools import OpenWeatherClient @@ -10,7 +12,7 @@ def client(): class MockResponse: - def __init__(self, json_data, status_code): + def __init__(self, json_data, status_code) -> None: self.json_data = json_data self.status_code = status_code diff --git a/tests/unit/tools/test_outpainting_image_variation_client.py b/tests/unit/tools/test_outpainting_image_variation_client.py index 1a84018a49..45a51c7be1 100644 --- a/tests/unit/tools/test_outpainting_image_variation_client.py +++ b/tests/unit/tools/test_outpainting_image_variation_client.py @@ -53,7 +53,7 @@ def test_image_outpainting_with_outfile(self, image_generation_engine, image_loa engine=image_generation_engine, output_file=outfile, image_loader=image_loader ) - image_generator.engine.run.return_value = Mock( # pyright: ignore + image_generator.engine.run.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess] value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" ) diff --git a/tests/unit/tools/test_prompt_image_generation_client.py b/tests/unit/tools/test_prompt_image_generation_client.py index dffbb4239d..201d4086f8 100644 --- a/tests/unit/tools/test_prompt_image_generation_client.py +++ b/tests/unit/tools/test_prompt_image_generation_client.py @@ -36,7 +36,7 @@ def test_generate_image_with_outfile(self, image_generation_engine) -> None: outfile = f"{tempfile.gettempdir()}/{str(uuid.uuid4())}.png" image_generator = PromptImageGenerationClient(engine=image_generation_engine, output_file=outfile) - image_generator.engine.run.return_value = Mock( # pyright: ignore + image_generator.engine.run.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess] value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" ) diff --git a/tests/unit/tools/test_rest_api_client.py b/tests/unit/tools/test_rest_api_client.py index b937b9f235..83dc303504 100644 --- a/tests/unit/tools/test_rest_api_client.py +++ b/tests/unit/tools/test_rest_api_client.py @@ -1,4 +1,5 @@ import pytest + from griptape.artifacts import BaseArtifact diff --git a/tests/unit/tools/test_sql_client.py b/tests/unit/tools/test_sql_client.py index 6584fa752c..0262be2d80 100644 --- a/tests/unit/tools/test_sql_client.py +++ b/tests/unit/tools/test_sql_client.py @@ -1,8 +1,10 @@ +import sqlite3 + import pytest + from griptape.drivers import SqlDriver from griptape.loaders import SqlLoader from griptape.tools import SqlClient -import sqlite3 class TestSqlClient: diff --git a/tests/unit/tools/test_structure_run_client.py b/tests/unit/tools/test_structure_run_client.py index b57bfb28f0..746115b05b 100644 --- a/tests/unit/tools/test_structure_run_client.py +++ b/tests/unit/tools/test_structure_run_client.py @@ -1,7 +1,8 @@ import pytest + from griptape.drivers.structure_run.local_structure_run_driver import LocalStructureRunDriver -from griptape.tools import StructureRunClient from griptape.structures import Agent +from griptape.tools import StructureRunClient from tests.mocks.mock_prompt_driver import MockPromptDriver diff --git a/tests/unit/tools/test_task_memory_client.py b/tests/unit/tools/test_task_memory_client.py index 3956ae415a..9a55b49744 100644 --- a/tests/unit/tools/test_task_memory_client.py +++ b/tests/unit/tools/test_task_memory_client.py @@ -1,4 +1,5 @@ import pytest + from griptape.artifacts import TextArtifact from griptape.tools import TaskMemoryClient from tests.utils import defaults diff --git a/tests/unit/tools/test_text_to_speech_client.py b/tests/unit/tools/test_text_to_speech_client.py index 881b1234de..8d1ba78e68 100644 --- a/tests/unit/tools/test_text_to_speech_client.py +++ b/tests/unit/tools/test_text_to_speech_client.py @@ -32,7 +32,7 @@ def test_text_to_speech_with_outfile(self, text_to_speech_engine) -> None: outfile = f"{tempfile.gettempdir()}/{str(uuid.uuid4())}.mp3" text_to_speech_client = TextToSpeechClient(engine=text_to_speech_engine, output_file=outfile) - text_to_speech_client.engine.run.return_value = Mock(value=b"audio data", format="mp3") # pyright: ignore + text_to_speech_client.engine.run.return_value = Mock(value=b"audio data", format="mp3") # pyright: ignore[reportFunctionMemberAccess] audio_artifact = text_to_speech_client.text_to_speech(params={"values": {"text": "say this!"}}) diff --git a/tests/unit/tools/test_transcription_client.py b/tests/unit/tools/test_transcription_client.py index ea6bd3453a..94e00e333c 100644 --- a/tests/unit/tools/test_transcription_client.py +++ b/tests/unit/tools/test_transcription_client.py @@ -24,7 +24,7 @@ def test_init_transcription_client(self, transcription_engine, audio_loader) -> @patch("builtins.open", mock_open(read_data=b"audio data")) def test_transcribe_audio_from_disk(self, transcription_engine, audio_loader) -> None: client = AudioTranscriptionClient(engine=transcription_engine, audio_loader=audio_loader) - client.engine.run.return_value = Mock(value="transcription") # pyright: ignore + client.engine.run.return_value = Mock(value="transcription") # pyright: ignore[reportFunctionMemberAccess] text_artifact = client.transcribe_audio_from_disk(params={"values": {"path": "audio.wav"}}) @@ -37,7 +37,7 @@ def test_transcribe_audio_from_memory(self, transcription_engine, audio_loader) memory.load_artifacts = Mock(return_value=[AudioArtifact(value=b"audio data", format="wav", name="name")]) client.find_input_memory = Mock(return_value=memory) - client.engine.run.return_value = Mock(value="transcription") # pyright: ignore + client.engine.run.return_value = Mock(value="transcription") # pyright: ignore[reportFunctionMemberAccess] text_artifact = client.transcribe_audio_from_memory( params={"values": {"memory_name": "memory", "artifact_namespace": "namespace", "artifact_name": "name"}} diff --git a/tests/unit/tools/test_variation_image_generation_client.py b/tests/unit/tools/test_variation_image_generation_client.py index b29f4fecf7..bd824a95c3 100644 --- a/tests/unit/tools/test_variation_image_generation_client.py +++ b/tests/unit/tools/test_variation_image_generation_client.py @@ -54,7 +54,7 @@ def test_image_variation_with_outfile(self, image_generation_engine, image_loade engine=image_generation_engine, output_file=outfile, image_loader=image_loader ) - image_generator.engine.run.return_value = Mock( # pyright: ignore + image_generator.engine.run.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess] value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" ) diff --git a/tests/unit/tools/test_vector_store_client.py b/tests/unit/tools/test_vector_store_client.py index 45018b847c..2e9c4037b1 100644 --- a/tests/unit/tools/test_vector_store_client.py +++ b/tests/unit/tools/test_vector_store_client.py @@ -1,5 +1,6 @@ import pytest -from griptape.artifacts import TextArtifact, ListArtifact + +from griptape.artifacts import ListArtifact, TextArtifact from griptape.drivers import LocalVectorStoreDriver from griptape.tools import VectorStoreClient from tests.mocks.mock_embedding_driver import MockEmbeddingDriver @@ -16,7 +17,7 @@ def test_search(self): driver.upsert_text_artifacts({"test": [TextArtifact("foo"), TextArtifact("bar")]}) - assert set([a.value for a in tool.search({"values": {"query": "test"}})]) == {"foo", "bar"} + assert {a.value for a in tool.search({"values": {"query": "test"}})} == {"foo", "bar"} def test_search_with_namespace(self): driver = LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver()) diff --git a/tests/unit/tools/test_web_scraper.py b/tests/unit/tools/test_web_scraper.py index f46004e8fb..9ba278586c 100644 --- a/tests/unit/tools/test_web_scraper.py +++ b/tests/unit/tools/test_web_scraper.py @@ -1,4 +1,5 @@ import pytest + from griptape.artifacts import ListArtifact diff --git a/tests/unit/tools/test_web_search.py b/tests/unit/tools/test_web_search.py index 0abc880c8b..a2b488ff9e 100644 --- a/tests/unit/tools/test_web_search.py +++ b/tests/unit/tools/test_web_search.py @@ -1,4 +1,5 @@ import pytest + from griptape.artifacts import BaseArtifact, ErrorArtifact, TextArtifact from griptape.tools import WebSearch diff --git a/tests/unit/utils/test_base_tokenizer.py b/tests/unit/utils/test_base_tokenizer.py index eed15b9b2b..08fd42c726 100644 --- a/tests/unit/utils/test_base_tokenizer.py +++ b/tests/unit/utils/test_base_tokenizer.py @@ -1,4 +1,5 @@ import logging + from tests.mocks.mock_tokenizer import MockTokenizer diff --git a/tests/unit/utils/test_command_runner.py b/tests/unit/utils/test_command_runner.py index 4ca3afebc0..25b7fd8c3e 100644 --- a/tests/unit/utils/test_command_runner.py +++ b/tests/unit/utils/test_command_runner.py @@ -1,4 +1,3 @@ -from griptape.artifacts import TextArtifact from griptape.utils import CommandRunner diff --git a/tests/unit/utils/test_conversation.py b/tests/unit/utils/test_conversation.py index cce067f730..28ee72409a 100644 --- a/tests/unit/utils/test_conversation.py +++ b/tests/unit/utils/test_conversation.py @@ -1,8 +1,8 @@ -from tests.mocks.mock_prompt_driver import MockPromptDriver from griptape.memory.structure import ConversationMemory, SummaryConversationMemory -from griptape.tasks import PromptTask from griptape.structures import Pipeline +from griptape.tasks import PromptTask from griptape.utils import Conversation +from tests.mocks.mock_prompt_driver import MockPromptDriver class TestConversation: diff --git a/tests/unit/utils/test_deprecate.py b/tests/unit/utils/test_deprecate.py index 0c8064f8d6..868dbd60fc 100644 --- a/tests/unit/utils/test_deprecate.py +++ b/tests/unit/utils/test_deprecate.py @@ -1,4 +1,5 @@ import pytest + from griptape.utils.deprecation import deprecation_warn diff --git a/tests/unit/utils/test_dict_utils.py b/tests/unit/utils/test_dict_utils.py index 4b4e4ca087..94e870e1a8 100644 --- a/tests/unit/utils/test_dict_utils.py +++ b/tests/unit/utils/test_dict_utils.py @@ -1,6 +1,7 @@ -from griptape.utils import remove_null_values_in_dict_recursively, dict_merge, remove_key_in_dict_recursively import pytest +from griptape.utils import dict_merge, remove_key_in_dict_recursively, remove_null_values_in_dict_recursively + class TestDictUtils: def test_remove_null_values_in_dict_recursively(self): diff --git a/tests/unit/utils/test_file_utils.py b/tests/unit/utils/test_file_utils.py index dbcf1044b9..de1882ef55 100644 --- a/tests/unit/utils/test_file_utils.py +++ b/tests/unit/utils/test_file_utils.py @@ -1,7 +1,8 @@ import os -from griptape.loaders import TextLoader -from griptape import utils from concurrent import futures + +from griptape import utils +from griptape.loaders import TextLoader from tests.mocks.mock_embedding_driver import MockEmbeddingDriver MAX_TOKENS = 50 diff --git a/tests/unit/utils/test_futures.py b/tests/unit/utils/test_futures.py index 5e30148a91..04ddb98778 100644 --- a/tests/unit/utils/test_futures.py +++ b/tests/unit/utils/test_futures.py @@ -1,4 +1,5 @@ from concurrent import futures + from griptape import utils diff --git a/tests/unit/utils/test_import_utils.py b/tests/unit/utils/test_import_utils.py index bcfb06c875..f6b2429d97 100644 --- a/tests/unit/utils/test_import_utils.py +++ b/tests/unit/utils/test_import_utils.py @@ -1,4 +1,5 @@ import pytest + from griptape.utils import import_optional_dependency, is_dependency_installed diff --git a/tests/unit/utils/test_load_artifact_from_memory.py b/tests/unit/utils/test_load_artifact_from_memory.py index db4f7d573c..c303f0954d 100644 --- a/tests/unit/utils/test_load_artifact_from_memory.py +++ b/tests/unit/utils/test_load_artifact_from_memory.py @@ -2,7 +2,7 @@ import pytest -from griptape.artifacts import TextArtifact, ErrorArtifact, ImageArtifact +from griptape.artifacts import ImageArtifact, TextArtifact from griptape.utils import load_artifact_from_memory @@ -21,7 +21,7 @@ def image_artifact(self): def test_no_memory(self): with pytest.raises(ValueError): - load_artifact_from_memory(None, "", "", TextArtifact) # pyright: ignore + load_artifact_from_memory(None, "", "", TextArtifact) # pyright: ignore[reportArgumentType] def test_no_artifacts_in_memory(self, memory): memory.load_artifacts.return_value = [] diff --git a/tests/unit/utils/test_stream.py b/tests/unit/utils/test_stream.py index 33c97cc757..e223a5c259 100644 --- a/tests/unit/utils/test_stream.py +++ b/tests/unit/utils/test_stream.py @@ -1,5 +1,7 @@ -from typing import Iterator +from collections.abc import Iterator + import pytest + from griptape.structures import Agent from griptape.utils import Stream from tests.mocks.mock_prompt_driver import MockPromptDriver diff --git a/tests/unit/utils/test_structure_visualizer.py b/tests/unit/utils/test_structure_visualizer.py index e16275a5c4..f6e621b915 100644 --- a/tests/unit/utils/test_structure_visualizer.py +++ b/tests/unit/utils/test_structure_visualizer.py @@ -1,7 +1,7 @@ -from tests.mocks.mock_prompt_driver import MockPromptDriver -from griptape.utils import StructureVisualizer +from griptape.structures import Agent, Pipeline, Workflow from griptape.tasks import PromptTask -from griptape.structures import Agent, Workflow, Pipeline +from griptape.utils import StructureVisualizer +from tests.mocks.mock_prompt_driver import MockPromptDriver class TestStructureVisualizer: diff --git a/tests/utils/code_blocks.py b/tests/utils/code_blocks.py index 9cfebb9877..ca5b193d1e 100644 --- a/tests/utils/code_blocks.py +++ b/tests/utils/code_blocks.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import pathlib import textwrap @@ -6,7 +8,7 @@ def check_py_string(source: str) -> None: - """Exec the python source given in a new module namespace + """Exec the python source given in a new module namespace. Does not return anything, but exceptions raised by the source will propagate out unmodified diff --git a/tests/utils/defaults.py b/tests/utils/defaults.py index 5a9f6f958f..bad7f0d797 100644 --- a/tests/utils/defaults.py +++ b/tests/utils/defaults.py @@ -1,11 +1,11 @@ -from griptape.artifacts import TextArtifact, BlobArtifact +from griptape.artifacts import BlobArtifact, TextArtifact from griptape.drivers import LocalVectorStoreDriver -from griptape.engines import PromptSummaryEngine, CsvExtractionEngine, JsonExtractionEngine +from griptape.engines import CsvExtractionEngine, JsonExtractionEngine, PromptSummaryEngine from griptape.engines.rag import RagEngine -from griptape.engines.rag.modules import VectorStoreRetrievalRagModule, PromptResponseRagModule -from griptape.engines.rag.stages import RetrievalRagStage, ResponseRagStage +from griptape.engines.rag.modules import PromptResponseRagModule, VectorStoreRetrievalRagModule +from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage from griptape.memory import TaskMemory -from griptape.memory.task.storage import TextArtifactStorage, BlobArtifactStorage +from griptape.memory.task.storage import BlobArtifactStorage, TextArtifactStorage from tests.mocks.mock_embedding_driver import MockEmbeddingDriver from tests.mocks.mock_prompt_driver import MockPromptDriver diff --git a/tests/utils/postgres.py b/tests/utils/postgres.py index 1e04153bd1..a320d9a053 100644 --- a/tests/utils/postgres.py +++ b/tests/utils/postgres.py @@ -1,4 +1,4 @@ -from psycopg2 import connect, OperationalError +from psycopg2 import OperationalError, connect def can_connect_to_postgres(user="postgres", password="postgres", host="localhost", port="5432", database="postgres"): diff --git a/tests/utils/structure_tester.py b/tests/utils/structure_tester.py index 8d62bc835a..0abdda2767 100644 --- a/tests/utils/structure_tester.py +++ b/tests/utils/structure_tester.py @@ -1,25 +1,26 @@ from __future__ import annotations -import os -from attrs import field, define -from schema import Schema, Literal -import logging + import json -from griptape.artifacts.error_artifact import ErrorArtifact +import logging +import os -from griptape.structures import Agent -from griptape.rules import Rule, Ruleset -from griptape.tasks import PromptTask -from griptape.structures import Structure +from attrs import define, field +from schema import Literal, Schema + +from griptape.artifacts.error_artifact import ErrorArtifact from griptape.drivers import ( - BasePromptDriver, AmazonBedrockPromptDriver, + AmazonSageMakerJumpstartPromptDriver, AnthropicPromptDriver, - CoherePromptDriver, - OpenAiChatPromptDriver, AzureOpenAiChatPromptDriver, - AmazonSageMakerJumpstartPromptDriver, + BasePromptDriver, + CoherePromptDriver, GooglePromptDriver, + OpenAiChatPromptDriver, ) +from griptape.rules import Rule, Ruleset +from griptape.structures import Agent, Structure +from griptape.tasks import PromptTask def get_enabled_prompt_drivers(prompt_drivers_options) -> list[BasePromptDriver]: