Skip to content

Commit

Permalink
add api node
Browse files Browse the repository at this point in the history
  • Loading branch information
ProKil committed Nov 3, 2024
1 parent f742311 commit 869d5d3
Show file tree
Hide file tree
Showing 11 changed files with 360 additions and 6 deletions.
59 changes: 59 additions & 0 deletions examples/api_node_examples/api_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import AsyncIterator
from aact.messages.base import Message
from aact.messages.commons import (
RestRequest,
RestResponse,
Tick,
get_rest_request_class,
get_rest_response_class,
AnyDataModel,
)
from aact.nodes import Node, NodeFactory


@NodeFactory.register("api_client")
class APIClient(Node[Tick | RestResponse, RestRequest]):
def __init__(
self,
input_tick_channel: str,
input_response_channel: str,
output_channel: str,
redis_url: str,
):
response_class = get_rest_response_class(AnyDataModel)
request_class = get_rest_request_class(AnyDataModel)
super().__init__(
input_channel_types=[
(input_tick_channel, Tick),
(input_response_channel, response_class),
],
output_channel_types=[(output_channel, request_class)],
redis_url=redis_url,
)

self.request_class = request_class
self.response_class = response_class
self.input_tick_channel = input_tick_channel
self.input_response_channel = input_response_channel
self.output_channel = output_channel
self.output_message_type: type[Message[RestRequest]] = Message[request_class] # type: ignore[valid-type, assignment]

async def event_handler(
self, channel: str, message: Message[RestResponse | Tick]
) -> AsyncIterator[tuple[str, Message[RestRequest]]]:
if channel == self.input_response_channel:
print("Received response: ", message.data)
elif channel == self.input_tick_channel:
yield (
self.output_channel,
self.output_message_type(
data=self.request_class(
method="POST",
url="http://0.0.0.0:8080/spotify/auth/token",
data={"username": "test", "password": "test"},
content_type="application/x-www-form-urlencoded",
)
),
)
else:
raise ValueError(f"Invalid channel: {channel}")
27 changes: 27 additions & 0 deletions examples/api_node_examples/api_node_demo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
redis_url = "redis://localhost:6379/0"
extra_modules = ["examples.api_node_examples.api_client"]

[[nodes]]
node_name = "api_client"
node_class = "api_client"

[nodes.node_args]
input_tick_channel = "tick/secs/1"
input_response_channel = "rest_response"
output_channel = "rest_request"


[[nodes]]
node_name = "api_node"
node_class = "rest_api"

[nodes.node_args]
input_channel = "rest_request"
output_channel = "rest_response"
input_type_str = "any"
output_type_str = "any"


[[nodes]]
node_name = "tick"
node_class = "tick"
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ dev-dependencies = [
"types-pyaudio",
"types-aiofiles",
"types-requests",
"pytest"
]

[tool.mypy]
Expand Down
19 changes: 18 additions & 1 deletion src/aact/messages/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
from .base import Message, DataModel
from .commons import Zero, Tick, Image, Float, Audio, Text
from .commons import (
Zero,
AnyDataModel,
Tick,
Image,
Float,
Audio,
Text,
RestRequest,
RestResponse,
get_rest_request_class,
get_rest_response_class,
)
from .registry import DataModelFactory

__all__ = [
"Zero",
"AnyDataModel",
"Message",
"Tick",
"Image",
Expand All @@ -12,4 +25,8 @@
"DataModel",
"Audio",
"Text",
"RestRequest",
"RestResponse",
"get_rest_request_class",
"get_rest_response_class",
]
47 changes: 46 additions & 1 deletion src/aact/messages/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,27 @@

from .registry import DataModelFactory
from .base import DataModel
from pydantic import Field, PlainValidator, PlainSerializer, WithJsonSchema, BaseModel
from pydantic import (
ConfigDict,
Field,
PlainValidator,
PlainSerializer,
WithJsonSchema,
BaseModel,
create_model,
)


@DataModelFactory.register("zero")
class Zero(DataModel):
pass


@DataModelFactory.register("any")
class AnyDataModel(DataModel):
model_config = ConfigDict(extra="allow")


@DataModelFactory.register("tick")
class Tick(DataModel):
tick: int
Expand Down Expand Up @@ -61,3 +74,35 @@ class DataEntry(BaseModel, Generic[T]):
timestamp: datetime = Field(default_factory=datetime.now)
channel: str
data: T


@DataModelFactory.register("rest_request")
class RestRequest(DataModel):
url: str
method: str
data: DataModel | None
content_type: str = Field(default="application/json")


@DataModelFactory.register("rest_response")
class RestResponse(DataModel):
status_code: int
data: DataModel | None


def get_rest_request_class(data_model: type[T]) -> type[RestRequest]:
new_class = create_model(
f"RestRequest[{data_model.__name__}]",
__base__=RestRequest,
data=(data_model | None, None),
)
return new_class


def get_rest_response_class(data_model: type[T]) -> type[RestResponse]:
new_class = create_model(
f"RestResponse[{data_model.__name__}]",
__base__=RestResponse,
data=(data_model | None, None),
)
return new_class
2 changes: 1 addition & 1 deletion src/aact/messages/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def inner_wrapper(wrapped_class: type[T]) -> type[T]:
if name in cls.registry:
logger.warning("DataModel %s already exists. Will replace it", name)
new_class = create_model(
cls.__name__,
wrapped_class.__name__,
__base__=wrapped_class,
data_type=(Literal[name], name),
)
Expand Down
2 changes: 2 additions & 0 deletions src/aact/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .print import PrintNode
from .tts import TTSNode
from .registry import NodeFactory
from .api import RestAPINode

