Skip to content

Commit de98531

Browse files
committed
Address various review comments
1 parent cc0871d commit de98531

File tree

4 files changed

+61
-69
lines changed

4 files changed

+61
-69
lines changed

temporalio/client.py

+41-58
Original file line numberDiff line numberDiff line change
@@ -1619,7 +1619,7 @@ async def terminate(
16191619
@overload
16201620
async def execute_update(
16211621
self,
1622-
update: temporalio.workflow.UpdateMethodMultiArg[[SelfType], LocalReturnType],
1622+
update: temporalio.workflow.UpdateMethodMultiParam[[SelfType], LocalReturnType],
16231623
*,
16241624
id: Optional[str] = None,
16251625
rpc_metadata: Mapping[str, str] = {},
@@ -1631,7 +1631,7 @@ async def execute_update(
16311631
@overload
16321632
async def execute_update(
16331633
self,
1634-
update: temporalio.workflow.UpdateMethodMultiArg[
1634+
update: temporalio.workflow.UpdateMethodMultiParam[
16351635
[SelfType, ParamType], LocalReturnType
16361636
],
16371637
arg: ParamType,
@@ -1645,7 +1645,7 @@ async def execute_update(
16451645
@overload
16461646
async def execute_update(
16471647
self,
1648-
update: temporalio.workflow.UpdateMethodMultiArg[
1648+
update: temporalio.workflow.UpdateMethodMultiParam[
16491649
MultiParamSpec, LocalReturnType
16501650
],
16511651
*,
@@ -1784,14 +1784,9 @@ async def _start_update(
17841784
) -> WorkflowUpdateHandle:
17851785
update_name: str
17861786
ret_type = result_type
1787-
if isinstance(update, temporalio.workflow.UpdateMethodMultiArg):
1787+
if isinstance(update, temporalio.workflow.UpdateMethodMultiParam):
17881788
defn = update._defn
1789-
if not defn:
1790-
raise RuntimeError(
1791-
f"Update definition not found on {update.__qualname__}, "
1792-
"is it decorated with @workflow.update?"
1793-
)
1794-
elif not defn.name:
1789+
if not defn.name:
17951790
raise RuntimeError("Cannot invoke dynamic update definition")
17961791
# TODO(cretz): Check count/type of args at runtime?
17971792
update_name = defn.name
@@ -1801,9 +1796,9 @@ async def _start_update(
18011796

18021797
return await self._client._impl.start_workflow_update(
18031798
UpdateWorkflowInput(
1804-
workflow_id=self._id,
1799+
id=self._id,
18051800
run_id=self._run_id,
1806-
update_id=id or "",
1801+
update_id=id,
18071802
update=update_name,
18081803
args=temporalio.common._arg_or_args(arg, args),
18091804
headers={},
@@ -3878,7 +3873,7 @@ def __init__(
38783873
name: str,
38793874
workflow_id: str,
38803875
*,
3881-
run_id: Optional[str] = None,
3876+
workflow_run_id: Optional[str] = None,
38823877
result_type: Optional[Type] = None,
38833878
):
38843879
"""Create a workflow update handle.
@@ -3890,29 +3885,29 @@ def __init__(
38903885
self._id = id
38913886
self._name = name
38923887
self._workflow_id = workflow_id
3893-
self._run_id = run_id
3888+
self._workflow_run_id = workflow_run_id
38943889
self._result_type = result_type
38953890
self._known_result: Optional[temporalio.api.update.v1.Outcome] = None
38963891

38973892
@property
38983893
def id(self) -> str:
3899-
"""ID of this Update request"""
3894+
"""ID of this Update request."""
39003895
return self._id
39013896

39023897
@property
39033898
def name(self) -> str:
3904-
"""The name of the Update being invoked"""
3899+
"""The name of the Update being invoked."""
39053900
return self._name
39063901

39073902
@property
39083903
def workflow_id(self) -> str:
3909-
"""The ID of the Workflow targeted by this Update"""
3904+
"""The ID of the Workflow targeted by this Update."""
39103905
return self._workflow_id
39113906

39123907
@property
3913-
def run_id(self) -> Optional[str]:
3914-
"""If specified, the specific run of the Workflow targeted by this Update"""
3915-
return self._run_id
3908+
def workflow_run_id(self) -> Optional[str]:
3909+
"""If specified, the specific run of the Workflow targeted by this Update."""
3910+
return self._workflow_run_id
39163911

39173912
async def result(
39183913
self,
@@ -3934,7 +3929,6 @@ async def result(
39343929
TimeoutError: The specified timeout was reached when waiting for the update result.
39353930
RPCError: Update result could not be fetched for some other reason.
39363931
"""
3937-
outcome: temporalio.api.update.v1.Outcome
39383932
if self._known_result is not None:
39393933
outcome = self._known_result
39403934
return await _update_outcome_to_result(
@@ -3944,23 +3938,20 @@ async def result(
39443938
self._client.data_converter,
39453939
self._result_type,
39463940
)
3947-
else:
3948-
return await self._client._impl.poll_workflow_update(
3949-
PollUpdateWorkflowInput(
3950-
self.workflow_id,
3951-
self.run_id,
3952-
self.id,
3953-
self.name,
3954-
timeout,
3955-
{},
3956-
self._result_type,
3957-
rpc_metadata,
3958-
rpc_timeout,
3959-
)
3960-
)
39613941

3962-
def _set_known_result(self, result: temporalio.api.update.v1.Outcome) -> None:
3963-
self._known_result = result
3942+
return await self._client._impl.poll_workflow_update(
3943+
PollUpdateWorkflowInput(
3944+
self.workflow_id,
3945+
self.workflow_run_id,
3946+
self.id,
3947+
self.name,
3948+
timeout,
3949+
{},
3950+
self._result_type,
3951+
rpc_metadata,
3952+
rpc_timeout,
3953+
)
3954+
)
39643955

39653956

39663957
class WorkflowFailureError(temporalio.exceptions.TemporalError):
@@ -4023,11 +4014,9 @@ def message(self) -> str:
40234014
class WorkflowUpdateFailedError(temporalio.exceptions.TemporalError):
40244015
"""Error that occurs when an update fails."""
40254016

4026-
def __init__(self, update_id: str, update_name: str, cause: BaseException) -> None:
4017+
def __init__(self, cause: BaseException) -> None:
40274018
"""Create workflow update failed error."""
40284019
super().__init__("Workflow update failed")
4029-
self._update_id = update_id
4030-
self._update_name = update_name
40314020
self.__cause__ = cause
40324021

40334022
@property
@@ -4171,9 +4160,9 @@ class TerminateWorkflowInput:
41714160
class UpdateWorkflowInput:
41724161
"""Input for :py:meth:`OutboundInterceptor.start_workflow_update`."""
41734162

4174-
workflow_id: str
4163+
id: str
41754164
run_id: Optional[str]
4176-
update_id: str
4165+
update_id: Optional[str]
41774166
update: str
41784167
args: Sequence[Any]
41794168
wait_for_stage: Optional[
@@ -4787,12 +4776,12 @@ async def start_workflow_update(
47874776
req = temporalio.api.workflowservice.v1.UpdateWorkflowExecutionRequest(
47884777
namespace=self._client.namespace,
47894778
workflow_execution=temporalio.api.common.v1.WorkflowExecution(
4790-
workflow_id=input.workflow_id,
4779+
workflow_id=input.id,
47914780
run_id=input.run_id or "",
47924781
),
47934782
request=temporalio.api.update.v1.Request(
47944783
meta=temporalio.api.update.v1.Meta(
4795-
update_id=input.update_id,
4784+
update_id=input.update_id or "",
47964785
identity=self._client.identity,
47974786
),
47984787
input=temporalio.api.update.v1.Input(
@@ -4814,23 +4803,19 @@ async def start_workflow_update(
48144803
req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout
48154804
)
48164805
except RPCError as err:
4817-
# If the status is INVALID_ARGUMENT, we can assume it's an update
4818-
# failed error
4819-
if err.status == RPCStatusCode.INVALID_ARGUMENT:
4820-
raise WorkflowUpdateFailedError(input.workflow_id, input.update, err)
4821-
else:
4822-
raise
4806+
raise
48234807

4808+
determined_id = resp.update_ref.update_id
48244809
update_handle = WorkflowUpdateHandle(
48254810
client=self._client,
4826-
id=input.update_id,
4811+
id=determined_id,
48274812
name=input.update,
4828-
workflow_id=input.workflow_id,
4829-
run_id=input.run_id,
4813+
workflow_id=input.id,
4814+
workflow_run_id=input.run_id,
48304815
result_type=input.ret_type,
48314816
)
48324817
if resp.HasField("outcome"):
4833-
update_handle._set_known_result(resp.outcome)
4818+
update_handle._known_result = resp.outcome
48344819

48354820
return update_handle
48364821

@@ -4869,8 +4854,8 @@ async def poll_loop():
48694854
input.ret_type,
48704855
)
48714856
except RPCError as err:
4872-
if err.status == RPCStatusCode.DEADLINE_EXCEEDED:
4873-
continue
4857+
if err.status != RPCStatusCode.DEADLINE_EXCEEDED:
4858+
raise
48744859

48754860
# Wait for at most the *overall* timeout
48764861
return await asyncio.wait_for(
@@ -5415,8 +5400,6 @@ async def _update_outcome_to_result(
54155400
) -> Any:
54165401
if outcome.HasField("failure"):
54175402
raise WorkflowUpdateFailedError(
5418-
id,
5419-
name,
54205403
await converter.decode_failure(outcome.failure.cause),
54215404
)
54225405
if not outcome.success.payloads:

temporalio/contrib/opentelemetry.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ async def start_workflow_update(
249249
) -> temporalio.client.WorkflowUpdateHandle:
250250
with self.root._start_as_current_span(
251251
f"StartWorkflowUpdate:{input.update}",
252-
attributes={"temporalWorkflowID": input.workflow_id},
252+
attributes={"temporalWorkflowID": input.id},
253253
input=input,
254254
kind=opentelemetry.trace.SpanKind.CLIENT,
255255
):

temporalio/workflow.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -773,7 +773,7 @@ def time_ns() -> int:
773773

774774
# Needs to be defined here to avoid a circular import
775775
@runtime_checkable
776-
class UpdateMethodMultiArg(Protocol[MultiParamSpec, ProtocolReturnType]):
776+
class UpdateMethodMultiParam(Protocol[MultiParamSpec, ProtocolReturnType]):
777777
"""Decorated workflow update functions implement this."""
778778

779779
_defn: temporalio.workflow._UpdateDefinition
@@ -784,22 +784,24 @@ def __call__(
784784
"""Generic callable type callback."""
785785
...
786786

787-
def validator(self, vfunc: Callable[MultiParamSpec, None]) -> None:
787+
def validator(
788+
self, vfunc: Callable[MultiParamSpec, None]
789+
) -> Callable[MultiParamSpec, None]:
788790
"""Use to decorate a function to validate the arguments passed to the update handler."""
789791
...
790792

791793

792794
@overload
793795
def update(
794796
fn: Callable[MultiParamSpec, Awaitable[ReturnType]]
795-
) -> UpdateMethodMultiArg[MultiParamSpec, ReturnType]:
797+
) -> UpdateMethodMultiParam[MultiParamSpec, ReturnType]:
796798
...
797799

798800

799801
@overload
800802
def update(
801803
fn: Callable[MultiParamSpec, ReturnType]
802-
) -> UpdateMethodMultiArg[MultiParamSpec, ReturnType]:
804+
) -> UpdateMethodMultiParam[MultiParamSpec, ReturnType]:
803805
...
804806

805807

@@ -808,7 +810,7 @@ def update(
808810
*, name: str
809811
) -> Callable[
810812
[Callable[MultiParamSpec, ReturnType]],
811-
UpdateMethodMultiArg[MultiParamSpec, ReturnType],
813+
UpdateMethodMultiParam[MultiParamSpec, ReturnType],
812814
]:
813815
...
814816

@@ -818,7 +820,7 @@ def update(
818820
*, dynamic: Literal[True]
819821
) -> Callable[
820822
[Callable[MultiParamSpec, ReturnType]],
821-
UpdateMethodMultiArg[MultiParamSpec, ReturnType],
823+
UpdateMethodMultiParam[MultiParamSpec, ReturnType],
822824
]:
823825
...
824826

@@ -880,10 +882,11 @@ def with_name(
880882

881883
def _update_validator(
882884
update_def: _UpdateDefinition, fn: Optional[Callable[..., None]] = None
883-
):
885+
) -> Optional[Callable[..., None]]:
884886
"""Decorator for a workflow update validator method."""
885887
if fn is not None:
886888
update_def.set_validator(fn)
889+
return fn
887890

888891

889892
def upsert_search_attributes(attributes: temporalio.common.SearchAttributes) -> None:
@@ -1187,7 +1190,7 @@ def _apply_to_class(
11871190
)
11881191
else:
11891192
queries[query_defn.name] = query_defn
1190-
elif isinstance(member, UpdateMethodMultiArg):
1193+
elif isinstance(member, UpdateMethodMultiParam):
11911194
update_defn = member._defn
11921195
if update_defn.name in updates:
11931196
defn_name = update_defn.name or "<dynamic>"
@@ -1230,7 +1233,7 @@ def _apply_to_class(
12301233
issues.append(
12311234
f"@workflow.query defined on {base_member.__qualname__} but not on the override"
12321235
)
1233-
elif isinstance(base_member, UpdateMethodMultiArg):
1236+
elif isinstance(base_member, UpdateMethodMultiParam):
12341237
update_defn = base_member._defn
12351238
if update_defn.name not in updates:
12361239
issues.append(

tests/worker/test_workflow.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -3592,9 +3592,10 @@ async def test_workflow_update_handlers_happy(client: Client, env: WorkflowEnvir
35923592
async with new_worker(
35933593
client, UpdateHandlersWorkflow, activities=[say_hello]
35943594
) as worker:
3595+
wf_id = f"update-handlers-workflow-{uuid.uuid4()}"
35953596
handle = await client.start_workflow(
35963597
UpdateHandlersWorkflow.run,
3597-
id=f"update-handlers-workflow-{uuid.uuid4()}",
3598+
id=wf_id,
35983599
task_queue=worker.task_queue,
35993600
)
36003601

@@ -3622,6 +3623,11 @@ async def test_workflow_update_handlers_happy(client: Client, env: WorkflowEnvir
36223623
UpdateHandlersWorkflow.async_named
36233624
)
36243625

3626+
# Get untyped handle
3627+
assert "val3" == await client.get_workflow_handle(wf_id).execute_update(
3628+
UpdateHandlersWorkflow.last_event, "val4"
3629+
)
3630+
36253631

36263632
async def test_workflow_update_handlers_unhappy(
36273633
client: Client, env: WorkflowEnvironment

0 commit comments

Comments
 (0)