diff --git a/pyproject.toml b/pyproject.toml index 68b32e3..607d7d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "uvicorn", "requests", "python-multipart", + "loguru", ] [project.urls] diff --git a/src/hayhooks/server/logger.py b/src/hayhooks/server/logger.py new file mode 100644 index 0000000..088458e --- /dev/null +++ b/src/hayhooks/server/logger.py @@ -0,0 +1,8 @@ +# logger.py + +import os +import sys +from loguru import logger as log + +log.remove() +log.add(sys.stderr, level=os.getenv("LOG", "INFO").upper()) diff --git a/src/hayhooks/server/pipelines/models.py b/src/hayhooks/server/pipelines/models.py index 3d36f4a..f4f5a92 100644 --- a/src/hayhooks/server/pipelines/models.py +++ b/src/hayhooks/server/pipelines/models.py @@ -1,6 +1,7 @@ from pandas import DataFrame from pydantic import BaseModel, ConfigDict, create_model from hayhooks.server.utils.create_valid_type import handle_unsupported_types +from haystack import Document class PipelineDefinition(BaseModel): @@ -8,6 +9,12 @@ class PipelineDefinition(BaseModel): source_code: str +DEFAULT_TYPES_MAPPING = { + DataFrame: dict, + Document: dict, +} + + def get_request_model(pipeline_name: str, pipeline_inputs): """ Inputs have this form: @@ -26,7 +33,7 @@ def get_request_model(pipeline_name: str, pipeline_inputs): component_model = {} for name, typedef in inputs.items(): try: - input_type = handle_unsupported_types(typedef["type"], {DataFrame: dict}) + input_type = handle_unsupported_types(type_=typedef["type"], types_mapping=DEFAULT_TYPES_MAPPING) except TypeError as e: print(f"ERROR at {component_name!r}, {name}: {typedef}") raise e @@ -56,7 +63,10 @@ def get_response_model(pipeline_name: str, pipeline_outputs): component_model = {} for name, typedef in outputs.items(): output_type = typedef["type"] - component_model[name] = (handle_unsupported_types(output_type, {DataFrame: dict}), ...) + component_model[name] = ( + handle_unsupported_types(type_=output_type, types_mapping=DEFAULT_TYPES_MAPPING), + ..., + ) response_model[component_name] = (create_model("ComponentParams", **component_model, __config__=config), ...) return create_model(f"{pipeline_name.capitalize()}RunResponse", **response_model, __config__=config) diff --git a/src/hayhooks/server/utils/create_valid_type.py b/src/hayhooks/server/utils/create_valid_type.py index 6c02f9e..2c103c9 100644 --- a/src/hayhooks/server/utils/create_valid_type.py +++ b/src/hayhooks/server/utils/create_valid_type.py @@ -1,7 +1,7 @@ from collections.abc import Callable as CallableABC -from inspect import isclass from types import GenericAlias -from typing import Callable, Dict, Optional, Union, get_args, get_origin, get_type_hints +from typing import Callable, Optional, Union, get_args, get_origin +from loguru import logger def is_callable_type(t): @@ -23,41 +23,26 @@ def is_callable_type(t): def handle_unsupported_types( - type_: type, types_mapping: Dict[type, type], skip_callables: bool = True + type_: type, types_mapping: dict, skip_callables: bool = True ) -> Union[GenericAlias, type, None]: - """ - Recursively handle types that are not supported by Pydantic by replacing them with the given types mapping. - """ - - def handle_generics(t_) -> Union[GenericAlias, None]: - """Handle generics recursively""" - if is_callable_type(t_) and skip_callables: - return None - - child_typing = [] - for t in get_args(t_): - if t in types_mapping: - result = types_mapping[t] - elif isclass(t): - result = handle_unsupported_types(t, types_mapping) - else: - result = t - child_typing.append(result) - - if len(child_typing) == 2 and child_typing[1] is type(None): - return Optional[child_typing[0]] - else: - return GenericAlias(get_origin(t_), tuple(child_typing)) - - if is_callable_type(type_) and skip_callables: + logger.debug(f"Handling unsupported type: {type_}") + + if skip_callables and is_callable_type(type_): + logger.warning(f"Skipping callable type: {type_}") return None - if isclass(type_): - new_type = {} - for arg_name, arg_type in get_type_hints(type_).items(): - if get_args(arg_type): - new_type[arg_name] = handle_generics(arg_type) - else: - new_type[arg_name] = arg_type - return type_ - return handle_generics(type_) + # Handle generic types (like List, Optional, etc.) + origin = get_origin(type_) + if origin is not None: + args = get_args(type_) + # Map the inner types using the same mapping + mapped_args = tuple(handle_unsupported_types(arg, types_mapping, skip_callables) or arg for arg in args) + # Reconstruct the generic type with mapped arguments + return origin[mapped_args] + + if type_ in types_mapping: + logger.debug(f"Mapping type: {type_} to {types_mapping[type_]}") + return types_mapping[type_] + + logger.debug(f"Returning original type: {type_}") + return type_ diff --git a/tests/test_files/basic_rag_pipeline.yml b/tests/test_files/basic_rag_pipeline.yml new file mode 100644 index 0000000..472b377 --- /dev/null +++ b/tests/test_files/basic_rag_pipeline.yml @@ -0,0 +1,72 @@ +components: + llm: + init_parameters: + api_base_url: null + api_key: + env_vars: + - OPENAI_API_KEY + strict: true + type: env_var + generation_kwargs: {} + model: gpt-4o-mini + organization: null + streaming_callback: null + system_prompt: null + type: haystack.components.generators.openai.OpenAIGenerator + prompt_builder: + init_parameters: + required_variables: null + template: "\nGiven the following information, answer the question.\n\nContext:\n\ + {% for document in documents %}\n {{ document.content }}\n{% endfor %}\n\ + \nQuestion: {{question}}\nAnswer:\n" + variables: null + type: haystack.components.builders.prompt_builder.PromptBuilder + retriever: + init_parameters: + document_store: + init_parameters: + bm25_algorithm: BM25L + bm25_parameters: {} + bm25_tokenization_regex: (?u)\b\w\w+\b + embedding_similarity_function: dot_product + index: d8b1f58f-20e9-4a57-a84d-a44fc651de4e + type: haystack.document_stores.in_memory.document_store.InMemoryDocumentStore + filter_policy: replace + filters: null + return_embedding: false + scale_score: false + top_k: 10 + type: haystack.components.retrievers.in_memory.embedding_retriever.InMemoryEmbeddingRetriever + text_embedder: + init_parameters: + batch_size: 32 + config_kwargs: null + device: + device: mps + type: single + model: sentence-transformers/all-MiniLM-L6-v2 + model_kwargs: null + normalize_embeddings: false + precision: float32 + prefix: '' + progress_bar: true + suffix: '' + token: + env_vars: + - HF_API_TOKEN + - HF_TOKEN + strict: false + type: env_var + tokenizer_kwargs: null + truncate_dim: null + trust_remote_code: false + type: haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder +connections: +- receiver: retriever.query_embedding + sender: text_embedder.embedding +- receiver: prompt_builder.documents + sender: retriever.documents +- receiver: llm.prompt + sender: prompt_builder.prompt +max_runs_per_component: 100 +metadata: {} diff --git a/tests/test_files/chat_with_website.yaml b/tests/test_files/chat_with_website.yml similarity index 75% rename from tests/test_files/chat_with_website.yaml rename to tests/test_files/chat_with_website.yml index 1cb3869..db4063f 100644 --- a/tests/test_files/chat_with_website.yaml +++ b/tests/test_files/chat_with_website.yml @@ -1,8 +1,8 @@ components: converter: - init_parameters: - extractor_type: DefaultExtractor type: haystack.components.converters.html.HTMLToDocument + init_parameters: + extraction_kwargs: null fetcher: init_parameters: @@ -10,7 +10,7 @@ components: retry_attempts: 2 timeout: 3 user_agents: - - haystack/LinkContentFetcher/2.0.0b8 + - haystack/LinkContentFetcher/2.0.0b8 type: haystack.components.fetchers.link_content.LinkContentFetcher llm: @@ -18,11 +18,11 @@ components: api_base_url: null api_key: env_vars: - - OPENAI_API_KEY + - OPENAI_API_KEY strict: true type: env_var generation_kwargs: {} - model: gpt-3.5-turbo + model: gpt-4o-mini streaming_callback: null system_prompt: null type: haystack.components.generators.openai.OpenAIGenerator @@ -40,11 +40,11 @@ components: type: haystack.components.builders.prompt_builder.PromptBuilder connections: -- receiver: converter.sources - sender: fetcher.streams -- receiver: prompt.documents - sender: converter.documents -- receiver: llm.prompt - sender: prompt.prompt + - receiver: converter.sources + sender: fetcher.streams + - receiver: prompt.documents + sender: converter.documents + - receiver: llm.prompt + sender: prompt.prompt metadata: {} diff --git a/tests/test_files/pipeline_qdrant.yml b/tests/test_files/pipeline_qdrant.yml new file mode 100644 index 0000000..84642a4 --- /dev/null +++ b/tests/test_files/pipeline_qdrant.yml @@ -0,0 +1,169 @@ +components: + embedder: + init_parameters: + batch_size: 32 + config_kwargs: null + device: + device: cpu + type: single + model: sentence-transformers/all-MiniLM-L6-v2 + model_kwargs: null + normalize_embeddings: false + precision: float32 + prefix: '' + progress_bar: true + suffix: '' + token: + env_vars: + - HF_API_TOKEN + - HF_TOKEN + strict: false + type: env_var + tokenizer_kwargs: null + truncate_dim: null + trust_remote_code: false + type: haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder + list_to_str_adapter: + init_parameters: + custom_filters: {} + output_type: str + template: '{{ replies[0] }}' + unsafe: false + type: haystack.components.converters.output_adapter.OutputAdapter + llm: + init_parameters: + api_base_url: http://localhost:8000/v1 + api_key: + env_vars: + - OPENAI_API_KEY + strict: true + type: env_var + generation_kwargs: {} + model: mistralai/Mistral-Nemo-Instruct-2407 + organization: null + streaming_callback: null + type: haystack.components.generators.chat.openai.OpenAIChatGenerator + memory_joiner: + init_parameters: + type_: list[haystack.dataclasses.chat_message.ChatMessage] + type: haystack.components.joiners.branch.BranchJoiner + memory_retriever: + init_parameters: + last_k: 10 + message_store: + init_parameters: {} + type: haystack_experimental.chat_message_stores.in_memory.InMemoryChatMessageStore + type: haystack_experimental.components.retrievers.chat_message_retriever.ChatMessageRetriever + memory_writer: + init_parameters: + message_store: + init_parameters: {} + type: haystack_experimental.chat_message_stores.in_memory.InMemoryChatMessageStore + type: haystack_experimental.components.writers.chat_message_writer.ChatMessageWriter + prompt_builder: + init_parameters: + required_variables: &id001 !!python/tuple + - query + - documents + - memories + template: null + variables: *id001 + type: haystack.components.builders.chat_prompt_builder.ChatPromptBuilder + query_rephrase_llm: + init_parameters: + api_base_url: http://localhost:8000/v1 + api_key: + env_vars: + - OPENAI_API_KEY + strict: true + type: env_var + generation_kwargs: {} + model: mistralai/Mistral-Nemo-Instruct-2407 + organization: null + streaming_callback: null + system_prompt: null + type: haystack.components.generators.openai.OpenAIGenerator + query_rephrase_prompt_builder: + init_parameters: + required_variables: null + template: "\nRewrite the question for semantic search while keeping its meaning\ + \ and key terms intact.\nIf the conversation history is empty, DO NOT change\ + \ the query.\nDo not translate the question.\nUse conversation history only\ + \ if necessary, and avoid extending the query with your own knowledge.\nIf\ + \ no changes are needed, output the current question as is.\n\nConversation\ + \ history:\n{% for memory in memories %}\n {{ memory.content }}\n{% endfor\ + \ %}\n\nUser Query: {{query}}\nRewritten Query:\n" + variables: null + type: haystack.components.builders.prompt_builder.PromptBuilder + retriever: + init_parameters: + document_store: + init_parameters: + api_key: null + embedding_dim: 768 + force_disable_check_same_thread: false + grpc_port: 6334 + hnsw_config: null + host: null + https: null + index: Document + init_from: null + location: null + metadata: {} + on_disk: false + on_disk_payload: null + optimizers_config: null + path: null + payload_fields_to_index: null + port: 6333 + prefer_grpc: false + prefix: null + progress_bar: false + quantization_config: null + recreate_index: false + replication_factor: null + return_embedding: false + scroll_size: 10000 + shard_number: null + similarity: cosine + sparse_idf: false + timeout: null + url: http://localhost:6333 + use_sparse_embeddings: false + wait_result_from_api: true + wal_config: null + write_batch_size: 100 + write_consistency_factor: null + type: haystack_integrations.document_stores.qdrant.document_store.QdrantDocumentStore + filter_policy: replace + filters: null + group_by: null + group_size: null + return_embedding: false + scale_score: false + score_threshold: null + top_k: 3 + type: haystack_integrations.components.retrievers.qdrant.retriever.QdrantEmbeddingRetriever +connections: +- receiver: query_rephrase_llm.prompt + sender: query_rephrase_prompt_builder.prompt +- receiver: list_to_str_adapter.replies + sender: query_rephrase_llm.replies +- receiver: embedder.text + sender: list_to_str_adapter.output +- receiver: retriever.query_embedding + sender: embedder.embedding +- receiver: prompt_builder.documents + sender: retriever.documents +- receiver: llm.messages + sender: prompt_builder.prompt +- receiver: memory_joiner.value + sender: llm.replies +- receiver: query_rephrase_prompt_builder.memories + sender: memory_retriever.messages +- receiver: prompt_builder.memories + sender: memory_retriever.messages +- receiver: memory_writer.messages + sender: memory_joiner.value +max_runs_per_component: 100 +metadata: {} \ No newline at end of file diff --git a/tests/test_files/pipeline_qdrant_2.yml b/tests/test_files/pipeline_qdrant_2.yml new file mode 100644 index 0000000..9ae3ff3 --- /dev/null +++ b/tests/test_files/pipeline_qdrant_2.yml @@ -0,0 +1,46 @@ +components: + document_embedder: + init_parameters: + batch_size: 32 + config_kwargs: null + device: + device: cpu + type: single + model: sentence-transformers/paraphrase-MiniLM-L3-v2 + model_kwargs: null + normalize_embeddings: false + precision: float32 + prefix: '' + progress_bar: true + suffix: '' + token: + env_vars: + - HF_API_TOKEN + - HF_TOKEN + strict: false + type: env_var + tokenizer_kwargs: null + truncate_dim: null + trust_remote_code: false + type: haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder + document_retriever: + init_parameters: + document_store: + init_parameters: + bm25_algorithm: BM25L + bm25_parameters: {} + bm25_tokenization_regex: (?u)\b\w\w+\b + embedding_similarity_function: dot_product + index: b39f1fea-7c83-4fdc-a9e0-928e3d5e4ae7 + type: haystack.document_stores.in_memory.document_store.InMemoryDocumentStore + filter_policy: replace + filters: null + return_embedding: false + scale_score: false + top_k: 3 + type: haystack.components.retrievers.in_memory.embedding_retriever.InMemoryEmbeddingRetriever +connections: +- receiver: document_retriever.query_embedding + sender: document_embedder.embedding +max_runs_per_component: 100 +metadata: {} diff --git a/tests/test_files/st_retriever.yml b/tests/test_files/st_retriever.yml new file mode 100644 index 0000000..9ae3ff3 --- /dev/null +++ b/tests/test_files/st_retriever.yml @@ -0,0 +1,46 @@ +components: + document_embedder: + init_parameters: + batch_size: 32 + config_kwargs: null + device: + device: cpu + type: single + model: sentence-transformers/paraphrase-MiniLM-L3-v2 + model_kwargs: null + normalize_embeddings: false + precision: float32 + prefix: '' + progress_bar: true + suffix: '' + token: + env_vars: + - HF_API_TOKEN + - HF_TOKEN + strict: false + type: env_var + tokenizer_kwargs: null + truncate_dim: null + trust_remote_code: false + type: haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder + document_retriever: + init_parameters: + document_store: + init_parameters: + bm25_algorithm: BM25L + bm25_parameters: {} + bm25_tokenization_regex: (?u)\b\w\w+\b + embedding_similarity_function: dot_product + index: b39f1fea-7c83-4fdc-a9e0-928e3d5e4ae7 + type: haystack.document_stores.in_memory.document_store.InMemoryDocumentStore + filter_policy: replace + filters: null + return_embedding: false + scale_score: false + top_k: 3 + type: haystack.components.retrievers.in_memory.embedding_retriever.InMemoryEmbeddingRetriever +connections: +- receiver: document_retriever.query_embedding + sender: document_embedder.embedding +max_runs_per_component: 100 +metadata: {} diff --git a/tests/test_handle_unsupported_types.py b/tests/test_handle_unsupported_types.py new file mode 100644 index 0000000..cb6d979 --- /dev/null +++ b/tests/test_handle_unsupported_types.py @@ -0,0 +1,41 @@ +from typing import Optional, List +from hayhooks.server.utils.create_valid_type import handle_unsupported_types + + +def test_handle_simple_type(): + result = handle_unsupported_types(int, {}) + assert result == int + + +def test_handle_generic_type(): + result = handle_unsupported_types(List[int], {}) + assert result == list[int] + + +def test_handle_recursive_type(): + class Node: + def __init__(self, value: int, next: Optional['Node'] = None): + self.value = value + self.next = next + + result = handle_unsupported_types(Node, {}) + assert result == Node + + +def test_handle_circular_reference(): + class A: + def __init__(self, b: 'B'): + self.b = b + + class B: + def __init__(self, a: 'A'): + self.a = a + + result = handle_unsupported_types(A, {}) + assert result == A # Adjust assertion based on expected behavior + + +def test_handle_nested_generics(): + nested_type = dict[str, list[Optional[int]]] + result = handle_unsupported_types(nested_type, {}) + assert result == nested_type diff --git a/tests/test_it_deploy.py b/tests/test_it_deploy.py new file mode 100644 index 0000000..2e5ae88 --- /dev/null +++ b/tests/test_it_deploy.py @@ -0,0 +1,24 @@ +import pytest +from fastapi.testclient import TestClient +from hayhooks.server import app +from pathlib import Path + +client = TestClient(app) + +# Load pipeline definitions from test_files +test_files = Path(__file__).parent / "test_files" +pipeline_names = [file.stem for file in test_files.glob("*.yml")] + + +@pytest.mark.parametrize("pipeline_name", pipeline_names) +def test_deploy_pipeline_def(pipeline_name: str): + pipeline_def = (Path(__file__).parent / "test_files" / f"{pipeline_name}.yml").read_text() + + deploy_response = client.post("/deploy", json={"name": pipeline_name, "source_code": pipeline_def}) + assert deploy_response.status_code == 200 + + status_response = client.get("/status") + assert pipeline_name in status_response.json()["pipelines"] + + docs_response = client.get("/docs") + assert docs_response.status_code == 200