diff --git a/CHANGELOG.md b/CHANGELOG.md index 8a88fb709f..daa6b674b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Updated `EventListener.handler` return value behavior. - If `EventListener.handler` returns `None`, the event will not be published to the `event_listener_driver`. - If `EventListener.handler` is None, the event will be published to the `event_listener_driver` as-is. +- **BREAKING**: `RestApiTool` now returns a `JsonArtifact` instead of a `TextArtifact`. +- **BREAKING**: Removed `RestApiTool.response_body`, `RestApiTool.request_path_params_schema`. +- **BREAKING**: Changed `RestApiTool` fields from `str` to `dict`: + - `RestApiTool.request_query_params_schema` (`dict`) + - `RestApiTool.request_body_schema` (`dict`) - Updated `EventListener.handler` return type to `Optional[BaseEvent | dict]`. - `BaseTask.parent_outputs` type has changed from `dict[str, str | None]` to `dict[str, BaseArtifact]`. - `Workflow.context["parent_outputs"]` type has changed from `dict[str, str | None]` to `dict[str, BaseArtifact]`. @@ -43,6 +48,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Structures not flushing events when not listening for `FinishStructureRunEvent`. - `EventListener.event_types` and the argument to `BaseEventListenerDriver.handler` being out of sync. +- `RestApiTool` failing with native tool calling due to schemas being in schema description. ## \[0.33.1\] - 2024-10-11 diff --git a/MIGRATION.md b/MIGRATION.md index 956611de8c..7db0da8fbf 100644 --- a/MIGRATION.md +++ b/MIGRATION.md @@ -49,6 +49,117 @@ EventListener(handler=handler_fn_return_dict, event_listener_driver=driver) EventListener(handler=handler_fn_return_base_event, event_listener_driver=driver) ``` +### Changed `RestApiTool` fields + +The following fields are no longer stringified full schemas. They are now just the properties of the schema. + +- `RestApiTool.request_path_params_schema` (`list`) +- `RestApiTool.request_query_params_schema` (`dict`) +- `RestApiTool.request_body_schema` (`dict`) + +#### Before +```python +posts_client = RestApiTool( + base_url="https://jsonplaceholder.typicode.com", + path="posts", + description="Allows for creating, updating, deleting, patching, and getting posts.", + request_body_schema=dumps( + { + "$schema": "https://json-schema.org/draft/2019-09/schema", + "$id": "http://example.com/example.json", + "type": "object", + "default": {}, + "title": "Root Schema", + "required": ["title", "body", "userId"], + "properties": { + "title": { + "type": "string", + "default": "", + "title": "The title Schema", + }, + "body": { + "type": "string", + "default": "", + "title": "The body Schema", + }, + "userId": { + "type": "integer", + "default": 0, + "title": "The userId Schema", + }, + }, + } + ), + request_query_params_schema=dumps( + { + "$schema": "https://json-schema.org/draft/2019-09/schema", + "$id": "http://example.com/example.json", + "type": "object", + "default": {}, + "title": "Root Schema", + "required": ["userId"], + "properties": { + "userId": { + "type": "string", + "default": "", + "title": "The userId Schema", + }, + }, + } + ), + response_body_schema=dumps( + { + "$schema": "https://json-schema.org/draft/2019-09/schema", + "$id": "http://example.com/example.json", + "type": "object", + "default": {}, + "title": "Root Schema", + "required": ["id", "title", "body", "userId"], + "properties": { + "id": { + "type": "integer", + "default": 0, + "title": "The id Schema", + }, + "title": { + "type": "string", + "default": "", + "title": "The title Schema", + }, + "body": { + "type": "string", + "default": "", + "title": "The body Schema", + }, + "userId": { + "type": "integer", + "default": 0, + "title": "The userId Schema", + }, + }, + } + ), +) +``` + +#### After +```python +posts_client = RestApiTool( + base_url="https://jsonplaceholder.typicode.com", + path="posts", + description="Allows for creating, updating, deleting, patching, and getting posts.", + request_body_schema={ + "title": str, + "body": str, + "userId": int, + }, + request_query_params_schema={ + "userId": str, + }, +) +``` + + ## 0.32.X to 0.33.X ### Removed `DataframeLoader` diff --git a/docs/griptape-tools/official-tools/src/rest_api_tool_1.py b/docs/griptape-tools/official-tools/src/rest_api_tool_1.py index 3f6b3b6631..a7ce21e1a3 100644 --- a/docs/griptape-tools/official-tools/src/rest_api_tool_1.py +++ b/docs/griptape-tools/official-tools/src/rest_api_tool_1.py @@ -1,5 +1,3 @@ -from json import dumps - from griptape.configs import Defaults from griptape.configs.drivers import DriversConfig from griptape.drivers import OpenAiChatPromptDriver @@ -15,100 +13,21 @@ posts_client = RestApiTool( base_url="https://jsonplaceholder.typicode.com", path="posts", - description="Allows for creating, updating, deleting, patching, and getting posts.", - request_body_schema=dumps( - { - "$schema": "https://json-schema.org/draft/2019-09/schema", - "$id": "http://example.com/example.json", - "type": "object", - "default": {}, - "title": "Root Schema", - "required": ["title", "body", "userId"], - "properties": { - "title": { - "type": "string", - "default": "", - "title": "The title Schema", - }, - "body": { - "type": "string", - "default": "", - "title": "The body Schema", - }, - "userId": { - "type": "integer", - "default": 0, - "title": "The userId Schema", - }, - }, - } - ), - request_query_params_schema=dumps( - { - "$schema": "https://json-schema.org/draft/2019-09/schema", - "$id": "http://example.com/example.json", - "type": "object", - "default": {}, - "title": "Root Schema", - "required": ["userId"], - "properties": { - "userId": { - "type": "string", - "default": "", - "title": "The userId Schema", - }, - }, - } - ), - request_path_params_schema=dumps( - { - "$schema": "https://json-schema.org/draft/2019-09/schema", - "$id": "http://example.com/example.json", - "type": "array", - "default": [], - "title": "Root Schema", - "items": { - "anyOf": [ - { - "type": "string", - "title": "Post id", - }, - ] - }, - } - ), - response_body_schema=dumps( - { - "$schema": "https://json-schema.org/draft/2019-09/schema", - "$id": "http://example.com/example.json", - "type": "object", - "default": {}, - "title": "Root Schema", - "required": ["id", "title", "body", "userId"], - "properties": { - "id": { - "type": "integer", - "default": 0, - "title": "The id Schema", - }, - "title": { - "type": "string", - "default": "", - "title": "The title Schema", - }, - "body": { - "type": "string", - "default": "", - "title": "The body Schema", - }, - "userId": { - "type": "integer", - "default": 0, - "title": "The userId Schema", - }, - }, - } - ), + description="Allows for creating, updating, deleting, patching, and getting posts. Can also be used to access subresources.", + request_body_schema={ + "title": str, + "body": str, + "userId": int, + }, +) + +comments_client = RestApiTool( + base_url="https://jsonplaceholder.typicode.com", + path="comments", + description="Allows for getting comments for a post.", + request_query_params_schema={ + "postId": str, + }, ) pipeline = Pipeline( @@ -140,6 +59,10 @@ "Output the body of all the comments for post 1.", tools=[posts_client], ), + ToolkitTask( + "Get the comments for post 1.", + tools=[comments_client], + ), ) pipeline.run() diff --git a/griptape/tools/rest_api/tool.py b/griptape/tools/rest_api/tool.py index f5e233e576..1257aaf22c 100644 --- a/griptape/tools/rest_api/tool.py +++ b/griptape/tools/rest_api/tool.py @@ -8,7 +8,7 @@ from attrs import define, field from schema import Literal, Schema -from griptape.artifacts import BaseArtifact, ErrorArtifact, TextArtifact +from griptape.artifacts import BaseArtifact, ErrorArtifact, JsonArtifact from griptape.tools import BaseTool from griptape.utils.decorators import activity @@ -24,17 +24,14 @@ class RestApiTool(BaseTool): request_body_schema: A JSON schema string describing the request body. Recommended for PUT, POST, and PATCH requests. request_query_params_schema: A JSON schema string describing the available query parameters. request_path_params_schema: A JSON schema string describing the available path parameters. The schema must describe an array of string values. - response_body_schema: A JSON schema string describing the response body. request_headers: Headers to include in the requests. """ base_url: str = field(kw_only=True) path: Optional[str] = field(default=None, kw_only=True) description: str = field(kw_only=True) - request_path_params_schema: Optional[str] = field(default=None, kw_only=True) - request_query_params_schema: Optional[str] = field(default=None, kw_only=True) - request_body_schema: Optional[str] = field(default=None, kw_only=True) - response_body_schema: Optional[str] = field(default=None, kw_only=True) + request_query_params_schema: Optional[dict] = field(default=None, kw_only=True) + request_body_schema: Optional[dict] = field(default=None, kw_only=True) request_headers: Optional[dict[str, str]] = field(default=None, kw_only=True) @property @@ -47,11 +44,13 @@ def full_url(self) -> str: """ This tool can be used to make a put request to the rest api url: {{ _self.full_url }} This rest api has the following description: {{ _self.description }} - {% if _self.request_body_schema %}The request body must follow this JSON schema: {{ _self.request_body_schema }}{% endif %} - {% if _self.response_body_schema %}The response body must follow this JSON schema: {{ _self.response_body_schema }}{% endif %} """, ), - "schema": Schema({Literal("body", description="The request body."): dict}), + "schema": lambda _self: Schema( + { + Literal("body", description="The request body."): Schema(_self.request_body_schema), + } + ), }, ) def put(self, params: dict) -> BaseArtifact: @@ -66,7 +65,7 @@ def put(self, params: dict) -> BaseArtifact: try: response = put(url, json=body, timeout=30, headers=self.request_headers) - return TextArtifact(response.text) + return JsonArtifact(response.json(), meta={"status_code": response.status_code}) except exceptions.RequestException as err: return ErrorArtifact(str(err)) @@ -76,15 +75,12 @@ def put(self, params: dict) -> BaseArtifact: """ This tool can be used to make a patch request to the rest api url: {{ _self.full_url }} This rest api has the following description: {{ _self.description }} - {% if _self.request_path_parameters %}The request path parameters must follow this JSON schema: {{ _self.request_path_params_schema }}{% endif %} - {% if _self.request_body_schema %}The request body must follow this JSON schema: {{ _self.request_body_schema }}{% endif %} - {% if _self.response_body_schema %}The response body must follow this JSON schema: {{ _self.response_body_schema }}{% endif %} """, ), - "schema": Schema( + "schema": lambda _self: Schema( { - Literal("path_params", description="The request path parameters."): Schema([str]), - Literal("body", description="The request body."): dict, + Literal("path_params", description="The request path parameters."): [str], + Literal("body", description="The request body."): Schema(_self.request_body_schema), }, ), }, @@ -101,7 +97,7 @@ def patch(self, params: dict) -> BaseArtifact: try: response = patch(url, json=body, timeout=30, headers=self.request_headers) - return TextArtifact(response.text) + return JsonArtifact(response.json(), meta={"status_code": response.status_code}) except exceptions.RequestException as err: return ErrorArtifact(str(err)) @@ -111,11 +107,13 @@ def patch(self, params: dict) -> BaseArtifact: """ This tool can be used to make a post request to the rest api url: {{ _self.full_url }} This rest api has the following description: {{ _self.description }} - {% if _self.request_body_schema %}The request body must follow this JSON schema: {{ _self.request_body_schema }}{% endif %} - {% if _self.response_body_schema %}The response body must follow this JSON schema: {{ _self.response_body_schema }}{% endif %} """, ), - "schema": Schema({Literal("body", description="The request body."): dict}), + "schema": lambda _self: Schema( + { + Literal("body", description="The request body."): schema.Schema(_self.request_body_schema), + } + ), }, ) def post(self, params: dict) -> BaseArtifact: @@ -129,7 +127,7 @@ def post(self, params: dict) -> BaseArtifact: try: response = post(url, json=body, timeout=30, headers=self.request_headers) - return TextArtifact(response.text) + return JsonArtifact(response.json(), meta={"status_code": response.status_code}) except exceptions.RequestException as err: return ErrorArtifact(str(err)) @@ -139,20 +137,15 @@ def post(self, params: dict) -> BaseArtifact: """ This tool can be used to make a get request to the rest api url: {{ _self.full_url }} This rest api has the following description: {{ _self.description }} - {% if _self.request_path_parameters %}The request path parameters must follow this JSON schema: {{ _self.request_path_params_schema }}{% endif %} - {% if _self.request_query_parameters %}The request query parameters must follow this JSON schema: {{ _self.request_path_params_schema }}{% endif %} - {% if _self.response_body_schema %}The response body must follow this JSON schema: {{ _self.response_body_schema }}{% endif %} """, ), - "schema": schema.Optional( - Schema( - { - schema.Optional(Literal("query_params", description="The request query parameters.")): dict, - schema.Optional(Literal("path_params", description="The request path parameters.")): Schema( - [str] - ), - }, - ), + "schema": lambda _self: Schema( + { + schema.Optional(Literal("query_params", description="The request query parameters.")): Schema( + _self.request_query_params_schema + ), + schema.Optional(Literal("path_params", description="The request path parameters.")): [str], + }, ), }, ) @@ -172,7 +165,7 @@ def get(self, params: dict) -> BaseArtifact: try: response = get(url, params=query_params, timeout=30, headers=self.request_headers) - return TextArtifact(response.text) + return JsonArtifact(response.json(), meta={"status_code": response.status_code}) except exceptions.RequestException as err: return ErrorArtifact(str(err)) @@ -182,14 +175,14 @@ def get(self, params: dict) -> BaseArtifact: """ This tool can be used to make a delete request to the rest api url: {{ _self.full_url }} This rest api has the following description: {{ _self.description }} - {% if _self.request_path_parameters %}The request path parameters must follow this JSON schema: {{ _self.request_path_params_schema }}{% endif %} - {% if _self.request_query_parameters %}The request query parameters must follow this JSON schema: {{ _self.request_path_params_schema }}{% endif %} """, ), - "schema": Schema( + "schema": lambda _self: Schema( { - schema.Optional(Literal("query_params", description="The request query parameters.")): dict, - schema.Optional(Literal("path_params", description="The request path parameters.")): Schema([str]), + schema.Optional(Literal("query_params", description="The request query parameters.")): Schema( + _self.request_query_params_schema + ), + schema.Optional(Literal("path_params", description="The request path parameters.")): [str], }, ), }, @@ -207,7 +200,7 @@ def delete(self, params: dict) -> BaseArtifact: try: response = delete(url, params=query_params, timeout=30, headers=self.request_headers) - return TextArtifact(response.text) + return JsonArtifact(response.json(), meta={"status_code": response.status_code}) except exceptions.RequestException as err: return ErrorArtifact(str(err)) diff --git a/tests/unit/tools/test_rest_api_tool.py b/tests/unit/tools/test_rest_api_tool.py index 70d63478ee..6dd3d19cfc 100644 --- a/tests/unit/tools/test_rest_api_tool.py +++ b/tests/unit/tools/test_rest_api_tool.py @@ -1,38 +1,79 @@ import pytest -from griptape.artifacts import BaseArtifact - class TestRestApi: @pytest.fixture() - def client(self): + def client(self, mocker): from griptape.tools import RestApiTool + mock_return_value = {"value": "foo bar"} + + mock_response = mocker.Mock() + mock_response.status_code = 201 + mock_response.json.return_value = mock_return_value + mocker.patch("requests.put", return_value=mock_response) + + mock_response = mocker.Mock() + mock_response.status_code = 201 + mock_response.json.return_value = mock_return_value + mocker.patch("requests.post", return_value=mock_response) + + mock_response = mocker.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = mock_return_value + mocker.patch("requests.get", return_value=mock_response) + + mock_response = mocker.Mock() + mock_response.status_code = 204 + mock_response.json.return_value = mock_return_value + mocker.patch("requests.delete", return_value=mock_response) + + mock_response = mocker.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = mock_return_value + mocker.patch("requests.patch", return_value=mock_response) + return RestApiTool(base_url="http://www.griptape.ai", description="Griptape website.") def test_put(self, client): - assert isinstance(client.post({"values": {"body": {}}}), BaseArtifact) + response = client.put({"values": {"body": {}}}) + assert response.value == {"value": "foo bar"} + assert response.meta["status_code"] == 201 def test_post(self, client): - assert isinstance(client.post({"values": {"body": {}}}), BaseArtifact) + response = client.post({"values": {"body": {}}}) + assert response.value == {"value": "foo bar"} + assert response.meta["status_code"] == 201 def test_get_one(self, client): - assert isinstance(client.get({"values": {"path_params": ["1"]}}), BaseArtifact) + response = client.get({"values": {"path_params": ["1"]}}) + assert response.value == {"value": "foo bar"} + assert response.meta["status_code"] == 200 def test_get_all(self, client): - assert isinstance(client.get({"values": {}}), BaseArtifact) + response = client.get({"values": {}}) + assert response.value == {"value": "foo bar"} + assert response.meta["status_code"] == 200 def test_get_filtered(self, client): - assert isinstance(client.get({"values": {"query_params": {"limit": 10}}}), BaseArtifact) + response = client.get({"values": {"query_params": {"limit": 10}}}) + assert response.value == {"value": "foo bar"} + assert response.meta["status_code"] == 200 def test_delete_one(self, client): - assert isinstance(client.delete({"values": {"path_params": ["1"]}}), BaseArtifact) + response = client.delete({"values": {"path_params": ["1"]}}) + assert response.value == {"value": "foo bar"} + assert response.meta["status_code"] == 204 def test_delete_multiple(self, client): - assert isinstance(client.delete({"values": {"query_params": {"ids": [1, 2]}}}), BaseArtifact) + response = client.delete({"values": {"query_params": {"ids": [1, 2]}}}) + assert response.value == {"value": "foo bar"} + assert response.meta["status_code"] == 204 def test_patch(self, client): - assert isinstance(client.patch({"values": {"path_params": ["1"], "body": {}}}), BaseArtifact) + response = client.patch({"values": {"path_params": ["1"], "body": {}}}) + assert response.value == {"value": "foo bar"} + assert response.meta["status_code"] == 200 def test_build_url(self, client): url = client._build_url("https://foo.bar")