Skip to content

Commit

Permalink
Merge pull request #44 from deepset-ai/fix-recursion-error
Browse files Browse the repository at this point in the history
fix recursion error when handling unsupported types
  • Loading branch information
mpangrazzi authored Dec 13, 2024
2 parents 5739ee6 + 296fd6f commit 1d96419
Show file tree
Hide file tree
Showing 11 changed files with 452 additions and 50 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies = [
"uvicorn",
"requests",
"python-multipart",
"loguru",
]

[project.urls]
Expand Down
8 changes: 8 additions & 0 deletions src/hayhooks/server/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# logger.py

import os
import sys
from loguru import logger as log

log.remove()
log.add(sys.stderr, level=os.getenv("LOG", "INFO").upper())
14 changes: 12 additions & 2 deletions src/hayhooks/server/pipelines/models.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
from pandas import DataFrame
from pydantic import BaseModel, ConfigDict, create_model
from hayhooks.server.utils.create_valid_type import handle_unsupported_types
from haystack import Document


class PipelineDefinition(BaseModel):
name: str
source_code: str


DEFAULT_TYPES_MAPPING = {
DataFrame: dict,
Document: dict,
}


def get_request_model(pipeline_name: str, pipeline_inputs):
"""
Inputs have this form:
Expand All @@ -26,7 +33,7 @@ def get_request_model(pipeline_name: str, pipeline_inputs):
component_model = {}
for name, typedef in inputs.items():
try:
input_type = handle_unsupported_types(typedef["type"], {DataFrame: dict})
input_type = handle_unsupported_types(type_=typedef["type"], types_mapping=DEFAULT_TYPES_MAPPING)
except TypeError as e:
print(f"ERROR at {component_name!r}, {name}: {typedef}")
raise e
Expand Down Expand Up @@ -56,7 +63,10 @@ def get_response_model(pipeline_name: str, pipeline_outputs):
component_model = {}
for name, typedef in outputs.items():
output_type = typedef["type"]
component_model[name] = (handle_unsupported_types(output_type, {DataFrame: dict}), ...)
component_model[name] = (
handle_unsupported_types(type_=output_type, types_mapping=DEFAULT_TYPES_MAPPING),
...,
)
response_model[component_name] = (create_model("ComponentParams", **component_model, __config__=config), ...)

return create_model(f"{pipeline_name.capitalize()}RunResponse", **response_model, __config__=config)
Expand Down
59 changes: 22 additions & 37 deletions src/hayhooks/server/utils/create_valid_type.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import Callable as CallableABC
from inspect import isclass
from types import GenericAlias
from typing import Callable, Dict, Optional, Union, get_args, get_origin, get_type_hints
from typing import Callable, Optional, Union, get_args, get_origin
from loguru import logger


def is_callable_type(t):
Expand All @@ -23,41 +23,26 @@ def is_callable_type(t):


def handle_unsupported_types(
type_: type, types_mapping: Dict[type, type], skip_callables: bool = True
type_: type, types_mapping: dict, skip_callables: bool = True
) -> Union[GenericAlias, type, None]:
"""
Recursively handle types that are not supported by Pydantic by replacing them with the given types mapping.
"""

def handle_generics(t_) -> Union[GenericAlias, None]:
"""Handle generics recursively"""
if is_callable_type(t_) and skip_callables:
return None

child_typing = []
for t in get_args(t_):
if t in types_mapping:
result = types_mapping[t]
elif isclass(t):
result = handle_unsupported_types(t, types_mapping)
else:
result = t
child_typing.append(result)

if len(child_typing) == 2 and child_typing[1] is type(None):
return Optional[child_typing[0]]
else:
return GenericAlias(get_origin(t_), tuple(child_typing))

if is_callable_type(type_) and skip_callables:
logger.debug(f"Handling unsupported type: {type_}")

if skip_callables and is_callable_type(type_):
logger.warning(f"Skipping callable type: {type_}")
return None

if isclass(type_):
new_type = {}
for arg_name, arg_type in get_type_hints(type_).items():
if get_args(arg_type):
new_type[arg_name] = handle_generics(arg_type)
else:
new_type[arg_name] = arg_type
return type_
return handle_generics(type_)
# Handle generic types (like List, Optional, etc.)
origin = get_origin(type_)
if origin is not None:
args = get_args(type_)
# Map the inner types using the same mapping
mapped_args = tuple(handle_unsupported_types(arg, types_mapping, skip_callables) or arg for arg in args)
# Reconstruct the generic type with mapped arguments
return origin[mapped_args]

if type_ in types_mapping:
logger.debug(f"Mapping type: {type_} to {types_mapping[type_]}")
return types_mapping[type_]

logger.debug(f"Returning original type: {type_}")
return type_
72 changes: 72 additions & 0 deletions tests/test_files/basic_rag_pipeline.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
components:
llm:
init_parameters:
api_base_url: null
api_key:
env_vars:
- OPENAI_API_KEY
strict: true
type: env_var
generation_kwargs: {}
model: gpt-4o-mini
organization: null
streaming_callback: null
system_prompt: null
type: haystack.components.generators.openai.OpenAIGenerator
prompt_builder:
init_parameters:
required_variables: null
template: "\nGiven the following information, answer the question.\n\nContext:\n\
{% for document in documents %}\n {{ document.content }}\n{% endfor %}\n\
\nQuestion: {{question}}\nAnswer:\n"
variables: null
type: haystack.components.builders.prompt_builder.PromptBuilder
retriever:
init_parameters:
document_store:
init_parameters:
bm25_algorithm: BM25L
bm25_parameters: {}
bm25_tokenization_regex: (?u)\b\w\w+\b
embedding_similarity_function: dot_product
index: d8b1f58f-20e9-4a57-a84d-a44fc651de4e
type: haystack.document_stores.in_memory.document_store.InMemoryDocumentStore
filter_policy: replace
filters: null
return_embedding: false
scale_score: false
top_k: 10
type: haystack.components.retrievers.in_memory.embedding_retriever.InMemoryEmbeddingRetriever
text_embedder:
init_parameters:
batch_size: 32
config_kwargs: null
device:
device: mps
type: single
model: sentence-transformers/all-MiniLM-L6-v2
model_kwargs: null
normalize_embeddings: false
precision: float32
prefix: ''
progress_bar: true
suffix: ''
token:
env_vars:
- HF_API_TOKEN
- HF_TOKEN
strict: false
type: env_var
tokenizer_kwargs: null
truncate_dim: null
trust_remote_code: false
type: haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder
connections:
- receiver: retriever.query_embedding
sender: text_embedder.embedding
- receiver: prompt_builder.documents
sender: retriever.documents
- receiver: llm.prompt
sender: prompt_builder.prompt
max_runs_per_component: 100
metadata: {}
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
components:
converter:
init_parameters:
extractor_type: DefaultExtractor
type: haystack.components.converters.html.HTMLToDocument
init_parameters:
extraction_kwargs: null

fetcher:
init_parameters:
raise_on_failure: true
retry_attempts: 2
timeout: 3
user_agents:
- haystack/LinkContentFetcher/2.0.0b8
- haystack/LinkContentFetcher/2.0.0b8
type: haystack.components.fetchers.link_content.LinkContentFetcher

llm:
init_parameters:
api_base_url: null
api_key:
env_vars:
- OPENAI_API_KEY
- OPENAI_API_KEY
strict: true
type: env_var
generation_kwargs: {}
model: gpt-3.5-turbo
model: gpt-4o-mini
streaming_callback: null
system_prompt: null
type: haystack.components.generators.openai.OpenAIGenerator
Expand All @@ -40,11 +40,11 @@ components:
type: haystack.components.builders.prompt_builder.PromptBuilder

connections:
- receiver: converter.sources
sender: fetcher.streams
- receiver: prompt.documents
sender: converter.documents
- receiver: llm.prompt
sender: prompt.prompt
- receiver: converter.sources
sender: fetcher.streams
- receiver: prompt.documents
sender: converter.documents
- receiver: llm.prompt
sender: prompt.prompt

metadata: {}
Loading

0 comments on commit 1d96419

Please sign in to comment.