From 020f6635c3586fcefe03b1fcc78a4df4d80ed1e7 Mon Sep 17 00:00:00 2001 From: Spencer Judge Date: Tue, 24 Oct 2023 15:25:08 -0700 Subject: [PATCH] Add start_update overloads / update handle type --- temporalio/client.py | 68 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 63 insertions(+), 5 deletions(-) diff --git a/temporalio/client.py b/temporalio/client.py index 2f989081..29f078b9 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -1642,6 +1642,7 @@ async def execute_update( ) -> LocalReturnType: ... + # Overload for multi-param update @overload async def execute_update( self, @@ -1721,6 +1722,63 @@ async def execute_update( ) return await handle.result() + # Overload for no-param start update + @overload + async def start_update( + self, + update: temporalio.workflow.UpdateMethodMultiParam[[SelfType], LocalReturnType], + *, + id: Optional[str] = None, + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + ) -> WorkflowUpdateHandle[LocalReturnType]: + ... + + # Overload for single-param start update + @overload + async def start_update( + self, + update: temporalio.workflow.UpdateMethodMultiParam[ + [SelfType, ParamType], LocalReturnType + ], + arg: ParamType, + *, + id: Optional[str] = None, + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + ) -> WorkflowUpdateHandle[LocalReturnType]: + ... + + # Overload for multi-param start update + @overload + async def start_update( + self, + update: temporalio.workflow.UpdateMethodMultiParam[ + MultiParamSpec, LocalReturnType + ], + *, + args: MultiParamSpec.args, + id: Optional[str] = None, + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + ) -> WorkflowUpdateHandle[LocalReturnType]: + ... + + # Overload for string-name start update + @overload + async def start_update( + self, + update: str, + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: Optional[str] = None, + result_type: Optional[Type] = None, + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + ) -> WorkflowUpdateHandle[Any]: + ... + async def start_update( self, update: Union[str, Callable], @@ -1731,7 +1789,7 @@ async def start_update( result_type: Optional[Type] = None, rpc_metadata: Mapping[str, str] = {}, rpc_timeout: Optional[timedelta] = None, - ) -> WorkflowUpdateHandle: + ) -> WorkflowUpdateHandle[Any]: """Send an update request to the workflow and return a handle to it. This will target the workflow with :py:attr:`run_id` if present. To use a @@ -1781,7 +1839,7 @@ async def _start_update( result_type: Optional[Type] = None, rpc_metadata: Mapping[str, str] = {}, rpc_timeout: Optional[timedelta] = None, - ) -> WorkflowUpdateHandle: + ) -> WorkflowUpdateHandle[Any]: update_name: str ret_type = result_type if isinstance(update, temporalio.workflow.UpdateMethodMultiParam): @@ -3863,7 +3921,7 @@ async def __anext__(self) -> ScheduleListDescription: return ret -class WorkflowUpdateHandle: +class WorkflowUpdateHandle(Generic[LocalReturnType]): """Handle for a workflow update execution request.""" def __init__( @@ -3915,7 +3973,7 @@ async def result( timeout: Optional[timedelta] = None, rpc_metadata: Mapping[str, str] = {}, rpc_timeout: Optional[timedelta] = None, - ) -> Any: + ) -> LocalReturnType: """Wait for and return the result of the update. The result may already be known in which case no call is made. Otherwise the result will be polled for until returned, or until the provided timeout is reached, if specified. @@ -4804,7 +4862,7 @@ async def start_workflow_update( raise determined_id = resp.update_ref.update_id - update_handle = WorkflowUpdateHandle( + update_handle: WorkflowUpdateHandle[Any] = WorkflowUpdateHandle( client=self._client, id=determined_id, name=input.update,