Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix recursion error when handling unsupported types #44

Merged
merged 6 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "hayhooks"
dynamic = ["version"]
description = 'Grab and deploy Haystack pipelines'
readme = "README.md"
requires-python = ">=3.7"
requires-python = ">=3.7,<3.13"
license = "Apache-2.0"
keywords = []
authors = [
Expand All @@ -29,6 +29,7 @@ dependencies = [
"uvicorn",
"requests",
"python-multipart",
"loguru",
]

[project.urls]
Expand Down
8 changes: 8 additions & 0 deletions src/hayhooks/server/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# logger.py

import os
import sys
from loguru import logger as log

log.remove()
log.add(sys.stderr, level=os.getenv("LOG", "INFO").upper())
14 changes: 12 additions & 2 deletions src/hayhooks/server/pipelines/models.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
from pandas import DataFrame
from pydantic import BaseModel, ConfigDict, create_model
from hayhooks.server.utils.create_valid_type import handle_unsupported_types
from haystack import Document


class PipelineDefinition(BaseModel):
name: str
source_code: str


DEFAULT_TYPES_MAPPING = {
DataFrame: dict,
Document: dict,
}


def get_request_model(pipeline_name: str, pipeline_inputs):
"""
Inputs have this form:
Expand All @@ -26,7 +33,7 @@ def get_request_model(pipeline_name: str, pipeline_inputs):
component_model = {}
for name, typedef in inputs.items():
try:
input_type = handle_unsupported_types(typedef["type"], {DataFrame: dict})
input_type = handle_unsupported_types(type_=typedef["type"], types_mapping=DEFAULT_TYPES_MAPPING)
except TypeError as e:
print(f"ERROR at {component_name!r}, {name}: {typedef}")
raise e
Expand Down Expand Up @@ -56,7 +63,10 @@ def get_response_model(pipeline_name: str, pipeline_outputs):
component_model = {}
for name, typedef in outputs.items():
output_type = typedef["type"]
component_model[name] = (handle_unsupported_types(output_type, {DataFrame: dict}), ...)
component_model[name] = (
handle_unsupported_types(type_=output_type, types_mapping=DEFAULT_TYPES_MAPPING),
...,
)
response_model[component_name] = (create_model("ComponentParams", **component_model, __config__=config), ...)

return create_model(f"{pipeline_name.capitalize()}RunResponse", **response_model, __config__=config)
Expand Down
59 changes: 22 additions & 37 deletions src/hayhooks/server/utils/create_valid_type.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import Callable as CallableABC
from inspect import isclass
from types import GenericAlias
from typing import Callable, Dict, Optional, Union, get_args, get_origin, get_type_hints
from typing import Callable, Optional, Union, get_args, get_origin
from loguru import logger


def is_callable_type(t):
Expand All @@ -23,41 +23,26 @@ def is_callable_type(t):


def handle_unsupported_types(
type_: type, types_mapping: Dict[type, type], skip_callables: bool = True
type_: type, types_mapping: dict, skip_callables: bool = True
) -> Union[GenericAlias, type, None]:
"""
Recursively handle types that are not supported by Pydantic by replacing them with the given types mapping.
"""

def handle_generics(t_) -> Union[GenericAlias, None]:
"""Handle generics recursively"""
if is_callable_type(t_) and skip_callables:
return None

child_typing = []
for t in get_args(t_):
if t in types_mapping:
result = types_mapping[t]
elif isclass(t):
result = handle_unsupported_types(t, types_mapping)
else:
result = t
child_typing.append(result)

if len(child_typing) == 2 and child_typing[1] is type(None):
return Optional[child_typing[0]]
else:
return GenericAlias(get_origin(t_), tuple(child_typing))

if is_callable_type(type_) and skip_callables:
logger.debug(f"Handling unsupported type: {type_}")

if skip_callables and is_callable_type(type_):
logger.warning(f"Skipping callable type: {type_}")
return None

if isclass(type_):
new_type = {}
for arg_name, arg_type in get_type_hints(type_).items():
if get_args(arg_type):
new_type[arg_name] = handle_generics(arg_type)
else:
new_type[arg_name] = arg_type
return type_
return handle_generics(type_)
# Handle generic types (like List, Optional, etc.)
origin = get_origin(type_)
if origin is not None:
args = get_args(type_)
# Map the inner types using the same mapping
mapped_args = tuple(handle_unsupported_types(arg, types_mapping, skip_callables) or arg for arg in args)
# Reconstruct the generic type with mapped arguments
return origin[mapped_args]

if type_ in types_mapping:
logger.debug(f"Mapping type: {type_} to {types_mapping[type_]}")
return types_mapping[type_]

logger.debug(f"Returning original type: {type_}")
return type_
72 changes: 72 additions & 0 deletions tests/test_files/basic_rag_pipeline.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
components:
llm:
init_parameters:
api_base_url: null
api_key:
env_vars:
- OPENAI_API_KEY
strict: true
type: env_var
generation_kwargs: {}
model: gpt-4o-mini
organization: null
streaming_callback: null
system_prompt: null
type: haystack.components.generators.openai.OpenAIGenerator
prompt_builder:
init_parameters:
required_variables: null
template: "\nGiven the following information, answer the question.\n\nContext:\n\
{% for document in documents %}\n {{ document.content }}\n{% endfor %}\n\
\nQuestion: {{question}}\nAnswer:\n"
variables: null
type: haystack.components.builders.prompt_builder.PromptBuilder
retriever:
init_parameters:
document_store:
init_parameters:
bm25_algorithm: BM25L
bm25_parameters: {}
bm25_tokenization_regex: (?u)\b\w\w+\b
embedding_similarity_function: dot_product
index: d8b1f58f-20e9-4a57-a84d-a44fc651de4e
type: haystack.document_stores.in_memory.document_store.InMemoryDocumentStore
filter_policy: replace
filters: null
return_embedding: false
scale_score: false
top_k: 10
type: haystack.components.retrievers.in_memory.embedding_retriever.InMemoryEmbeddingRetriever
text_embedder:
init_parameters:
batch_size: 32
config_kwargs: null
device:
device: mps
type: single
model: sentence-transformers/all-MiniLM-L6-v2
model_kwargs: null
normalize_embeddings: false
precision: float32
prefix: ''
progress_bar: true
suffix: ''
token:
env_vars:
- HF_API_TOKEN
- HF_TOKEN
strict: false
type: env_var
tokenizer_kwargs: null
truncate_dim: null
trust_remote_code: false
type: haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder
connections:
- receiver: retriever.query_embedding
sender: text_embedder.embedding
- receiver: prompt_builder.documents
sender: retriever.documents
- receiver: llm.prompt
sender: prompt_builder.prompt
max_runs_per_component: 100
metadata: {}
50 changes: 0 additions & 50 deletions tests/test_files/chat_with_website.yaml

This file was deleted.

Loading