Skip to content

Commit

Permalink
Add update definitions to workflow definition test
Browse files Browse the repository at this point in the history
  • Loading branch information
Sushisource committed Oct 17, 2023
1 parent 5daec88 commit 668532d
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 14 deletions.
19 changes: 7 additions & 12 deletions temporalio/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,9 @@ def __init__(
if name is not None and dynamic:
raise RuntimeError("Cannot provide name and dynamic boolean")
self._fn = fn
self._name = name
self._name = (
name if name is not None else self._fn.__name__ if self._fn else None
)
self._dynamic = dynamic
if self._fn is not None:
# Only bother to assign the definition if we are given a function. The function is not provided when
Expand All @@ -817,15 +819,8 @@ def __call__(self, fn: CallableSyncOrAsyncType):
return self

def _assign_defn(self) -> None:
chosen_name = (
self._name
if self._name is not None
else self._fn.__name__
if self._fn
else None
)
assert self._fn is not None
self._defn = _UpdateDefinition(name=chosen_name, fn=self._fn, is_method=True)
self._defn = _UpdateDefinition(name=self._name, fn=self._fn, is_method=True)

def validator(self, fn: Callable[..., None]):
"""Decorator for a workflow update validator method. Apply this decorator to a function to have it run before
Expand Down Expand Up @@ -1240,9 +1235,9 @@ async def with_object(*args, **kwargs) -> Any:
def _assert_dynamic_handler_args(
fn: Callable, arg_types: Optional[List[Type]], is_method: bool
) -> bool:
# Dynamic query/signal must have three args: self, name, and
# Sequence[RawValue]. An older form accepted varargs for the third param so
# we will too (but will warn in the signal/query code)
# Dynamic query/signal/update must have three args: self, name, and
# Sequence[RawValue]. An older form accepted varargs for the third param for signals/queries so
# we will too (but will warn in the signal/query code).
params = list(inspect.signature(fn).parameters.values())
total_expected_params = 3 if is_method else 2
if (
Expand Down
33 changes: 31 additions & 2 deletions tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ def base_signal(self):
def base_query(self):
pass

@workflow.update
def base_update(self):
pass


@workflow.defn(name="workflow-custom")
class GoodDefn(GoodDefnBase):
Expand Down Expand Up @@ -50,6 +54,18 @@ def query2(self):
def query3(self, name: str, args: Sequence[RawValue]):
pass

@workflow.update
def update1(self):
pass

@workflow.update(name="update-custom")
def update2(self):
pass

@workflow.update(dynamic=True)
def update3(self, name: str, args: Sequence[RawValue]):
pass


def test_workflow_defn_good():
# Although the API is internal, we want to check the literal definition just
Expand Down Expand Up @@ -87,8 +103,21 @@ def test_workflow_defn_good():
name="base_query", fn=GoodDefnBase.base_query, is_method=True
),
},
# TODO: Add
updates={},
# Since updates use class-based decorators we need to pass the inner _fn for the fn param
updates={
"update1": workflow._UpdateDefinition(
name="update1", fn=GoodDefn.update1._fn, is_method=True
),
"update-custom": workflow._UpdateDefinition(
name="update-custom", fn=GoodDefn.update2._fn, is_method=True
),
None: workflow._UpdateDefinition(
name=None, fn=GoodDefn.update3._fn, is_method=True
),
"base_update": workflow._UpdateDefinition(
name="base_update", fn=GoodDefnBase.base_update._fn, is_method=True
),
},
sandboxed=True,
)

Expand Down

0 comments on commit 668532d

Please sign in to comment.