diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 4485955a..b3ec30a3 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -803,7 +803,9 @@ def __init__( if name is not None and dynamic: raise RuntimeError("Cannot provide name and dynamic boolean") self._fn = fn - self._name = name + self._name = ( + name if name is not None else self._fn.__name__ if self._fn else None + ) self._dynamic = dynamic if self._fn is not None: # Only bother to assign the definition if we are given a function. The function is not provided when @@ -817,15 +819,8 @@ def __call__(self, fn: CallableSyncOrAsyncType): return self def _assign_defn(self) -> None: - chosen_name = ( - self._name - if self._name is not None - else self._fn.__name__ - if self._fn - else None - ) assert self._fn is not None - self._defn = _UpdateDefinition(name=chosen_name, fn=self._fn, is_method=True) + self._defn = _UpdateDefinition(name=self._name, fn=self._fn, is_method=True) def validator(self, fn: Callable[..., None]): """Decorator for a workflow update validator method. Apply this decorator to a function to have it run before @@ -1240,9 +1235,9 @@ async def with_object(*args, **kwargs) -> Any: def _assert_dynamic_handler_args( fn: Callable, arg_types: Optional[List[Type]], is_method: bool ) -> bool: - # Dynamic query/signal must have three args: self, name, and - # Sequence[RawValue]. An older form accepted varargs for the third param so - # we will too (but will warn in the signal/query code) + # Dynamic query/signal/update must have three args: self, name, and + # Sequence[RawValue]. An older form accepted varargs for the third param for signals/queries so + # we will too (but will warn in the signal/query code). params = list(inspect.signature(fn).parameters.values()) total_expected_params = 3 if is_method else 2 if ( diff --git a/tests/test_workflow.py b/tests/test_workflow.py index e9851445..0b7ae6d2 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -19,6 +19,10 @@ def base_signal(self): def base_query(self): pass + @workflow.update + def base_update(self): + pass + @workflow.defn(name="workflow-custom") class GoodDefn(GoodDefnBase): @@ -50,6 +54,18 @@ def query2(self): def query3(self, name: str, args: Sequence[RawValue]): pass + @workflow.update + def update1(self): + pass + + @workflow.update(name="update-custom") + def update2(self): + pass + + @workflow.update(dynamic=True) + def update3(self, name: str, args: Sequence[RawValue]): + pass + def test_workflow_defn_good(): # Although the API is internal, we want to check the literal definition just @@ -87,8 +103,21 @@ def test_workflow_defn_good(): name="base_query", fn=GoodDefnBase.base_query, is_method=True ), }, - # TODO: Add - updates={}, + # Since updates use class-based decorators we need to pass the inner _fn for the fn param + updates={ + "update1": workflow._UpdateDefinition( + name="update1", fn=GoodDefn.update1._fn, is_method=True + ), + "update-custom": workflow._UpdateDefinition( + name="update-custom", fn=GoodDefn.update2._fn, is_method=True + ), + None: workflow._UpdateDefinition( + name=None, fn=GoodDefn.update3._fn, is_method=True + ), + "base_update": workflow._UpdateDefinition( + name="base_update", fn=GoodDefnBase.base_update._fn, is_method=True + ), + }, sandboxed=True, )