Skip to content

Commit

Permalink
refactoring ; add test
Browse files Browse the repository at this point in the history
  • Loading branch information
mpangrazzi committed Nov 15, 2024
1 parent 8832a6b commit 8ebd76c
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 29 deletions.
31 changes: 19 additions & 12 deletions src/hayhooks/server/utils/create_valid_type.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,32 @@
from collections.abc import Callable as CallableABC
from inspect import isclass
from types import GenericAlias
from typing import Dict, Optional, Union, get_args, get_origin, get_type_hints, Callable
from typing import Callable, Dict, Optional, Union, get_args, get_origin, get_type_hints


def is_callable_type(t):
"""Check if a type is any form of callable"""
if t in (Callable, CallableABC):
return True

# Check origin type
origin = get_origin(t)
if origin in (Callable, CallableABC):
return True

# Handle Optional/Union types
if origin in (Union, type(Optional[int])): # type(Optional[int]) handles runtime Optional type
args = get_args(t)
return any(is_callable_type(arg) for arg in args)

return False


def handle_unsupported_types(type_: type, types_mapping: Dict[type, type]) -> Union[GenericAlias, type]:
"""
Recursively handle types that are not supported by Pydantic by replacing them with the given types mapping.
"""

def is_callable_type(t):
"""Check if a type is any form of callable"""
origin = get_origin(t)
return (
t is Callable
or origin is Callable
or origin is CallableABC
or (origin is not None and isinstance(origin, type) and issubclass(origin, CallableABC))
or (isinstance(t, type) and issubclass(t, CallableABC))
)

def handle_generics(t_) -> GenericAlias:
"""Handle generics recursively"""
if is_callable_type(t_):
Expand Down
5 changes: 3 additions & 2 deletions src/hayhooks/server/utils/deploy_utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from fastapi.concurrency import run_in_threadpool
from fastapi.responses import JSONResponse

from hayhooks.server.pipelines import registry
from hayhooks.server.pipelines.models import (
PipelineDefinition,
convert_component_output,
get_request_model,
get_response_model,
convert_component_output,
)


def deploy_pipeline_def(app, pipeline_def: PipelineDefinition):
try:
pipe = registry.add(pipeline_def.name, pipeline_def.source_code)
Expand Down
51 changes: 36 additions & 15 deletions tests/test_handle_callable_type.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,53 @@
import typing
import haystack
from collections.abc import Callable as CallableABC
from types import NoneType
from typing import Any, Callable, Optional, Union

import haystack
import pytest

from hayhooks.server.pipelines.models import get_request_model
from hayhooks.server.utils.create_valid_type import is_callable_type


@pytest.mark.parametrize(
"t, expected",
[
(Callable, True),
(CallableABC, True),
(Callable[[int], str], True),
(Callable[..., Any], True),
(int, False),
(str, False),
(Any, False),
(Union[int, str], False),
(Optional[Callable[[haystack.dataclasses.streaming_chunk.StreamingChunk], NoneType]], True),
],
)
def test_is_callable_type(t, expected):
assert is_callable_type(t) == expected


def test_handle_callable_type_when_creating_pipeline_models():
pipeline_name = "test_pipeline"
pipeline_inputs = {
'generator': {
'system_prompt': {'type': typing.Optional[str], 'is_mandatory': False, 'default_value': None},
'streaming_callback': {
'type': typing.Optional[
typing.Callable[[haystack.dataclasses.streaming_chunk.StreamingChunk], NoneType]
],
'is_mandatory': False,
'default_value': None,
"generator": {
"system_prompt": {"type": Optional[str], "is_mandatory": False, "default_value": None},
"streaming_callback": {
"type": Optional[Callable[[haystack.dataclasses.streaming_chunk.StreamingChunk], NoneType]],
"is_mandatory": False,
"default_value": None,
},
'generation_kwargs': {
'type': typing.Optional[typing.Dict[str, typing.Any]],
'is_mandatory': False,
'default_value': None,
"generation_kwargs": {
"type": Optional[dict[str, Any]],
"is_mandatory": False,
"default_value": None,
},
}
}

request_model = get_request_model(pipeline_name, pipeline_inputs)

# This line used to throw an error because the Callable type was not handled correctly
# by the handle_unsupported_types function
# by the handle_unsupported_types function
assert request_model.model_json_schema() is not None
assert request_model.__name__ == "Test_pipelineRunRequest"

0 comments on commit 8ebd76c

Please sign in to comment.