__all__ = [
"Node",
Expand All @@ -30,4 +31,5 @@
"PerformanceMeasureNode",
"PrintNode",
"TTSNode",
"RestAPINode",
]
145 changes: 145 additions & 0 deletions src/aact/nodes/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import asyncio
import json
import logging
from typing import AsyncIterator, TypeVar
from . import Node, NodeFactory

from ..messages.base import Message
from ..messages.commons import (
RestRequest,
RestResponse,
get_rest_response_class,
get_rest_request_class,
)

import aiohttp

from ..messages.registry import DataModelFactory

T = TypeVar("T", bound=RestResponse)


logger = logging.getLogger(__name__)


async def _parse_response(
response: aiohttp.ClientResponse, response_class: type[T]
) -> T:
status_code = response.status

try:
if response.content_type == "application/json":
json_data = await response.json()
# Only try to parse data into model if we have a successful response
if 200 <= status_code < 300 and json_data:
return response_class(status_code=status_code, data=json_data)
except (json.JSONDecodeError, ValueError) as e:
# Log error here if needed
logger.error(f"Error parsing response: {e}")
pass

logger.warning(f"{response}")
return response_class(status_code=status_code, data=None)


@NodeFactory.register("rest_api")
class RestAPINode(Node[RestRequest, RestResponse]):
"""
A node that sends a REST request to a given URL and sends the response to an output channel.
Args:
- `input_channel`: The input channel to listen for RestRequest messages.
- `output_channel`: The output channel to send RestResponse messages.
- `output_type_str`: The string identifier of the DataModel to parse the response into.
- `redis_url`: The URL of the Redis server to connect to.
REST API Node Example:
```toml
[[nodes]]
node_name = "rest_api"
node_class = "rest_api"
[[nodes.node_args]]
input_channel = "rest_request"
output_channel = "rest_response"
output_type_str = "float"
```
"""

def __init__(
self,
input_channel: str,
output_channel: str,
input_type_str: str,
output_type_str: str,
redis_url: str,
):
if input_type_str not in DataModelFactory.registry:
raise ValueError(
f"DataModel {input_type_str} is used as the input data model, but not found in registry"
)

if output_type_str not in DataModelFactory.registry:
raise ValueError(
f"DataModel {output_type_str} is used as the output data model, but not found in registry"
)

request_class = get_rest_request_class(
DataModelFactory.registry[input_type_str]
)
response_data_class = DataModelFactory.registry[output_type_str]
response_class = get_rest_response_class(response_data_class)

super().__init__(
input_channel_types=[(input_channel, request_class)],
output_channel_types=[(output_channel, response_class)],
redis_url=redis_url,
)

self.input_channel = input_channel
self.output_channel = output_channel
self.shutdown_event: asyncio.Event = asyncio.Event()
self.request_class = request_class
self.response_class = response_class
self.response_data_class = response_data_class

async def __aenter__(self) -> "RestAPINode":
await super().__aenter__()
return self

async def handle_request(self, message: RestRequest) -> None:
if message.content_type == "application/json":
async with aiohttp.request(
message.method,
message.url,
json=message.data.model_dump_json() if message.data else None,
) as response:
response_data = await _parse_response(response, self.response_class)
else:
async with aiohttp.request(
message.method,
message.url,
data=message.data.model_dump(exclude={"data_type"})
if message.data
else None,
headers={"Content-Type": message.content_type},
) as response:
response_data = await _parse_response(response, self.response_class)
await self.r.publish(
self.output_channel,
Message[self.response_class](data=response_data).model_dump_json(), # type: ignore[name-defined]
)

async def event_handler(
self, channel: str, message: Message[RestRequest]
) -> AsyncIterator[tuple[str, Message[RestResponse]]]:
if channel == self.input_channel:
await self.handle_request(message.data)
else:
raise ValueError(f"Unexpected channel {channel}")
yield # This is needed to make this function a generator
14 changes: 11 additions & 3 deletions src/aact/nodes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,17 @@ async def _wait_for_input(
async for message in self.pubsub.listen():
channel = message["channel"].decode("utf-8")
if message["type"] == "message" and channel in self.input_channel_types:
data = Message[self.input_channel_types[channel]].model_validate_json( # type: ignore
message["data"]
)
try:
data = Message[
self.input_channel_types[channel]
].model_validate_json( # type: ignore
message["data"]
)
except ValidationError as e:
self.logger.error(
f"Failed to validate message from {channel}: {message['data']}. Error: {e}"
)
raise e
yield channel, data
raise Exception("Input channel closed unexpectedly")

Expand Down
13 changes: 13 additions & 0 deletions tests/messages/test_rest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from aact.messages import get_rest_request_class, get_rest_response_class, Text


def test_get_rest_request_class():
request_class = get_rest_request_class(Text)

assert request_class.__name__ == "RestRequest[Text]"
assert request_class.__annotations__["data"] == Text | None

response_class = get_rest_response_class(Text)

assert response_class.__name__ == "RestResponse[Text]"
assert response_class.__annotations__["data"] == Text | None
Loading

0 comments on commit 869d5d3

Please sign in to comment.