From bf8936164c819a8bc1a47e5043ab35b2820fef40 Mon Sep 17 00:00:00 2001 From: PaulKalho Date: Mon, 11 Nov 2024 22:00:58 +0100 Subject: [PATCH] feat: add input output abstraction to settings --- scystream/sdk/config/config_loader.py | 23 ++++++++ scystream/sdk/config/models.py | 9 ++- scystream/sdk/core.py | 81 ++++++++++++++++++++++++++- scystream/sdk/env/settings.py | 16 ++++++ scystream/sdk/scheduler.py | 2 +- 5 files changed, 124 insertions(+), 7 deletions(-) diff --git a/scystream/sdk/config/config_loader.py b/scystream/sdk/config/config_loader.py index f68e565..13760e2 100644 --- a/scystream/sdk/config/config_loader.py +++ b/scystream/sdk/config/config_loader.py @@ -7,6 +7,18 @@ CONFIG_FILE_DEFAULT_NAME = "cbc.yaml" +def _remove_empty_dicts(data): + """ + Remove keys with empty dictionaries from a nested structure. + """ + if isinstance(data, dict): + return {k: _remove_empty_dicts(v) for k, v in data.items() if v != {}} + elif isinstance(data, list): + return [_remove_empty_dicts(i) for i in data] + else: + return data + + def load_config( config_file_name: str = CONFIG_FILE_DEFAULT_NAME, config_path: Union[str, Path] = None @@ -18,11 +30,22 @@ def load_config( try: file = _find_and_load_config(config_file_name, config_path) block = ComputeBlock(**file) + # TODO: Check if envs && input/output configs correspond to the + # loaded one return block except ValidationError as e: raise ValueError(f"Configuration file validation error: {e}") +def generate_yaml_from_compute_block( + compute_block: ComputeBlock, + output_path: Path +): + cleaned_data = _remove_empty_dicts(compute_block.dict()) + with output_path.open("w") as file: + yaml.dump(cleaned_data, file, default_flow_style=False) + + def _find_and_load_config( config_file_name: str, config_path: Union[str, Path] = None diff --git a/scystream/sdk/config/models.py b/scystream/sdk/config/models.py index cca2008..bcc67de 100644 --- a/scystream/sdk/config/models.py +++ b/scystream/sdk/config/models.py @@ -4,6 +4,8 @@ FILE_TYPE_IDENTIFIER = "file" DB_TABLE_TYPE_IDENTIFIER = "db_table" +# TODO: reevaluate the identifier +TODO_TYPE_IDENTIFIER = "TODO: SetType" """ This file contains the schema definition for the config file. @@ -23,7 +25,8 @@ class InputOutputModel(BaseModel): If a value is explicitly set to `null`, validation will fail unless the ENV-Variable is manually set by the ComputeBlock user. """ - type: Literal[FILE_TYPE_IDENTIFIER, DB_TABLE_TYPE_IDENTIFIER] + type: Literal[FILE_TYPE_IDENTIFIER, + DB_TABLE_TYPE_IDENTIFIER, TODO_TYPE_IDENTIFIER] description: Optional[StrictStr] = None config: Optional[ Dict[ @@ -59,8 +62,8 @@ class Entrypoint(BaseModel): Optional[Union[StrictStr, StrictInt, StrictFloat, List, bool]] ] ] = None - inputs: Dict[StrictStr, InputOutputModel] - outputs: Dict[StrictStr, InputOutputModel] + inputs: Optional[Dict[StrictStr, InputOutputModel]] = None + outputs: Optional[Dict[StrictStr, InputOutputModel]] = None class ComputeBlock(BaseModel): diff --git a/scystream/sdk/core.py b/scystream/sdk/core.py index 5dd8b12..58cf18d 100644 --- a/scystream/sdk/core.py +++ b/scystream/sdk/core.py @@ -1,8 +1,11 @@ import functools - -from typing import Callable, Type, Optional +from typing import Callable, Type, Optional, Union from .env.settings import EnvSettings from pydantic import ValidationError +from scystream.sdk.config.models import ComputeBlock, Entrypoint, \ + InputOutputModel +from pydantic_core import PydanticUndefinedType +from scystream.sdk.env.settings import InputSettings, OutputSettings _registered_functions = {} @@ -18,6 +21,9 @@ def wrapper(*args, **kwargs): if settings_class is not None: # Load settings try: + # TODO: 1. LoadSettings + # TODO: 2. Generate config from settings (only for the entrypoint) + # TODO: 3. Validate if generated config and given config are same settings = settings_class.get_settings() except ValidationError as e: raise ValueError(f"Invalid environment configuration: {e}") @@ -26,11 +32,80 @@ def wrapper(*args, **kwargs): else: return func(*args, **kwargs) - _registered_functions[func.__name__] = wrapper + _registered_functions[func.__name__] = { + "function": wrapper, + "settings": settings_class + } return wrapper return decorator def get_registered_functions(): """Returns a dictionary of registered entrypoint functions.""" + print(_registered_functions) return _registered_functions + + +def _get_pydantic_default_value_or_none(value): + if type(value.default) is PydanticUndefinedType: + return None + return value.default + + +def _build_input_output_dict_from_class( + subject: Union[InputSettings, OutputSettings] +): + config_dict = {} + for key, value in subject.model_fields.items(): + config_dict[key] = _get_pydantic_default_value_or_none(value) + return InputOutputModel( + type="TODO: SetType", + description="", + config=config_dict + ) + + +def generate_compute_block() -> ComputeBlock: + """ + Converts the Settings to a ComputeBlock + """ + entrypoints = {} + for entrypoint, func in _registered_functions.items(): + envs = {} + inputs = {} + outputs = {} + + if func["settings"]: + entrypoint_settings_class = func["settings"] + for key, value in entrypoint_settings_class.model_fields.items(): + if ( + isinstance(value.default_factory, type) and + issubclass(value.default_factory, InputSettings) + ): + inputs[key] = _build_input_output_dict_from_class( + value.default_factory + ) + elif ( + isinstance(value.default_factory, type) and + issubclass(value.default_factory, OutputSettings) + ): + outputs[key] = _build_input_output_dict_from_class( + value.default_factory + ) + else: + envs[key] = _get_pydantic_default_value_or_none(value) + + entrypoints[entrypoint] = Entrypoint( + description="", + envs=envs, + inputs=inputs, + outputs=outputs + ) + + return ComputeBlock( + name="", + description="", + author="", + entrypoints=entrypoints, + docker_image="" + ) diff --git a/scystream/sdk/env/settings.py b/scystream/sdk/env/settings.py index b217b93..e0c27c8 100644 --- a/scystream/sdk/env/settings.py +++ b/scystream/sdk/env/settings.py @@ -72,3 +72,19 @@ def __init__(self, propagate_kwargs=None, *args, **kwargs): if propagate_kwargs: kwargs = self._propagate_kwargs(propagate_kwargs) super().__init__(*args, **kwargs) + + +class InputSettings(EnvSettings): + """ + Abstraction-Layer for inputs + could be extended + """ + pass + + +class OutputSettings(EnvSettings): + """ + Abstraction-Layer for outputs + could be exended + """ + pass diff --git a/scystream/sdk/scheduler.py b/scystream/sdk/scheduler.py index f403c83..c610897 100644 --- a/scystream/sdk/scheduler.py +++ b/scystream/sdk/scheduler.py @@ -13,6 +13,6 @@ def list_entrypoints(): def execute_function(name, *args, **kwargs): functions = get_registered_functions() if name in functions: - return functions[name](*args, **kwargs) + return functions[name]["function"](*args, **kwargs) else: raise Exception(f"No entrypoint found with the name: {name}")