Skip to content

Commit

Permalink
Finally! Type system defeated!
Browse files Browse the repository at this point in the history
  • Loading branch information
Sushisource committed Oct 19, 2023
1 parent 668532d commit 3a9dc2c
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 59 deletions.
65 changes: 61 additions & 4 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1615,6 +1615,62 @@ async def terminate(
)
)

# Overload for no-param update
@overload
async def update(
self,
update: temporalio.workflow.UpdateMethodMultiArg[[SelfType], LocalReturnType],
*,
id: Optional[str] = None,
rpc_metadata: Mapping[str, str] = {},
rpc_timeout: Optional[timedelta] = None,
) -> LocalReturnType:
...

# Overload for single-param update
@overload
async def update(
self,
update: temporalio.workflow.UpdateMethodMultiArg[
[SelfType, ParamType], LocalReturnType
],
arg: ParamType,
*,
id: Optional[str] = None,
rpc_metadata: Mapping[str, str] = {},
rpc_timeout: Optional[timedelta] = None,
) -> LocalReturnType:
...

@overload
async def update(
self,
update: temporalio.workflow.UpdateMethodMultiArg[
MultiParamSpec, LocalReturnType
],
*,
args: MultiParamSpec.args,
id: Optional[str] = None,
rpc_metadata: Mapping[str, str] = {},
rpc_timeout: Optional[timedelta] = None,
) -> LocalReturnType:
...

# Overload for string-name update
@overload
async def update(
self,
update: str,
arg: Any = temporalio.common._arg_unset,
*,
args: Sequence[Any] = [],
id: Optional[str] = None,
result_type: Optional[Type] = None,
rpc_metadata: Mapping[str, str] = {},
rpc_timeout: Optional[timedelta] = None,
) -> Any:
...

