Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions src/core/interfaces/command_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from asyncio import iscoroutinefunction as asyncio_iscoroutinefunction
from collections.abc import Awaitable, Callable
from typing import Any

Expand All @@ -9,6 +10,27 @@
CommandServiceHandler = Callable[[list[Any], str], Awaitable[ProcessedResult]]


def _is_async_callable(candidate: Any) -> bool:
"""Return ``True`` when *candidate* is an awaitable callable."""

if asyncio_iscoroutinefunction(candidate): # Handles partials and decorated callables
return True

func_attr = getattr(candidate, "func", None)
if func_attr and asyncio_iscoroutinefunction(func_attr):
return True

call_method = getattr(candidate, "__call__", None)
if not call_method:
return False

if asyncio_iscoroutinefunction(call_method):
return True

bound_function = getattr(call_method, "__func__", None)
return bool(bound_function and asyncio_iscoroutinefunction(bound_function))


class FunctionCommandService(ICommandService):
"""Adapter that turns a coroutine function into an ``ICommandService``."""

Expand All @@ -19,6 +41,10 @@ def __init__(self, handler: CommandServiceHandler):
)
if not callable(handler):
raise TypeError("The command service handler must be callable.")
if not _is_async_callable(handler):
raise TypeError(
"The command service handler must be an async callable that returns an awaitable result."
)
self._handler = handler

async def process_commands(
Expand Down Expand Up @@ -51,6 +77,10 @@ def ensure_command_service(
return service

if callable(service):
if not _is_async_callable(service):
raise TypeError(
"The command service handler must be an async callable that returns an awaitable result."
)
return FunctionCommandService(service) # type: ignore[arg-type]

raise TypeError("The provided command service is not valid.")
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/core/test_command_service_module.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

import pytest
from src.core.domain.processed_result import ProcessedResult
from src.core.interfaces.command_service import ensure_command_service
Expand Down Expand Up @@ -48,6 +50,42 @@ async def handler(messages: list[str], session_id: str) -> ProcessedResult:
assert result.command_results == ["session"]


@pytest.mark.asyncio
async def test_ensure_command_service_accepts_partial_async_callable() -> None:
async def handler(
messages: list[str], session_id: str, prefix: str
) -> ProcessedResult:
return ProcessedResult(
modified_messages=[f"{prefix}:{value}" for value in messages],
command_executed=bool(messages),
command_results=[session_id],
)

partial_handler = partial(handler, prefix="partial")

validated_service = ensure_command_service(partial_handler)

result = await validated_service.process_commands(["message"], "session")

assert result.modified_messages == ["partial:message"]
assert result.command_executed is True
assert result.command_results == ["session"]


def test_ensure_command_service_rejects_sync_callable() -> None:
def handler(messages: list[str], session_id: str) -> ProcessedResult:
return ProcessedResult(
modified_messages=messages,
command_executed=False,
command_results=[session_id],
)

with pytest.raises(TypeError) as exc:
ensure_command_service(handler)

assert "async" in str(exc.value).lower()


def test_ensure_command_service_rejects_none() -> None:
with pytest.raises(ValueError) as exc:
ensure_command_service(None)
Expand Down
Loading