diff --git a/temporalio/client.py b/temporalio/client.py index 498c35a0..d2d2a2cb 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -1615,6 +1615,62 @@ async def terminate( ) ) + # Overload for no-param update + @overload + async def update( + self, + update: temporalio.workflow.UpdateMethodMultiArg[[SelfType], LocalReturnType], + *, + id: Optional[str] = None, + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + ) -> LocalReturnType: + ... + + # Overload for single-param update + @overload + async def update( + self, + update: temporalio.workflow.UpdateMethodMultiArg[ + [SelfType, ParamType], LocalReturnType + ], + arg: ParamType, + *, + id: Optional[str] = None, + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + ) -> LocalReturnType: + ... + + @overload + async def update( + self, + update: temporalio.workflow.UpdateMethodMultiArg[ + MultiParamSpec, LocalReturnType + ], + *, + args: MultiParamSpec.args, + id: Optional[str] = None, + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + ) -> LocalReturnType: + ... + + # Overload for string-name update + @overload + async def update( + self, + update: str, + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: Optional[str] = None, + result_type: Optional[Type] = None, + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + ) -> Any: + ... + async def update( self, update: Union[str, Callable], @@ -1701,15 +1757,16 @@ async def start_update( """ update_name: str ret_type = result_type - if callable(update): - if not isinstance(update, temporalio.workflow.update): + if isinstance(update, temporalio.workflow.UpdateMethodMultiArg): + defn = update._defn + if not defn: raise RuntimeError( f"Update definition not found on {update.__qualname__}, " "is it decorated with @workflow.update?" ) - defn = update._defn - if not defn.name: + elif not defn.name: raise RuntimeError("Cannot invoke dynamic update definition") + # TODO(cretz): Check count/type of args at runtime? update_name = defn.name ret_type = defn.ret_type else: diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index e244e7c1..f25af44d 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -439,7 +439,7 @@ async def run_update() -> None: job.input, defn.name, defn.arg_types, - False, + defn.dynamic_vararg, ) handler_input = HandleUpdateInput( # TODO: update id vs proto instance id @@ -1013,6 +1013,10 @@ def workflow_set_update_handler( if validator is not None: defn.set_validator(validator) self._updates[name] = defn + if defn.dynamic_vararg: + raise RuntimeError( + "Dynamic updates do not support a vararg third param, use Sequence[RawValue]", + ) else: self._updates.pop(name, None) diff --git a/temporalio/workflow.py b/temporalio/workflow.py index b3ec30a3..c9ad2a23 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -36,7 +36,13 @@ overload, ) -from typing_extensions import Concatenate, Literal, TypedDict +from typing_extensions import ( + Concatenate, + Literal, + Protocol, + TypedDict, + runtime_checkable, +) import temporalio.api.common.v1 import temporalio.bridge.proto.child_workflow @@ -64,6 +70,7 @@ MethodSyncSingleParam, MultiParamSpec, ParamType, + ProtocolReturnType, ReturnType, SelfType, ) @@ -764,8 +771,64 @@ def time_ns() -> int: return _Runtime.current().workflow_time_ns() -# noinspection PyPep8Naming -class update(object): +# Needs to be defined here to avoid a circular import +@runtime_checkable +class UpdateMethodMultiArg(Protocol[MultiParamSpec, ProtocolReturnType]): + """Decorated workflow update functions implement this.""" + + _defn: temporalio.workflow._UpdateDefinition + + def __call__( + self, *args: MultiParamSpec.args, **kwargs: MultiParamSpec.kwargs + ) -> Union[ProtocolReturnType, Awaitable[ProtocolReturnType]]: + """Generic callable type callback.""" + ... + + def validator(self, vfunc: Callable[MultiParamSpec, None]) -> None: + """Use to decorate a function to validate the arguments passed to the update handler.""" + ... + + +@overload +def update( + fn: Callable[MultiParamSpec, Awaitable[ReturnType]] +) -> UpdateMethodMultiArg[MultiParamSpec, ReturnType]: + ... + + +@overload +def update( + fn: Callable[MultiParamSpec, ReturnType] +) -> UpdateMethodMultiArg[MultiParamSpec, ReturnType]: + ... + + +@overload +def update( + *, name: str +) -> Callable[ + [Callable[MultiParamSpec, ReturnType]], + UpdateMethodMultiArg[MultiParamSpec, ReturnType], +]: + ... + + +@overload +def update( + *, dynamic: Literal[True] +) -> Callable[ + [Callable[MultiParamSpec, ReturnType]], + UpdateMethodMultiArg[MultiParamSpec, ReturnType], +]: + ... + + +def update( + fn: Optional[CallableSyncOrAsyncType] = None, + *, + name: Optional[str] = None, + dynamic: Optional[bool] = False, +): """Decorator for a workflow update handler method. This is set on any async or non-async method that you wish to be called upon @@ -791,44 +854,33 @@ class update(object): present. """ - def __init__( - self, - fn: Optional[CallableSyncOrAsyncType] = None, - *, - name: Optional[str] = None, - dynamic: Optional[bool] = False, - ): - """See :py:class:`update`.""" - if name is not None or dynamic: - if name is not None and dynamic: - raise RuntimeError("Cannot provide name and dynamic boolean") - self._fn = fn - self._name = ( - name if name is not None else self._fn.__name__ if self._fn else None - ) - self._dynamic = dynamic - if self._fn is not None: - # Only bother to assign the definition if we are given a function. The function is not provided when - # extra arguments are specified - in that case, the __call__ method is invoked instead. - self._assign_defn() - - def __call__(self, fn: CallableSyncOrAsyncType): - """Call the update decorator (as when passing optional arguments).""" - self._fn = fn - self._assign_defn() - return self - - def _assign_defn(self) -> None: - assert self._fn is not None - self._defn = _UpdateDefinition(name=self._name, fn=self._fn, is_method=True) - - def validator(self, fn: Callable[..., None]): - """Decorator for a workflow update validator method. Apply this decorator to a function to have it run before - the update handler. If it throws an error, the update will be rejected. The validator must not mutate workflow - state at all, and cannot call workflow functions which would schedule new commands (ex: starting an - activity). - """ - self._defn.set_validator(fn) + def with_name( + name: Optional[str], fn: CallableSyncOrAsyncType + ) -> CallableSyncOrAsyncType: + defn = _UpdateDefinition(name=name, fn=fn, is_method=True) + if defn.dynamic_vararg: + raise RuntimeError( + "Dynamic updates do not support a vararg third param, use Sequence[RawValue]", + ) + setattr(fn, "_defn", defn) + setattr(fn, "validator", partial(_update_validator, defn)) + return fn + + if name is not None or dynamic: + if name is not None and dynamic: + raise RuntimeError("Cannot provide name and dynamic boolean") + return partial(with_name, name) + if fn is None: + raise RuntimeError("Cannot create update without function or name or dynamic") + return with_name(fn.__name__, fn) + + +def _update_validator( + update_def: _UpdateDefinition, fn: Optional[Callable[..., None]] = None +): + """Decorator for a workflow update validator method.""" + if fn is not None: + update_def.set_validator(fn) def upsert_search_attributes(attributes: temporalio.common.SearchAttributes) -> None: @@ -1132,7 +1184,7 @@ def _apply_to_class( ) else: queries[query_defn.name] = query_defn - elif isinstance(member, update): + elif isinstance(member, UpdateMethodMultiArg): update_defn = member._defn if update_defn.name in updates: defn_name = update_defn.name or "" @@ -1350,6 +1402,7 @@ class _UpdateDefinition: arg_types: Optional[List[Type]] = None ret_type: Optional[Type] = None validator: Optional[Callable[..., None]] = None + dynamic_vararg: bool = False def __post_init__(self) -> None: if self.arg_types is None: diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 0b7ae6d2..b1f1378a 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -103,19 +103,18 @@ def test_workflow_defn_good(): name="base_query", fn=GoodDefnBase.base_query, is_method=True ), }, - # Since updates use class-based decorators we need to pass the inner _fn for the fn param updates={ "update1": workflow._UpdateDefinition( - name="update1", fn=GoodDefn.update1._fn, is_method=True + name="update1", fn=GoodDefn.update1, is_method=True ), "update-custom": workflow._UpdateDefinition( - name="update-custom", fn=GoodDefn.update2._fn, is_method=True + name="update-custom", fn=GoodDefn.update2, is_method=True ), None: workflow._UpdateDefinition( - name=None, fn=GoodDefn.update3._fn, is_method=True + name=None, fn=GoodDefn.update3, is_method=True ), "base_update": workflow._UpdateDefinition( - name="base_update", fn=GoodDefnBase.base_update._fn, is_method=True + name="base_update", fn=GoodDefnBase.base_update, is_method=True ), }, sandboxed=True, diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 12770940..51846994 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -3553,6 +3553,10 @@ async def runs_activity(self, name: str) -> str: await act return "done" + @workflow.update(name="renamed") + async def async_named(self) -> str: + return "named" + @workflow.update async def bad_validator(self) -> str: return "done" @@ -3579,10 +3583,6 @@ def dynavalidator(name: str, _args: Sequence[RawValue]) -> None: workflow.set_dynamic_update_handler(dynahandler, validator=dynavalidator) return "set" - @workflow.update(name="name_override") - async def not_the_name(self) -> str: - return "name_overridden" - async def test_workflow_update_handlers_happy(client: Client): async with new_worker( @@ -3611,9 +3611,8 @@ async def test_workflow_update_handlers_happy(client: Client): await handle.update(UpdateHandlersWorkflow.set_dynamic) assert "dynahandler - made_up" == await handle.update("made_up") - assert "name_overridden" == await handle.update( - UpdateHandlersWorkflow.not_the_name - ) + # Name overload + assert "named" == await handle.update(UpdateHandlersWorkflow.async_named) async def test_workflow_update_handlers_unhappy(client: Client):