diff --git a/temporalio/client.py b/temporalio/client.py index ccde4f80..498c35a0 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -1702,15 +1702,14 @@ async def start_update( update_name: str ret_type = result_type if callable(update): - defn = temporalio.workflow._UpdateDefinition.from_fn(update) - if not defn: + if not isinstance(update, temporalio.workflow.update): raise RuntimeError( f"Update definition not found on {update.__qualname__}, " "is it decorated with @workflow.update?" ) - elif not defn.name: + defn = update._defn + if not defn.name: raise RuntimeError("Cannot invoke dynamic update definition") - # TODO(cretz): Check count/type of args at runtime? update_name = defn.name ret_type = defn.ret_type else: diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index f25af44d..e244e7c1 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -439,7 +439,7 @@ async def run_update() -> None: job.input, defn.name, defn.arg_types, - defn.dynamic_vararg, + False, ) handler_input = HandleUpdateInput( # TODO: update id vs proto instance id @@ -1013,10 +1013,6 @@ def workflow_set_update_handler( if validator is not None: defn.set_validator(validator) self._updates[name] = defn - if defn.dynamic_vararg: - raise RuntimeError( - "Dynamic updates do not support a vararg third param, use Sequence[RawValue]", - ) else: self._updates.pop(name, None) diff --git a/temporalio/workflow.py b/temporalio/workflow.py index f3f55907..4485955a 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -764,12 +764,8 @@ def time_ns() -> int: return _Runtime.current().workflow_time_ns() -def update( - fn: Optional[CallableSyncOrAsyncType] = None, - *, - name: Optional[str] = None, - dynamic: Optional[bool] = False, -): +# noinspection PyPep8Naming +class update(object): """Decorator for a workflow update handler method. This is set on any async or non-async method that you wish to be called upon @@ -795,33 +791,49 @@ def update( present. """ - def with_name( - name: Optional[str], fn: CallableSyncOrAsyncType - ) -> CallableSyncOrAsyncType: - defn = _UpdateDefinition(name=name, fn=fn, is_method=True) - if defn.dynamic_vararg: - raise RuntimeError( - "Dynamic updates do not support a vararg third param, use Sequence[RawValue]", - ) - setattr(fn, "__temporal_update_definition", defn) - setattr(fn, "validator", partial(_update_validator, defn)) - return fn - - if name is not None or dynamic: - if name is not None and dynamic: - raise RuntimeError("Cannot provide name and dynamic boolean") - return partial(with_name, name) - if fn is None: - raise RuntimeError("Cannot create update without function or name or dynamic") - return with_name(fn.__name__, fn) - - -def _update_validator( - update_def: _UpdateDefinition, fn: Optional[Callable[..., None]] = None -): - """Decorator for a workflow update validator method.""" - if fn is not None: - update_def.set_validator(fn) + def __init__( + self, + fn: Optional[CallableSyncOrAsyncType] = None, + *, + name: Optional[str] = None, + dynamic: Optional[bool] = False, + ): + """See :py:class:`update`.""" + if name is not None or dynamic: + if name is not None and dynamic: + raise RuntimeError("Cannot provide name and dynamic boolean") + self._fn = fn + self._name = name + self._dynamic = dynamic + if self._fn is not None: + # Only bother to assign the definition if we are given a function. The function is not provided when + # extra arguments are specified - in that case, the __call__ method is invoked instead. + self._assign_defn() + + def __call__(self, fn: CallableSyncOrAsyncType): + """Call the update decorator (as when passing optional arguments).""" + self._fn = fn + self._assign_defn() + return self + + def _assign_defn(self) -> None: + chosen_name = ( + self._name + if self._name is not None + else self._fn.__name__ + if self._fn + else None + ) + assert self._fn is not None + self._defn = _UpdateDefinition(name=chosen_name, fn=self._fn, is_method=True) + + def validator(self, fn: Callable[..., None]): + """Decorator for a workflow update validator method. Apply this decorator to a function to have it run before + the update handler. If it throws an error, the update will be rejected. The validator must not mutate workflow + state at all, and cannot call workflow functions which would schedule new commands (ex: starting an + activity). + """ + self._defn.set_validator(fn) def upsert_search_attributes(attributes: temporalio.common.SearchAttributes) -> None: @@ -1125,10 +1137,8 @@ def _apply_to_class( ) else: queries[query_defn.name] = query_defn - elif hasattr(member, "__temporal_update_definition"): - update_defn = cast( - _UpdateDefinition, getattr(member, "__temporal_update_definition") - ) + elif isinstance(member, update): + update_defn = member._defn if update_defn.name in updates: defn_name = update_defn.name or "" issues.append( @@ -1345,38 +1355,16 @@ class _UpdateDefinition: arg_types: Optional[List[Type]] = None ret_type: Optional[Type] = None validator: Optional[Callable[..., None]] = None - dynamic_vararg: bool = False - - @staticmethod - def from_fn(fn: Callable) -> Optional[_UpdateDefinition]: - return getattr(fn, "__temporal_update_definition", None) - - @staticmethod - def must_name_from_fn_or_str(update: Union[str, Callable]) -> str: - if callable(update): - defn = _UpdateDefinition.from_fn(update) - if not defn: - raise RuntimeError( - f"Update definition not found on {update.__qualname__}, " - "is it decorated with @workflow.update?" - ) - elif not defn.name: - raise RuntimeError("Cannot invoke dynamic update definition") - # TODO(cretz): Check count/type of args at runtime? - return defn.name - return str(update) def __post_init__(self) -> None: if self.arg_types is None: arg_types, ret_type = temporalio.common._type_hints_from_func(self.fn) - # If dynamic, assert it - if not self.name: - object.__setattr__( - self, - "dynamic_vararg", - not _assert_dynamic_handler_args( - self.fn, arg_types, self.is_method - ), + # Disallow dynamic varargs + if not self.name and not _assert_dynamic_handler_args( + self.fn, arg_types, self.is_method + ): + raise RuntimeError( + "Dynamic updates do not support a vararg third param, use Sequence[RawValue]", ) object.__setattr__(self, "arg_types", arg_types) object.__setattr__(self, "ret_type", ret_type) diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index f5d8a249..12770940 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -3579,6 +3579,10 @@ def dynavalidator(name: str, _args: Sequence[RawValue]) -> None: workflow.set_dynamic_update_handler(dynahandler, validator=dynavalidator) return "set" + @workflow.update(name="name_override") + async def not_the_name(self) -> str: + return "name_overridden" + async def test_workflow_update_handlers_happy(client: Client): async with new_worker( @@ -3607,6 +3611,10 @@ async def test_workflow_update_handlers_happy(client: Client): await handle.update(UpdateHandlersWorkflow.set_dynamic) assert "dynahandler - made_up" == await handle.update("made_up") + assert "name_overridden" == await handle.update( + UpdateHandlersWorkflow.not_the_name + ) + async def test_workflow_update_handlers_unhappy(client: Client): async with new_worker(client, UpdateHandlersWorkflow) as worker: