From ce6ece17e99d8f31d65aa8eb1622a4e2b083ef74 Mon Sep 17 00:00:00 2001 From: Anupam Kumar Date: Wed, 7 Aug 2024 16:46:59 +0530 Subject: [PATCH] feat: Add enum and default value support in task processing Signed-off-by: Anupam Kumar --- nc_py_api/ex_app/providers/task_processing.py | 127 +++++++++++++----- 1 file changed, 97 insertions(+), 30 deletions(-) diff --git a/nc_py_api/ex_app/providers/task_processing.py b/nc_py_api/ex_app/providers/task_processing.py index 8f76d969..4582f742 100644 --- a/nc_py_api/ex_app/providers/task_processing.py +++ b/nc_py_api/ex_app/providers/task_processing.py @@ -3,35 +3,104 @@ import contextlib import dataclasses import typing +from enum import IntEnum + +from pydantic import RootModel +from pydantic.dataclasses import dataclass from ..._exceptions import NextcloudException, NextcloudExceptionNotFound -from ..._misc import clear_from_params_empty, require_capabilities +from ..._misc import require_capabilities from ..._session import AsyncNcSessionApp, NcSessionApp _EP_SUFFIX: str = "ai_provider/task_processing" -@dataclasses.dataclass -class TaskProcessingProvider: - """TaskProcessing provider description.""" +class ShapeType(IntEnum): + """Enum for shape types.""" + + NUMBER = 0 + TEXT = 1 + IMAGE = 2 + AUDIO = 3 + VIDEO = 4 + FILE = 5 + ENUM = 6 + LISTOFNUMBERS = 10 + LISTOFTEXTS = 11 + LISTOFIMAGES = 12 + LISTOFAUDIOS = 13 + LISTOFVIDEOS = 14 + LISTOFFILES = 15 + + +@dataclass +class ShapeEnumValue: + """Data object for input output shape enum slot value.""" + + name: str + """Name of the enum slot value which will be displayed in the UI""" + value: str + """Value of the enum slot value""" + + +@dataclass +class ShapeDescriptor: + """Data object for input output shape entries.""" - def __init__(self, raw_data: dict): - self._raw_data = raw_data + name: str + """Name of the shape entry""" + description: str + """Description of the shape entry""" + shape_type: ShapeType + """Type of the shape entry""" - @property - def name(self) -> str: - """Unique ID for the provider.""" - return self._raw_data["name"] - @property - def display_name(self) -> str: - """Providers display name.""" - return self._raw_data["display_name"] +@dataclass +class TaskType: + """TaskType description for the provider.""" - @property - def task_type(self) -> str: - """The TaskType provided by this provider.""" - return self._raw_data["task_type"] + id: str + """The unique ID for the task type.""" + name: str + """The localized name of the task type.""" + description: str + """The localized description of the task type.""" + input_shape: list[ShapeDescriptor] + """The input shape of the task.""" + output_shape: list[ShapeDescriptor] + """The output shape of the task.""" + + +@dataclass +class TaskProcessingProvider: + """TaskProcessing provider description.""" + + # pylint: disable=too-many-instance-attributes + + id: str + """Unique ID for the provider.""" + name: str + """The localized name of this provider""" + task_type: str + """The TaskType provided by this provider.""" + expected_runtime: int = dataclasses.field(default=0) + """Expected runtime of the task in seconds.""" + optional_input_shape: list[ShapeDescriptor] = dataclasses.field(default_factory=list) + """Optional input shape of the task.""" + optional_output_shape: list[ShapeDescriptor] = dataclasses.field(default_factory=list) + """Optional output shape of the task.""" + input_shape_enum_values: dict[str, list[ShapeEnumValue]] = dataclasses.field(default_factory=dict) + """The option dict for each input shape ENUM slot.""" + input_shape_defaults: dict[str, str | int | float] = dataclasses.field(default_factory=dict) + """The default values for input shape slots.""" + optional_input_shape_enum_values: dict[str, list[ShapeEnumValue]] = dataclasses.field(default_factory=dict) + """The option list for each optional input shape ENUM slot.""" + optional_input_shape_defaults: dict[str, str | int | float] = dataclasses.field(default_factory=dict) + """The default values for optional input shape slots.""" + output_shape_enum_values: dict[str, list[ShapeEnumValue]] = dataclasses.field(default_factory=dict) + """The option list for each output shape ENUM slot.""" + optional_output_shape_enum_values: dict[str, list[ShapeEnumValue]] = dataclasses.field(default_factory=dict) + """The option list for each optional output shape ENUM slot.""" def __repr__(self): return f"<{self.__class__.__name__} name={self.name}, type={self.task_type}>" @@ -44,17 +113,16 @@ def __init__(self, session: NcSessionApp): self._session = session def register( - self, name: str, display_name: str, task_type: str, custom_task_type: dict[str, typing.Any] | None = None + self, + provider: TaskProcessingProvider, + custom_task_type: TaskType | None = None, ) -> None: """Registers or edit the TaskProcessing provider.""" require_capabilities("app_api", self._session.capabilities) params = { - "name": name, - "displayName": display_name, - "taskType": task_type, - "customTaskType": custom_task_type, + "provider": RootModel(provider).model_dump(), + **({"customTaskType": RootModel(custom_task_type).model_dump()} if custom_task_type else {}), } - clear_from_params_empty(["customTaskType"], params) self._session.ocs("POST", f"{self._session.ae_url}/{_EP_SUFFIX}", json=params) def unregister(self, name: str, not_fail=True) -> None: @@ -123,17 +191,16 @@ def __init__(self, session: AsyncNcSessionApp): self._session = session async def register( - self, name: str, display_name: str, task_type: str, custom_task_type: dict[str, typing.Any] | None = None + self, + provider: TaskProcessingProvider, + custom_task_type: TaskType | None = None, ) -> None: """Registers or edit the TaskProcessing provider.""" require_capabilities("app_api", await self._session.capabilities) params = { - "name": name, - "displayName": display_name, - "taskType": task_type, - "customTaskType": custom_task_type, + "provider": RootModel(provider).model_dump(), + **({"customTaskType": RootModel(custom_task_type).model_dump()} if custom_task_type else {}), } - clear_from_params_empty(["customTaskType"], params) await self._session.ocs("POST", f"{self._session.ae_url}/{_EP_SUFFIX}", json=params) async def unregister(self, name: str, not_fail=True) -> None: