From 8832a6b14fd8d3e56b797250eb6897ee7fede60e Mon Sep 17 00:00:00 2001 From: Michele Pangrazzi Date: Fri, 15 Nov 2024 18:18:02 +0100 Subject: [PATCH 1/3] Handle Callable-like types in get_request_model (eg streaming_callback) --- src/hayhooks/server/pipelines/models.py | 3 +- .../server/utils/create_valid_type.py | 39 ++++++++++++------- tests/test_handle_callable_type.py | 32 +++++++++++++++ 3 files changed, 59 insertions(+), 15 deletions(-) create mode 100644 tests/test_handle_callable_type.py diff --git a/src/hayhooks/server/pipelines/models.py b/src/hayhooks/server/pipelines/models.py index 51d37ad..ed5c6e2 100644 --- a/src/hayhooks/server/pipelines/models.py +++ b/src/hayhooks/server/pipelines/models.py @@ -1,5 +1,6 @@ from pandas import DataFrame from pydantic import BaseModel, ConfigDict, create_model +from typing import Callable from hayhooks.server.utils.create_valid_type import handle_unsupported_types @@ -27,7 +28,7 @@ def get_request_model(pipeline_name: str, pipeline_inputs): component_model = {} for name, typedef in inputs.items(): try: - input_type = handle_unsupported_types(typedef["type"], {DataFrame: dict}) + input_type = handle_unsupported_types(typedef["type"], {DataFrame: dict, Callable: dict}) except TypeError as e: print(f"ERROR at {component_name!r}, {name}: {typedef}") raise e diff --git a/src/hayhooks/server/utils/create_valid_type.py b/src/hayhooks/server/utils/create_valid_type.py index 906a84b..d0f5d33 100644 --- a/src/hayhooks/server/utils/create_valid_type.py +++ b/src/hayhooks/server/utils/create_valid_type.py @@ -1,24 +1,36 @@ +from collections.abc import Callable as CallableABC from inspect import isclass from types import GenericAlias -from typing import Dict, Optional, Union, get_args, get_origin, get_type_hints +from typing import Dict, Optional, Union, get_args, get_origin, get_type_hints, Callable def handle_unsupported_types(type_: type, types_mapping: Dict[type, type]) -> Union[GenericAlias, type]: """ Recursively handle types that are not supported by Pydantic by replacing them with the given types mapping. - - :param type_: Type to replace if not supported - :param types_mapping: Mapping of types to replace """ - def _handle_generics(t_) -> GenericAlias: - """ - Handle generics recursively - """ + def is_callable_type(t): + """Check if a type is any form of callable""" + origin = get_origin(t) + return ( + t is Callable + or origin is Callable + or origin is CallableABC + or (origin is not None and isinstance(origin, type) and issubclass(origin, CallableABC)) + or (isinstance(t, type) and issubclass(t, CallableABC)) + ) + + def handle_generics(t_) -> GenericAlias: + """Handle generics recursively""" + if is_callable_type(t_): + return types_mapping[Callable] + child_typing = [] for t in get_args(t_): if t in types_mapping: result = types_mapping[t] + elif is_callable_type(t): + result = types_mapping[Callable] elif isclass(t): result = handle_unsupported_types(t, types_mapping) else: @@ -26,20 +38,19 @@ def _handle_generics(t_) -> GenericAlias: child_typing.append(result) if len(child_typing) == 2 and child_typing[1] is type(None): - # because TypedDict can't handle union types with None - # rewrite them as Optional[type] return Optional[child_typing[0]] else: return GenericAlias(get_origin(t_), tuple(child_typing)) + if is_callable_type(type_): + return types_mapping[Callable] + if isclass(type_): new_type = {} for arg_name, arg_type in get_type_hints(type_).items(): if get_args(arg_type): - new_type[arg_name] = _handle_generics(arg_type) + new_type[arg_name] = handle_generics(arg_type) else: new_type[arg_name] = arg_type - return type_ - - return _handle_generics(type_) + return handle_generics(type_) diff --git a/tests/test_handle_callable_type.py b/tests/test_handle_callable_type.py new file mode 100644 index 0000000..a05ff8a --- /dev/null +++ b/tests/test_handle_callable_type.py @@ -0,0 +1,32 @@ +import typing +import haystack +from types import NoneType +from hayhooks.server.pipelines.models import get_request_model + + +def test_handle_callable_type_when_creating_pipeline_models(): + pipeline_name = "test_pipeline" + pipeline_inputs = { + 'generator': { + 'system_prompt': {'type': typing.Optional[str], 'is_mandatory': False, 'default_value': None}, + 'streaming_callback': { + 'type': typing.Optional[ + typing.Callable[[haystack.dataclasses.streaming_chunk.StreamingChunk], NoneType] + ], + 'is_mandatory': False, + 'default_value': None, + }, + 'generation_kwargs': { + 'type': typing.Optional[typing.Dict[str, typing.Any]], + 'is_mandatory': False, + 'default_value': None, + }, + } + } + + request_model = get_request_model(pipeline_name, pipeline_inputs) + + # This line used to throw an error because the Callable type was not handled correctly + # by the handle_unsupported_types function + assert request_model.model_json_schema() is not None + assert request_model.__name__ == "Test_pipelineRunRequest" From 8ebd76c9405255c65a69a4bcacbcc171de14f412 Mon Sep 17 00:00:00 2001 From: Michele Pangrazzi Date: Fri, 15 Nov 2024 21:14:58 +0100 Subject: [PATCH 2/3] refactoring ; add test --- .../server/utils/create_valid_type.py | 31 ++++++----- src/hayhooks/server/utils/deploy_utils.py | 5 +- tests/test_handle_callable_type.py | 51 +++++++++++++------ 3 files changed, 58 insertions(+), 29 deletions(-) diff --git a/src/hayhooks/server/utils/create_valid_type.py b/src/hayhooks/server/utils/create_valid_type.py index d0f5d33..5072a78 100644 --- a/src/hayhooks/server/utils/create_valid_type.py +++ b/src/hayhooks/server/utils/create_valid_type.py @@ -1,7 +1,25 @@ from collections.abc import Callable as CallableABC from inspect import isclass from types import GenericAlias -from typing import Dict, Optional, Union, get_args, get_origin, get_type_hints, Callable +from typing import Callable, Dict, Optional, Union, get_args, get_origin, get_type_hints + + +def is_callable_type(t): + """Check if a type is any form of callable""" + if t in (Callable, CallableABC): + return True + + # Check origin type + origin = get_origin(t) + if origin in (Callable, CallableABC): + return True + + # Handle Optional/Union types + if origin in (Union, type(Optional[int])): # type(Optional[int]) handles runtime Optional type + args = get_args(t) + return any(is_callable_type(arg) for arg in args) + + return False def handle_unsupported_types(type_: type, types_mapping: Dict[type, type]) -> Union[GenericAlias, type]: @@ -9,17 +27,6 @@ def handle_unsupported_types(type_: type, types_mapping: Dict[type, type]) -> Un Recursively handle types that are not supported by Pydantic by replacing them with the given types mapping. """ - def is_callable_type(t): - """Check if a type is any form of callable""" - origin = get_origin(t) - return ( - t is Callable - or origin is Callable - or origin is CallableABC - or (origin is not None and isinstance(origin, type) and issubclass(origin, CallableABC)) - or (isinstance(t, type) and issubclass(t, CallableABC)) - ) - def handle_generics(t_) -> GenericAlias: """Handle generics recursively""" if is_callable_type(t_): diff --git a/src/hayhooks/server/utils/deploy_utils.py b/src/hayhooks/server/utils/deploy_utils.py index 9a5d762..2d42742 100644 --- a/src/hayhooks/server/utils/deploy_utils.py +++ b/src/hayhooks/server/utils/deploy_utils.py @@ -1,15 +1,16 @@ from fastapi import HTTPException -from fastapi.responses import JSONResponse from fastapi.concurrency import run_in_threadpool +from fastapi.responses import JSONResponse from hayhooks.server.pipelines import registry from hayhooks.server.pipelines.models import ( PipelineDefinition, + convert_component_output, get_request_model, get_response_model, - convert_component_output, ) + def deploy_pipeline_def(app, pipeline_def: PipelineDefinition): try: pipe = registry.add(pipeline_def.name, pipeline_def.source_code) diff --git a/tests/test_handle_callable_type.py b/tests/test_handle_callable_type.py index a05ff8a..9f158f4 100644 --- a/tests/test_handle_callable_type.py +++ b/tests/test_handle_callable_type.py @@ -1,25 +1,46 @@ -import typing -import haystack +from collections.abc import Callable as CallableABC from types import NoneType +from typing import Any, Callable, Optional, Union + +import haystack +import pytest + from hayhooks.server.pipelines.models import get_request_model +from hayhooks.server.utils.create_valid_type import is_callable_type + + +@pytest.mark.parametrize( + "t, expected", + [ + (Callable, True), + (CallableABC, True), + (Callable[[int], str], True), + (Callable[..., Any], True), + (int, False), + (str, False), + (Any, False), + (Union[int, str], False), + (Optional[Callable[[haystack.dataclasses.streaming_chunk.StreamingChunk], NoneType]], True), + ], +) +def test_is_callable_type(t, expected): + assert is_callable_type(t) == expected def test_handle_callable_type_when_creating_pipeline_models(): pipeline_name = "test_pipeline" pipeline_inputs = { - 'generator': { - 'system_prompt': {'type': typing.Optional[str], 'is_mandatory': False, 'default_value': None}, - 'streaming_callback': { - 'type': typing.Optional[ - typing.Callable[[haystack.dataclasses.streaming_chunk.StreamingChunk], NoneType] - ], - 'is_mandatory': False, - 'default_value': None, + "generator": { + "system_prompt": {"type": Optional[str], "is_mandatory": False, "default_value": None}, + "streaming_callback": { + "type": Optional[Callable[[haystack.dataclasses.streaming_chunk.StreamingChunk], NoneType]], + "is_mandatory": False, + "default_value": None, }, - 'generation_kwargs': { - 'type': typing.Optional[typing.Dict[str, typing.Any]], - 'is_mandatory': False, - 'default_value': None, + "generation_kwargs": { + "type": Optional[dict[str, Any]], + "is_mandatory": False, + "default_value": None, }, } } @@ -27,6 +48,6 @@ def test_handle_callable_type_when_creating_pipeline_models(): request_model = get_request_model(pipeline_name, pipeline_inputs) # This line used to throw an error because the Callable type was not handled correctly - # by the handle_unsupported_types function + # by the handle_unsupported_types function assert request_model.model_json_schema() is not None assert request_model.__name__ == "Test_pipelineRunRequest" From d1adc2b7735b50a51909e9731f68708deaf88058 Mon Sep 17 00:00:00 2001 From: Michele Pangrazzi Date: Tue, 19 Nov 2024 10:32:11 +0100 Subject: [PATCH 3/3] skip callable types when creating pipeline ; recursively serialize component output to avoid pydantic serialization errors (eg on OpenAI responses) --- src/hayhooks/server/pipelines/models.py | 62 ++++++++++--------- .../server/utils/create_valid_type.py | 16 ++--- tests/test_convert_component_output.py | 43 +++++++++++++ tests/test_handle_callable_type.py | 3 +- 4 files changed, 86 insertions(+), 38 deletions(-) create mode 100644 tests/test_convert_component_output.py diff --git a/src/hayhooks/server/pipelines/models.py b/src/hayhooks/server/pipelines/models.py index ed5c6e2..3d36f4a 100644 --- a/src/hayhooks/server/pipelines/models.py +++ b/src/hayhooks/server/pipelines/models.py @@ -1,7 +1,5 @@ from pandas import DataFrame from pydantic import BaseModel, ConfigDict, create_model -from typing import Callable - from hayhooks.server.utils.create_valid_type import handle_unsupported_types @@ -28,14 +26,16 @@ def get_request_model(pipeline_name: str, pipeline_inputs): component_model = {} for name, typedef in inputs.items(): try: - input_type = handle_unsupported_types(typedef["type"], {DataFrame: dict, Callable: dict}) + input_type = handle_unsupported_types(typedef["type"], {DataFrame: dict}) except TypeError as e: print(f"ERROR at {component_name!r}, {name}: {typedef}") raise e - component_model[name] = ( - input_type, - typedef.get("default_value", ...), - ) + + if input_type is not None: + component_model[name] = ( + input_type, + typedef.get("default_value", ...), + ) request_model[component_name] = (create_model("ComponentParams", **component_model, __config__=config), ...) return create_model(f"{pipeline_name.capitalize()}RunRequest", **request_model, __config__=config) @@ -62,30 +62,34 @@ def get_response_model(pipeline_name: str, pipeline_outputs): return create_model(f"{pipeline_name.capitalize()}RunResponse", **response_model, __config__=config) +def convert_value_to_dict(value): + """Convert a single value to a dictionary if possible""" + if hasattr(value, "to_dict"): + if "init_parameters" in value.to_dict(): + return value.to_dict()["init_parameters"] + return value.to_dict() + elif hasattr(value, "model_dump"): + return value.model_dump() + elif isinstance(value, dict): + return {k: convert_value_to_dict(v) for k, v in value.items()} + elif isinstance(value, list): + return [convert_value_to_dict(item) for item in value] + else: + return value + + def convert_component_output(component_output): """ - Converts outputs from a component as a dict so that it can be validated against response model - - Component output has this form: + Converts component outputs to dictionaries that can be validated against response model. + Handles nested structures recursively. - "documents":[ - {"id":"818170...", "content":"RapidAPI for Mac is a full-featured HTTP client."} - ] + Args: + component_output: Dict with component outputs + Returns: + Dict with all nested objects converted to dictionaries """ - result = {} - for output_name, data in component_output.items(): - - def get_value(data): - if hasattr(data, "to_dict") and "init_parameters" in data.to_dict(): - return data.to_dict()["init_parameters"] - elif hasattr(data, "to_dict"): - return data.to_dict() - else: - return data - - if type(data) is list: - result[output_name] = [get_value(d) for d in data] - else: - result[output_name] = get_value(data) - return result + if isinstance(component_output, dict): + return {name: convert_value_to_dict(data) for name, data in component_output.items()} + + return convert_value_to_dict(component_output) diff --git a/src/hayhooks/server/utils/create_valid_type.py b/src/hayhooks/server/utils/create_valid_type.py index 5072a78..6c02f9e 100644 --- a/src/hayhooks/server/utils/create_valid_type.py +++ b/src/hayhooks/server/utils/create_valid_type.py @@ -22,22 +22,22 @@ def is_callable_type(t): return False -def handle_unsupported_types(type_: type, types_mapping: Dict[type, type]) -> Union[GenericAlias, type]: +def handle_unsupported_types( + type_: type, types_mapping: Dict[type, type], skip_callables: bool = True +) -> Union[GenericAlias, type, None]: """ Recursively handle types that are not supported by Pydantic by replacing them with the given types mapping. """ - def handle_generics(t_) -> GenericAlias: + def handle_generics(t_) -> Union[GenericAlias, None]: """Handle generics recursively""" - if is_callable_type(t_): - return types_mapping[Callable] + if is_callable_type(t_) and skip_callables: + return None child_typing = [] for t in get_args(t_): if t in types_mapping: result = types_mapping[t] - elif is_callable_type(t): - result = types_mapping[Callable] elif isclass(t): result = handle_unsupported_types(t, types_mapping) else: @@ -49,8 +49,8 @@ def handle_generics(t_) -> GenericAlias: else: return GenericAlias(get_origin(t_), tuple(child_typing)) - if is_callable_type(type_): - return types_mapping[Callable] + if is_callable_type(type_) and skip_callables: + return None if isclass(type_): new_type = {} diff --git a/tests/test_convert_component_output.py b/tests/test_convert_component_output.py new file mode 100644 index 0000000..df5cb68 --- /dev/null +++ b/tests/test_convert_component_output.py @@ -0,0 +1,43 @@ +from hayhooks.server.pipelines.models import convert_component_output +from openai.types.completion_usage import CompletionTokensDetails, PromptTokensDetails + + +def test_convert_component_output_with_nested_models(): + sample_response = [ + { + 'model': 'gpt-4o-mini-2024-07-18', + 'index': 0, + 'finish_reason': 'stop', + 'usage': { + 'completion_tokens': 52, + 'prompt_tokens': 29, + 'total_tokens': 81, + 'completion_tokens_details': CompletionTokensDetails( + accepted_prediction_tokens=0, audio_tokens=0, reasoning_tokens=0, rejected_prediction_tokens=0 + ), + 'prompt_tokens_details': PromptTokensDetails(audio_tokens=0, cached_tokens=0), + }, + } + ] + + converted_output = convert_component_output(sample_response) + + assert converted_output == [ + { + 'model': 'gpt-4o-mini-2024-07-18', + 'index': 0, + 'finish_reason': 'stop', + 'usage': { + 'completion_tokens': 52, + 'prompt_tokens': 29, + 'total_tokens': 81, + 'completion_tokens_details': { + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + }, + 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}, + }, + } + ] diff --git a/tests/test_handle_callable_type.py b/tests/test_handle_callable_type.py index 9f158f4..529b6cf 100644 --- a/tests/test_handle_callable_type.py +++ b/tests/test_handle_callable_type.py @@ -27,7 +27,7 @@ def test_is_callable_type(t, expected): assert is_callable_type(t) == expected -def test_handle_callable_type_when_creating_pipeline_models(): +def test_skip_callables_when_creating_pipeline_models(): pipeline_name = "test_pipeline" pipeline_inputs = { "generator": { @@ -51,3 +51,4 @@ def test_handle_callable_type_when_creating_pipeline_models(): # by the handle_unsupported_types function assert request_model.model_json_schema() is not None assert request_model.__name__ == "Test_pipelineRunRequest" + assert "streaming_callback" not in request_model.model_json_schema()["$defs"]["ComponentParams"]["properties"]