-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
360 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from typing import AsyncIterator | ||
from aact.messages.base import Message | ||
from aact.messages.commons import ( | ||
RestRequest, | ||
RestResponse, | ||
Tick, | ||
get_rest_request_class, | ||
get_rest_response_class, | ||
AnyDataModel, | ||
) | ||
from aact.nodes import Node, NodeFactory | ||
|
||
|
||
@NodeFactory.register("api_client") | ||
class APIClient(Node[Tick | RestResponse, RestRequest]): | ||
def __init__( | ||
self, | ||
input_tick_channel: str, | ||
input_response_channel: str, | ||
output_channel: str, | ||
redis_url: str, | ||
): | ||
response_class = get_rest_response_class(AnyDataModel) | ||
request_class = get_rest_request_class(AnyDataModel) | ||
super().__init__( | ||
input_channel_types=[ | ||
(input_tick_channel, Tick), | ||
(input_response_channel, response_class), | ||
], | ||
output_channel_types=[(output_channel, request_class)], | ||
redis_url=redis_url, | ||
) | ||
|
||
self.request_class = request_class | ||
self.response_class = response_class | ||
self.input_tick_channel = input_tick_channel | ||
self.input_response_channel = input_response_channel | ||
self.output_channel = output_channel | ||
self.output_message_type: type[Message[RestRequest]] = Message[request_class] # type: ignore[valid-type, assignment] | ||
|
||
async def event_handler( | ||
self, channel: str, message: Message[RestResponse | Tick] | ||
) -> AsyncIterator[tuple[str, Message[RestRequest]]]: | ||
if channel == self.input_response_channel: | ||
print("Received response: ", message.data) | ||
elif channel == self.input_tick_channel: | ||
yield ( | ||
self.output_channel, | ||
self.output_message_type( | ||
data=self.request_class( | ||
method="POST", | ||
url="http://0.0.0.0:8080/spotify/auth/token", | ||
data={"username": "test", "password": "test"}, | ||
content_type="application/x-www-form-urlencoded", | ||
) | ||
), | ||
) | ||
else: | ||
raise ValueError(f"Invalid channel: {channel}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
redis_url = "redis://localhost:6379/0" | ||
extra_modules = ["examples.api_node_examples.api_client"] | ||
|
||
[[nodes]] | ||
node_name = "api_client" | ||
node_class = "api_client" | ||
|
||
[nodes.node_args] | ||
input_tick_channel = "tick/secs/1" | ||
input_response_channel = "rest_response" | ||
output_channel = "rest_request" | ||
|
||
|
||
[[nodes]] | ||
node_name = "api_node" | ||
node_class = "rest_api" | ||
|
||
[nodes.node_args] | ||
input_channel = "rest_request" | ||
output_channel = "rest_response" | ||
input_type_str = "any" | ||
output_type_str = "any" | ||
|
||
|
||
[[nodes]] | ||
node_name = "tick" | ||
node_class = "tick" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,6 +57,7 @@ dev-dependencies = [ | |
"types-pyaudio", | ||
"types-aiofiles", | ||
"types-requests", | ||
"pytest" | ||
] | ||
|
||
[tool.mypy] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import asyncio | ||
import json | ||
import logging | ||
from typing import AsyncIterator, TypeVar | ||
from . import Node, NodeFactory | ||
|
||
from ..messages.base import Message | ||
from ..messages.commons import ( | ||
RestRequest, | ||
RestResponse, | ||
get_rest_response_class, | ||
get_rest_request_class, | ||
) | ||
|
||
import aiohttp | ||
|
||
from ..messages.registry import DataModelFactory | ||
|
||
T = TypeVar("T", bound=RestResponse) | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
async def _parse_response( | ||
response: aiohttp.ClientResponse, response_class: type[T] | ||
) -> T: | ||
status_code = response.status | ||
|
||
try: | ||
if response.content_type == "application/json": | ||
json_data = await response.json() | ||
# Only try to parse data into model if we have a successful response | ||
if 200 <= status_code < 300 and json_data: | ||
return response_class(status_code=status_code, data=json_data) | ||
except (json.JSONDecodeError, ValueError) as e: | ||
# Log error here if needed | ||
logger.error(f"Error parsing response: {e}") | ||
pass | ||
|
||
logger.warning(f"{response}") | ||
return response_class(status_code=status_code, data=None) | ||
|
||
|
||
@NodeFactory.register("rest_api") | ||
class RestAPINode(Node[RestRequest, RestResponse]): | ||
""" | ||
A node that sends a REST request to a given URL and sends the response to an output channel. | ||
Args: | ||
- `input_channel`: The input channel to listen for RestRequest messages. | ||
- `output_channel`: The output channel to send RestResponse messages. | ||
- `output_type_str`: The string identifier of the DataModel to parse the response into. | ||
- `redis_url`: The URL of the Redis server to connect to. | ||
REST API Node Example: | ||
```toml | ||
[[nodes]] | ||
node_name = "rest_api" | ||
node_class = "rest_api" | ||
[[nodes.node_args]] | ||
input_channel = "rest_request" | ||
output_channel = "rest_response" | ||
output_type_str = "float" | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
input_channel: str, | ||
output_channel: str, | ||
input_type_str: str, | ||
output_type_str: str, | ||
redis_url: str, | ||
): | ||
if input_type_str not in DataModelFactory.registry: | ||
raise ValueError( | ||
f"DataModel {input_type_str} is used as the input data model, but not found in registry" | ||
) | ||
|
||
if output_type_str not in DataModelFactory.registry: | ||
raise ValueError( | ||
f"DataModel {output_type_str} is used as the output data model, but not found in registry" | ||
) | ||
|
||
request_class = get_rest_request_class( | ||
DataModelFactory.registry[input_type_str] | ||
) | ||
response_data_class = DataModelFactory.registry[output_type_str] | ||
response_class = get_rest_response_class(response_data_class) | ||
|
||
super().__init__( | ||
input_channel_types=[(input_channel, request_class)], | ||
output_channel_types=[(output_channel, response_class)], | ||
redis_url=redis_url, | ||
) | ||
|
||
self.input_channel = input_channel | ||
self.output_channel = output_channel | ||
self.shutdown_event: asyncio.Event = asyncio.Event() | ||
self.request_class = request_class | ||
self.response_class = response_class | ||
self.response_data_class = response_data_class | ||
|
||
async def __aenter__(self) -> "RestAPINode": | ||
await super().__aenter__() | ||
return self | ||
|
||
async def handle_request(self, message: RestRequest) -> None: | ||
if message.content_type == "application/json": | ||
async with aiohttp.request( | ||
message.method, | ||
message.url, | ||
json=message.data.model_dump_json() if message.data else None, | ||
) as response: | ||
response_data = await _parse_response(response, self.response_class) | ||
else: | ||
async with aiohttp.request( | ||
message.method, | ||
message.url, | ||
data=message.data.model_dump(exclude={"data_type"}) | ||
if message.data | ||
else None, | ||
headers={"Content-Type": message.content_type}, | ||
) as response: | ||
response_data = await _parse_response(response, self.response_class) | ||
await self.r.publish( | ||
self.output_channel, | ||
Message[self.response_class](data=response_data).model_dump_json(), # type: ignore[name-defined] | ||
) | ||
|
||
async def event_handler( | ||
self, channel: str, message: Message[RestRequest] | ||
) -> AsyncIterator[tuple[str, Message[RestResponse]]]: | ||
if channel == self.input_channel: | ||
await self.handle_request(message.data) | ||
else: | ||
raise ValueError(f"Unexpected channel {channel}") | ||
yield # This is needed to make this function a generator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from aact.messages import get_rest_request_class, get_rest_response_class, Text | ||
|
||
|
||
def test_get_rest_request_class(): | ||
request_class = get_rest_request_class(Text) | ||
|
||
assert request_class.__name__ == "RestRequest[Text]" | ||
assert request_class.__annotations__["data"] == Text | None | ||
|
||
response_class = get_rest_response_class(Text) | ||
|
||
assert response_class.__name__ == "RestResponse[Text]" | ||
assert response_class.__annotations__["data"] == Text | None |
Oops, something went wrong.