async def update(
self,
update: Union[str, Callable],
Expand Down Expand Up @@ -1701,15 +1757,16 @@ async def start_update(
"""
update_name: str
ret_type = result_type
if callable(update):
if not isinstance(update, temporalio.workflow.update):
if isinstance(update, temporalio.workflow.UpdateMethodMultiArg):
defn = update._defn
if not defn:
raise RuntimeError(
f"Update definition not found on {update.__qualname__}, "
"is it decorated with @workflow.update?"
)
defn = update._defn
if not defn.name:
elif not defn.name:
raise RuntimeError("Cannot invoke dynamic update definition")
# TODO(cretz): Check count/type of args at runtime?
update_name = defn.name
ret_type = defn.ret_type
else:
Expand Down
6 changes: 5 additions & 1 deletion temporalio/worker/_workflow_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ async def run_update() -> None:
job.input,
defn.name,
defn.arg_types,
False,
defn.dynamic_vararg,
)
handler_input = HandleUpdateInput(
# TODO: update id vs proto instance id
Expand Down Expand Up @@ -1013,6 +1013,10 @@ def workflow_set_update_handler(
if validator is not None:
defn.set_validator(validator)
self._updates[name] = defn
if defn.dynamic_vararg:
raise RuntimeError(
"Dynamic updates do not support a vararg third param, use Sequence[RawValue]",
)
else:
self._updates.pop(name, None)

Expand Down
137 changes: 95 additions & 42 deletions temporalio/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,13 @@
overload,
)

from typing_extensions import Concatenate, Literal, TypedDict
from typing_extensions import (
Concatenate,
Literal,
Protocol,
TypedDict,
runtime_checkable,
)

import temporalio.api.common.v1
import temporalio.bridge.proto.child_workflow
Expand Down Expand Up @@ -64,6 +70,7 @@
MethodSyncSingleParam,
MultiParamSpec,
ParamType,
ProtocolReturnType,
ReturnType,
SelfType,
)
Expand Down Expand Up @@ -764,8 +771,64 @@ def time_ns() -> int:
return _Runtime.current().workflow_time_ns()


# noinspection PyPep8Naming
class update(object):
# Needs to be defined here to avoid a circular import
@runtime_checkable
class UpdateMethodMultiArg(Protocol[MultiParamSpec, ProtocolReturnType]):
"""Decorated workflow update functions implement this."""

_defn: temporalio.workflow._UpdateDefinition

def __call__(
self, *args: MultiParamSpec.args, **kwargs: MultiParamSpec.kwargs
) -> Union[ProtocolReturnType, Awaitable[ProtocolReturnType]]:
"""Generic callable type callback."""
...

def validator(self, vfunc: Callable[MultiParamSpec, None]) -> None:
"""Use to decorate a function to validate the arguments passed to the update handler."""
...


@overload
def update(
fn: Callable[MultiParamSpec, Awaitable[ReturnType]]
) -> UpdateMethodMultiArg[MultiParamSpec, ReturnType]:
...


@overload
def update(
fn: Callable[MultiParamSpec, ReturnType]
) -> UpdateMethodMultiArg[MultiParamSpec, ReturnType]:
...


@overload
def update(
*, name: str
) -> Callable[
[Callable[MultiParamSpec, ReturnType]],
UpdateMethodMultiArg[MultiParamSpec, ReturnType],
]:
...


@overload
def update(
*, dynamic: Literal[True]
) -> Callable[
[Callable[MultiParamSpec, ReturnType]],
UpdateMethodMultiArg[MultiParamSpec, ReturnType],
]:
...


def update(
fn: Optional[CallableSyncOrAsyncType] = None,
*,
name: Optional[str] = None,
dynamic: Optional[bool] = False,
):
"""Decorator for a workflow update handler method.
This is set on any async or non-async method that you wish to be called upon
Expand All @@ -791,44 +854,33 @@ class update(object):
present.
"""

def __init__(
self,
fn: Optional[CallableSyncOrAsyncType] = None,
*,
name: Optional[str] = None,
dynamic: Optional[bool] = False,
):
"""See :py:class:`update`."""
if name is not None or dynamic:
if name is not None and dynamic:
raise RuntimeError("Cannot provide name and dynamic boolean")
self._fn = fn
self._name = (
name if name is not None else self._fn.__name__ if self._fn else None
)
self._dynamic = dynamic
if self._fn is not None:
# Only bother to assign the definition if we are given a function. The function is not provided when
# extra arguments are specified - in that case, the __call__ method is invoked instead.
self._assign_defn()

def __call__(self, fn: CallableSyncOrAsyncType):
"""Call the update decorator (as when passing optional arguments)."""
self._fn = fn
self._assign_defn()
return self

def _assign_defn(self) -> None:
assert self._fn is not None
self._defn = _UpdateDefinition(name=self._name, fn=self._fn, is_method=True)

def validator(self, fn: Callable[..., None]):
"""Decorator for a workflow update validator method. Apply this decorator to a function to have it run before
the update handler. If it throws an error, the update will be rejected. The validator must not mutate workflow
state at all, and cannot call workflow functions which would schedule new commands (ex: starting an
activity).
"""
self._defn.set_validator(fn)
def with_name(
name: Optional[str], fn: CallableSyncOrAsyncType
) -> CallableSyncOrAsyncType:
defn = _UpdateDefinition(name=name, fn=fn, is_method=True)
if defn.dynamic_vararg:
raise RuntimeError(
"Dynamic updates do not support a vararg third param, use Sequence[RawValue]",
)
setattr(fn, "_defn", defn)
setattr(fn, "validator", partial(_update_validator, defn))
return fn

if name is not None or dynamic:
if name is not None and dynamic:
raise RuntimeError("Cannot provide name and dynamic boolean")
return partial(with_name, name)
if fn is None:
raise RuntimeError("Cannot create update without function or name or dynamic")
return with_name(fn.__name__, fn)


def _update_validator(
update_def: _UpdateDefinition, fn: Optional[Callable[..., None]] = None
):
"""Decorator for a workflow update validator method."""
if fn is not None:
update_def.set_validator(fn)


def upsert_search_attributes(attributes: temporalio.common.SearchAttributes) -> None:
Expand Down Expand Up @@ -1132,7 +1184,7 @@ def _apply_to_class(
)
else:
queries[query_defn.name] = query_defn
elif isinstance(member, update):
elif isinstance(member, UpdateMethodMultiArg):
update_defn = member._defn
if update_defn.name in updates:
defn_name = update_defn.name or "<dynamic>"
Expand Down Expand Up @@ -1350,6 +1402,7 @@ class _UpdateDefinition:
arg_types: Optional[List[Type]] = None
ret_type: Optional[Type] = None
validator: Optional[Callable[..., None]] = None
dynamic_vararg: bool = False

def __post_init__(self) -> None:
if self.arg_types is None:
Expand Down
9 changes: 4 additions & 5 deletions tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,19 +103,18 @@ def test_workflow_defn_good():
name="base_query", fn=GoodDefnBase.base_query, is_method=True
),
},
# Since updates use class-based decorators we need to pass the inner _fn for the fn param
updates={
"update1": workflow._UpdateDefinition(
name="update1", fn=GoodDefn.update1._fn, is_method=True
name="update1", fn=GoodDefn.update1, is_method=True
),
"update-custom": workflow._UpdateDefinition(
name="update-custom", fn=GoodDefn.update2._fn, is_method=True
name="update-custom", fn=GoodDefn.update2, is_method=True
),
None: workflow._UpdateDefinition(
name=None, fn=GoodDefn.update3._fn, is_method=True
name=None, fn=GoodDefn.update3, is_method=True
),
"base_update": workflow._UpdateDefinition(
name="base_update", fn=GoodDefnBase.base_update._fn, is_method=True
name="base_update", fn=GoodDefnBase.base_update, is_method=True
),
},
sandboxed=True,
Expand Down
13 changes: 6 additions & 7 deletions tests/worker/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3553,6 +3553,10 @@ async def runs_activity(self, name: str) -> str:
await act
return "done"

@workflow.update(name="renamed")
async def async_named(self) -> str:
return "named"

@workflow.update
async def bad_validator(self) -> str:
return "done"
Expand All @@ -3579,10 +3583,6 @@ def dynavalidator(name: str, _args: Sequence[RawValue]) -> None:
workflow.set_dynamic_update_handler(dynahandler, validator=dynavalidator)
return "set"

@workflow.update(name="name_override")
async def not_the_name(self) -> str:
return "name_overridden"


async def test_workflow_update_handlers_happy(client: Client):
async with new_worker(
Expand Down Expand Up @@ -3611,9 +3611,8 @@ async def test_workflow_update_handlers_happy(client: Client):
await handle.update(UpdateHandlersWorkflow.set_dynamic)
assert "dynahandler - made_up" == await handle.update("made_up")

assert "name_overridden" == await handle.update(
UpdateHandlersWorkflow.not_the_name
)
# Name overload
assert "named" == await handle.update(UpdateHandlersWorkflow.async_named)


async def test_workflow_update_handlers_unhappy(client: Client):
Expand Down

0 comments on commit 3a9dc2c

Please sign in to comment.