From 8ebd76c9405255c65a69a4bcacbcc171de14f412 Mon Sep 17 00:00:00 2001 From: Michele Pangrazzi Date: Fri, 15 Nov 2024 21:14:58 +0100 Subject: [PATCH] refactoring ; add test --- .../server/utils/create_valid_type.py | 31 ++++++----- src/hayhooks/server/utils/deploy_utils.py | 5 +- tests/test_handle_callable_type.py | 51 +++++++++++++------ 3 files changed, 58 insertions(+), 29 deletions(-) diff --git a/src/hayhooks/server/utils/create_valid_type.py b/src/hayhooks/server/utils/create_valid_type.py index d0f5d33..5072a78 100644 --- a/src/hayhooks/server/utils/create_valid_type.py +++ b/src/hayhooks/server/utils/create_valid_type.py @@ -1,7 +1,25 @@ from collections.abc import Callable as CallableABC from inspect import isclass from types import GenericAlias -from typing import Dict, Optional, Union, get_args, get_origin, get_type_hints, Callable +from typing import Callable, Dict, Optional, Union, get_args, get_origin, get_type_hints + + +def is_callable_type(t): + """Check if a type is any form of callable""" + if t in (Callable, CallableABC): + return True + + # Check origin type + origin = get_origin(t) + if origin in (Callable, CallableABC): + return True + + # Handle Optional/Union types + if origin in (Union, type(Optional[int])): # type(Optional[int]) handles runtime Optional type + args = get_args(t) + return any(is_callable_type(arg) for arg in args) + + return False def handle_unsupported_types(type_: type, types_mapping: Dict[type, type]) -> Union[GenericAlias, type]: @@ -9,17 +27,6 @@ def handle_unsupported_types(type_: type, types_mapping: Dict[type, type]) -> Un Recursively handle types that are not supported by Pydantic by replacing them with the given types mapping. """ - def is_callable_type(t): - """Check if a type is any form of callable""" - origin = get_origin(t) - return ( - t is Callable - or origin is Callable - or origin is CallableABC - or (origin is not None and isinstance(origin, type) and issubclass(origin, CallableABC)) - or (isinstance(t, type) and issubclass(t, CallableABC)) - ) - def handle_generics(t_) -> GenericAlias: """Handle generics recursively""" if is_callable_type(t_): diff --git a/src/hayhooks/server/utils/deploy_utils.py b/src/hayhooks/server/utils/deploy_utils.py index 9a5d762..2d42742 100644 --- a/src/hayhooks/server/utils/deploy_utils.py +++ b/src/hayhooks/server/utils/deploy_utils.py @@ -1,15 +1,16 @@ from fastapi import HTTPException -from fastapi.responses import JSONResponse from fastapi.concurrency import run_in_threadpool +from fastapi.responses import JSONResponse from hayhooks.server.pipelines import registry from hayhooks.server.pipelines.models import ( PipelineDefinition, + convert_component_output, get_request_model, get_response_model, - convert_component_output, ) + def deploy_pipeline_def(app, pipeline_def: PipelineDefinition): try: pipe = registry.add(pipeline_def.name, pipeline_def.source_code) diff --git a/tests/test_handle_callable_type.py b/tests/test_handle_callable_type.py index a05ff8a..9f158f4 100644 --- a/tests/test_handle_callable_type.py +++ b/tests/test_handle_callable_type.py @@ -1,25 +1,46 @@ -import typing -import haystack +from collections.abc import Callable as CallableABC from types import NoneType +from typing import Any, Callable, Optional, Union + +import haystack +import pytest + from hayhooks.server.pipelines.models import get_request_model +from hayhooks.server.utils.create_valid_type import is_callable_type + + +@pytest.mark.parametrize( + "t, expected", + [ + (Callable, True), + (CallableABC, True), + (Callable[[int], str], True), + (Callable[..., Any], True), + (int, False), + (str, False), + (Any, False), + (Union[int, str], False), + (Optional[Callable[[haystack.dataclasses.streaming_chunk.StreamingChunk], NoneType]], True), + ], +) +def test_is_callable_type(t, expected): + assert is_callable_type(t) == expected def test_handle_callable_type_when_creating_pipeline_models(): pipeline_name = "test_pipeline" pipeline_inputs = { - 'generator': { - 'system_prompt': {'type': typing.Optional[str], 'is_mandatory': False, 'default_value': None}, - 'streaming_callback': { - 'type': typing.Optional[ - typing.Callable[[haystack.dataclasses.streaming_chunk.StreamingChunk], NoneType] - ], - 'is_mandatory': False, - 'default_value': None, + "generator": { + "system_prompt": {"type": Optional[str], "is_mandatory": False, "default_value": None}, + "streaming_callback": { + "type": Optional[Callable[[haystack.dataclasses.streaming_chunk.StreamingChunk], NoneType]], + "is_mandatory": False, + "default_value": None, }, - 'generation_kwargs': { - 'type': typing.Optional[typing.Dict[str, typing.Any]], - 'is_mandatory': False, - 'default_value': None, + "generation_kwargs": { + "type": Optional[dict[str, Any]], + "is_mandatory": False, + "default_value": None, }, } } @@ -27,6 +48,6 @@ def test_handle_callable_type_when_creating_pipeline_models(): request_model = get_request_model(pipeline_name, pipeline_inputs) # This line used to throw an error because the Callable type was not handled correctly - # by the handle_unsupported_types function + # by the handle_unsupported_types function assert request_model.model_json_schema() is not None assert request_model.__name__ == "Test_pipelineRunRequest"