diff --git a/temporalio/client.py b/temporalio/client.py index 10d1d1f2..95c18143 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import copy import dataclasses import inspect @@ -1716,7 +1717,7 @@ async def start_update( return await self._client._impl.start_workflow_update( UpdateWorkflowInput( - id=self._id, + workflow_id=self._id, run_id=self._run_id, update_id=id or "", update=update_name, @@ -3829,31 +3830,41 @@ async def result( *, timeout: Optional[timedelta] = None, rpc_metadata: Mapping[str, str] = None, + rpc_timeout: Optional[timedelta] = None, ) -> Any: + """Wait for and return the result of the update. The result may already be known in which case no call is made. + Otherwise the result will be polled for until returned, or until the provided timeout is reached, if specified. + + Args: + timeout: Optional timeout specifying maximum wait time for the result. + rpc_metadata: Headers used on the RPC call. Keys here override client-level RPC metadata keys. + rpc_timeout: Optional RPC deadline to set for the RPC call. If this elapses, the poll is retried until the + overall timeout has been reached. + """ outcome: temporalio.api.update.v1.Outcome if self._known_result is not None: outcome = self._known_result - else: - # TODO: This - raise NotImplementedError - - if outcome.HasField("failure"): - raise WorkflowUpdateFailedError( + return await _update_outcome_to_result( + outcome, self.id, self.name, - await self._client.data_converter.decode_failure(outcome.failure.cause), + self._client.data_converter, + self._result_type, + ) + else: + return await self._client._impl.poll_workflow_update( + PollUpdateWorkflowInput( + self.workflow_id, + self.run_id, + self.id, + self.name, + timeout, + {}, + self._result_type, + rpc_metadata, + rpc_timeout, + ) ) - if not outcome.success.payloads: - return None - type_hints = [self._result_type] if self._result_type else None - results = await self._client.data_converter.decode( - outcome.success.payloads, type_hints - ) - if not results: - return None - elif len(results) > 1: - warnings.warn(f"Expected single update result, got {len(results)}") - return results[0] def _set_known_result(self, result: temporalio.api.update.v1.Outcome) -> None: self._known_result = result @@ -4065,9 +4076,9 @@ class TerminateWorkflowInput: @dataclass class UpdateWorkflowInput: - """Input for :py:meth:`OutboundInterceptor.update_workflow`.""" + """Input for :py:meth:`OutboundInterceptor.start_workflow_update`.""" - id: str + workflow_id: str run_id: Optional[str] update_id: str update: str @@ -4076,7 +4087,21 @@ class UpdateWorkflowInput: temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage ] headers: Mapping[str, temporalio.api.common.v1.Payload] - # Type may be absent + ret_type: Optional[Type] + rpc_metadata: Mapping[str, str] + rpc_timeout: Optional[timedelta] + + +@dataclass +class PollUpdateWorkflowInput: + """Input for :py:meth:`OutboundInterceptor.poll_workflow_update`.""" + + workflow_id: str + run_id: Optional[str] + update_id: str + update: str + timeout: Optional[timedelta] + headers: Mapping[str, temporalio.api.common.v1.Payload] ret_type: Optional[Type] rpc_metadata: Mapping[str, str] rpc_timeout: Optional[timedelta] @@ -4329,9 +4354,13 @@ async def terminate_workflow(self, input: TerminateWorkflowInput) -> None: async def start_workflow_update( self, input: UpdateWorkflowInput ) -> WorkflowUpdateHandle: - """Called for every :py:meth:`WorkflowHandle.signal` call.""" + """Called for every :py:meth:`WorkflowHandle.update` and :py:meth:`WorkflowHandle.start_update` call.""" return await self.next.start_workflow_update(input) + async def poll_workflow_update(self, input: PollUpdateWorkflowInput) -> Any: + """May be called when calling :py:math:`WorkflowUpdateHandle.result`.""" + return await self.next.poll_workflow_update(input) + ### Async activity calls async def heartbeat_async_activity( @@ -4665,7 +4694,7 @@ async def start_workflow_update( req = temporalio.api.workflowservice.v1.UpdateWorkflowExecutionRequest( namespace=self._client.namespace, workflow_execution=temporalio.api.common.v1.WorkflowExecution( - workflow_id=input.id, + workflow_id=input.workflow_id, run_id=input.run_id or "", ), request=temporalio.api.update.v1.Request( @@ -4695,7 +4724,9 @@ async def start_workflow_update( # If the status is INVALID_ARGUMENT, we can assume it's an update # failed error if err.status == RPCStatusCode.INVALID_ARGUMENT: - raise WorkflowUpdateFailedError(input.id, input.update, err.cause) + raise WorkflowUpdateFailedError( + input.workflow_id, input.update, err.cause + ) else: raise @@ -4703,7 +4734,7 @@ async def start_workflow_update( client=self._client, id=input.update_id, name=input.update, - workflow_id=input.id, + workflow_id=input.workflow_id, run_id=input.run_id, result_type=input.ret_type, ) @@ -4712,6 +4743,47 @@ async def start_workflow_update( return update_handle + async def poll_workflow_update(self, input: PollUpdateWorkflowInput) -> Any: + req = temporalio.api.workflowservice.v1.PollWorkflowExecutionUpdateRequest( + namespace=self._client.namespace, + update_ref=temporalio.api.update.v1.UpdateRef( + workflow_execution=temporalio.api.common.v1.WorkflowExecution( + workflow_id=input.workflow_id, + run_id=input.run_id or "", + ), + update_id=input.update_id, + ), + identity=self._client.identity, + wait_policy=temporalio.api.update.v1.WaitPolicy( + lifecycle_stage=temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_COMPLETED + ), + ) + try: + # Wait for at most the *overall* timeout + async with asyncio.timeout(input.timeout.total_seconds()): + # Continue polling as long as we have either an empty response, or an *rpc* timeout + while True: + try: + res = await self._client.workflow_service.poll_workflow_execution_update( + req, + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, + ) + if res.HasField("outcome"): + return await _update_outcome_to_result( + res.outcome, + input.update_id, + input.update, + self._client.data_converter, + input.ret_type, + ) + except RPCError as err: + if err.status == RPCStatusCode.DEADLINE_EXCEEDED: + continue + except TimeoutError: + pass + ### Async activity calls async def heartbeat_async_activity( @@ -5240,6 +5312,30 @@ def _fix_history_enum(prefix: str, parent: Dict[str, Any], *attrs: str) -> None: _fix_history_enum(prefix, child_item, *attrs[1:]) +async def _update_outcome_to_result( + outcome: temporalio.api.update.v1.Outcome, + id: str, + name: str, + converter: temporalio.converter.DataConverter, + rtype: Optional[Type], +) -> Any: + if outcome.HasField("failure"): + raise WorkflowUpdateFailedError( + id, + name, + await converter.decode_failure(outcome.failure.cause), + ) + if not outcome.success.payloads: + return None + type_hints = [rtype] if rtype else None + results = await converter.decode(outcome.success.payloads, type_hints) + if not results: + return None + elif len(results) > 1: + warnings.warn(f"Expected single update result, got {len(results)}") + return results[0] + + @dataclass(frozen=True) class WorkerBuildIdVersionSets: """Represents the sets of compatible Build ID versions associated with some Task Queue, as diff --git a/temporalio/contrib/opentelemetry.py b/temporalio/contrib/opentelemetry.py index 6a4f42a2..b1e1df6b 100644 --- a/temporalio/contrib/opentelemetry.py +++ b/temporalio/contrib/opentelemetry.py @@ -249,7 +249,7 @@ async def update_workflow( ) -> Any: with self.root._start_as_current_span( f"UpdateWorkflow:{input.update}", - attributes={"temporalWorkflowID": input.id}, + attributes={"temporalWorkflowID": input.workflow_id}, input=input, kind=opentelemetry.trace.SpanKind.CLIENT, ): diff --git a/tests/test_client.py b/tests/test_client.py index a535beea..03d77d9f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -36,6 +36,7 @@ Client, Interceptor, OutboundInterceptor, + PollUpdateWorkflowInput, QueryWorkflowInput, RPCError, RPCStatusCode, @@ -408,6 +409,12 @@ async def start_workflow_update( self._parent.traces.append(("start_workflow_update", input)) return await super().start_workflow_update(input) + async def poll_workflow_update( + self, input: PollUpdateWorkflowInput + ) -> WorkflowUpdateHandle: + self._parent.traces.append(("poll_workflow_update", input)) + return await super().poll_workflow_update(input) + async def test_interceptor(client: Client, worker: ExternalWorker): # Create new client from existing client but with a tracing interceptor