Skip to content

Commit

Permalink
feat: add input output abstraction to settings
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulKalho committed Nov 11, 2024
1 parent f3553ad commit bf89361
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 7 deletions.
23 changes: 23 additions & 0 deletions scystream/sdk/config/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@
CONFIG_FILE_DEFAULT_NAME = "cbc.yaml"


def _remove_empty_dicts(data):
"""
Remove keys with empty dictionaries from a nested structure.
"""
if isinstance(data, dict):
return {k: _remove_empty_dicts(v) for k, v in data.items() if v != {}}
elif isinstance(data, list):
return [_remove_empty_dicts(i) for i in data]
else:
return data


def load_config(
config_file_name: str = CONFIG_FILE_DEFAULT_NAME,
config_path: Union[str, Path] = None
Expand All @@ -18,11 +30,22 @@ def load_config(
try:
file = _find_and_load_config(config_file_name, config_path)
block = ComputeBlock(**file)
# TODO: Check if envs && input/output configs correspond to the
# loaded one
return block
except ValidationError as e:
raise ValueError(f"Configuration file validation error: {e}")


def generate_yaml_from_compute_block(
compute_block: ComputeBlock,
output_path: Path
):
cleaned_data = _remove_empty_dicts(compute_block.dict())
with output_path.open("w") as file:
yaml.dump(cleaned_data, file, default_flow_style=False)


def _find_and_load_config(
config_file_name: str,
config_path: Union[str, Path] = None
Expand Down
9 changes: 6 additions & 3 deletions scystream/sdk/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

FILE_TYPE_IDENTIFIER = "file"
DB_TABLE_TYPE_IDENTIFIER = "db_table"
# TODO: reevaluate the identifier
TODO_TYPE_IDENTIFIER = "TODO: SetType"

"""
This file contains the schema definition for the config file.
Expand All @@ -23,7 +25,8 @@ class InputOutputModel(BaseModel):
If a value is explicitly set to `null`, validation will fail unless the
ENV-Variable is manually set by the ComputeBlock user.
"""
type: Literal[FILE_TYPE_IDENTIFIER, DB_TABLE_TYPE_IDENTIFIER]
type: Literal[FILE_TYPE_IDENTIFIER,
DB_TABLE_TYPE_IDENTIFIER, TODO_TYPE_IDENTIFIER]
description: Optional[StrictStr] = None
config: Optional[
Dict[
Expand Down Expand Up @@ -59,8 +62,8 @@ class Entrypoint(BaseModel):
Optional[Union[StrictStr, StrictInt, StrictFloat, List, bool]]
]
] = None
inputs: Dict[StrictStr, InputOutputModel]
outputs: Dict[StrictStr, InputOutputModel]
inputs: Optional[Dict[StrictStr, InputOutputModel]] = None
outputs: Optional[Dict[StrictStr, InputOutputModel]] = None


class ComputeBlock(BaseModel):
Expand Down
81 changes: 78 additions & 3 deletions scystream/sdk/core.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import functools

from typing import Callable, Type, Optional
from typing import Callable, Type, Optional, Union
from .env.settings import EnvSettings
from pydantic import ValidationError
from scystream.sdk.config.models import ComputeBlock, Entrypoint, \
InputOutputModel
from pydantic_core import PydanticUndefinedType
from scystream.sdk.env.settings import InputSettings, OutputSettings

_registered_functions = {}

Expand All @@ -18,6 +21,9 @@ def wrapper(*args, **kwargs):
if settings_class is not None:
# Load settings
try:
# TODO: 1. LoadSettings
# TODO: 2. Generate config from settings (only for the entrypoint)
# TODO: 3. Validate if generated config and given config are same
settings = settings_class.get_settings()
except ValidationError as e:
raise ValueError(f"Invalid environment configuration: {e}")
Expand All @@ -26,11 +32,80 @@ def wrapper(*args, **kwargs):
else:
return func(*args, **kwargs)

_registered_functions[func.__name__] = wrapper
_registered_functions[func.__name__] = {
"function": wrapper,
"settings": settings_class
}
return wrapper
return decorator


def get_registered_functions():
"""Returns a dictionary of registered entrypoint functions."""
print(_registered_functions)
return _registered_functions


def _get_pydantic_default_value_or_none(value):
if type(value.default) is PydanticUndefinedType:
return None
return value.default


def _build_input_output_dict_from_class(
subject: Union[InputSettings, OutputSettings]
):
config_dict = {}
for key, value in subject.model_fields.items():
config_dict[key] = _get_pydantic_default_value_or_none(value)
return InputOutputModel(
type="TODO: SetType",
description="<to-be-set>",
config=config_dict
)


def generate_compute_block() -> ComputeBlock:
"""
Converts the Settings to a ComputeBlock
"""
entrypoints = {}
for entrypoint, func in _registered_functions.items():
envs = {}
inputs = {}
outputs = {}

if func["settings"]:
entrypoint_settings_class = func["settings"]
for key, value in entrypoint_settings_class.model_fields.items():
if (
isinstance(value.default_factory, type) and
issubclass(value.default_factory, InputSettings)
):
inputs[key] = _build_input_output_dict_from_class(
value.default_factory
)
elif (
isinstance(value.default_factory, type) and
issubclass(value.default_factory, OutputSettings)
):
outputs[key] = _build_input_output_dict_from_class(
value.default_factory
)
else:
envs[key] = _get_pydantic_default_value_or_none(value)

entrypoints[entrypoint] = Entrypoint(
description="<tbd>",
envs=envs,
inputs=inputs,
outputs=outputs
)

return ComputeBlock(
name="<tbs>",
description="<tbs>",
author="<tbs>",
entrypoints=entrypoints,
docker_image="<tbs>"
)
16 changes: 16 additions & 0 deletions scystream/sdk/env/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,19 @@ def __init__(self, propagate_kwargs=None, *args, **kwargs):
if propagate_kwargs:
kwargs = self._propagate_kwargs(propagate_kwargs)
super().__init__(*args, **kwargs)


class InputSettings(EnvSettings):
"""
Abstraction-Layer for inputs
could be extended
"""
pass


class OutputSettings(EnvSettings):
"""
Abstraction-Layer for outputs
could be exended
"""
pass
2 changes: 1 addition & 1 deletion scystream/sdk/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ def list_entrypoints():
def execute_function(name, *args, **kwargs):
functions = get_registered_functions()
if name in functions:
return functions[name](*args, **kwargs)
return functions[name]["function"](*args, **kwargs)
else:
raise Exception(f"No entrypoint found with the name: {name}")

0 comments on commit bf89361

Please sign in to comment.