diff --git a/temporalio/client.py b/temporalio/client.py index 4b6948fc..1e5a4146 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -11,6 +11,7 @@ import uuid import warnings from abc import ABC, abstractmethod +from asyncio import Future from dataclasses import dataclass from datetime import datetime, timedelta, timezone from enum import Enum, IntEnum @@ -487,7 +488,7 @@ async def start_workflow( static_details: General fixed details for this workflow execution that may appear in UI/CLI. This can be in Temporal markdown format and can span multiple lines. This is a fixed value on the workflow that cannot be updated. For details that can be - updated, use `Workflow.CurrentDetails` within the workflow. + updated, use :py:meth:`temporalio.workflow.get_current_details` within the workflow. start_delay: Amount of time to wait before starting the workflow. This does not work with ``cron_schedule``. start_signal: If present, this signal is sent as signal-with-start @@ -510,22 +511,12 @@ async def start_workflow( already been started. RPCError: Workflow could not be started for some other reason. """ - # Use definition if callable - name: str - if isinstance(workflow, str): - name = workflow - elif callable(workflow): - defn = temporalio.workflow._Definition.must_from_run_fn(workflow) - if not defn.name: - raise ValueError("Cannot invoke dynamic workflow explicitly") - name = defn.name - if result_type is None: - result_type = defn.ret_type - else: - raise TypeError("Workflow must be a string or callable") temporalio.common._warn_on_deprecated_search_attributes( search_attributes, stack_level=stack_level ) + name, result_type_from_run_fn = ( + temporalio.workflow._Definition.get_name_and_result_type(workflow) + ) return await self._impl.start_workflow( StartWorkflowInput( @@ -548,7 +539,7 @@ async def start_workflow( static_details=static_details, start_signal=start_signal, start_signal_args=start_signal_args, - ret_type=result_type, + ret_type=result_type or result_type_from_run_fn, rpc_metadata=rpc_metadata, rpc_timeout=rpc_timeout, request_eager_start=request_eager_start, @@ -820,6 +811,318 @@ def get_workflow_handle_for( result_type=defn.ret_type, ) + # Overload for no-param update + @overload + async def execute_update_with_start_workflow( + self, + update: temporalio.workflow.UpdateMethodMultiParam[[SelfType], LocalReturnType], + *, + start_workflow_operation: WithStartWorkflowOperation[SelfType, Any], + id: Optional[str] = None, + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + ) -> LocalReturnType: ... + + # Overload for single-param update + @overload + async def execute_update_with_start_workflow( + self, + update: temporalio.workflow.UpdateMethodMultiParam[ + [SelfType, ParamType], LocalReturnType + ], + arg: ParamType, + *, + start_workflow_operation: WithStartWorkflowOperation[SelfType, Any], + id: Optional[str] = None, + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + ) -> LocalReturnType: ... + + # Overload for multi-param update + @overload + async def execute_update_with_start_workflow( + self, + update: temporalio.workflow.UpdateMethodMultiParam[ + MultiParamSpec, LocalReturnType + ], + *, + args: MultiParamSpec.args, # pyright: ignore + start_workflow_operation: WithStartWorkflowOperation[SelfType, Any], + id: Optional[str] = None, + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + ) -> LocalReturnType: ... + + # Overload for string-name update + @overload + async def execute_update_with_start_workflow( + self, + update: str, + arg: Any = temporalio.common._arg_unset, + *, + start_workflow_operation: WithStartWorkflowOperation[Any, Any], + args: Sequence[Any] = [], + id: Optional[str] = None, + result_type: Optional[Type] = None, + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + ) -> Any: ... + + async def execute_update_with_start_workflow( + self, + update: Union[str, Callable], + arg: Any = temporalio.common._arg_unset, + *, + start_workflow_operation: WithStartWorkflowOperation[Any, Any], + args: Sequence[Any] = [], + id: Optional[str] = None, + result_type: Optional[Type] = None, + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + ) -> Any: + """Send an update-with-start request and wait for the update to complete. + + A WorkflowIDConflictPolicy must be set in the start_workflow_operation. If the + specified workflow execution is not running, a new workflow execution is started + and the update is sent in the first workflow task. Alternatively if the specified + workflow execution is running then, if the WorkflowIDConflictPolicy is + USE_EXISTING, the update is issued against the specified workflow, and if the + WorkflowIDConflictPolicy is FAIL, an error is returned. This call will block until + the update has completed, and return the update result. Note that this means that + the call will not return successfully until the update has been delivered to a + worker. + + .. warning:: + This API is experimental + + Args: + update: Update function or name on the workflow. arg: Single argument to the + update. + args: Multiple arguments to the update. Cannot be set if arg is. + start_workflow_operation: a WithStartWorkflowOperation definining the + WorkflowIDConflictPolicy and how to start the workflow in the event that a + workflow is started. + id: ID of the update. If not set, the default is a new UUID. + result_type: For string updates, this can set the specific result + type hint to deserialize into. + rpc_metadata: Headers used on the RPC call. Keys here override + client-level RPC metadata keys. + rpc_timeout: Optional RPC deadline to set for the RPC call. + + Raises: + WorkflowUpdateFailedError: If the update failed. + WorkflowUpdateRPCTimeoutOrCancelledError: This update call timed out + or was cancelled. This doesn't mean the update itself was timed out or + cancelled. + + RPCError: There was some issue starting the workflow or sending the update to + the workflow. + """ + handle = await self._start_update_with_start( + update, + arg, + args=args, + start_workflow_operation=start_workflow_operation, + wait_for_stage=WorkflowUpdateStage.COMPLETED, + id=id, + result_type=result_type, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) + return await handle.result() + + # Overload for no-param start update + @overload + async def start_update_with_start_workflow( + self, + update: temporalio.workflow.UpdateMethodMultiParam[[SelfType], LocalReturnType], + *, + start_workflow_operation: WithStartWorkflowOperation[SelfType, Any], + wait_for_stage: WorkflowUpdateStage, + id: Optional[str] = None, + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + ) -> WorkflowUpdateHandle[LocalReturnType]: ... + + # Overload for single-param start update + @overload + async def start_update_with_start_workflow( + self, + update: temporalio.workflow.UpdateMethodMultiParam[ + [SelfType, ParamType], LocalReturnType + ], + arg: ParamType, + *, + start_workflow_operation: WithStartWorkflowOperation[SelfType, Any], + wait_for_stage: WorkflowUpdateStage, + id: Optional[str] = None, + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + ) -> WorkflowUpdateHandle[LocalReturnType]: ... + + # Overload for multi-param start update + @overload + async def start_update_with_start_workflow( + self, + update: temporalio.workflow.UpdateMethodMultiParam[ + MultiParamSpec, LocalReturnType + ], + *, + args: MultiParamSpec.args, # pyright: ignore + start_workflow_operation: WithStartWorkflowOperation[SelfType, Any], + wait_for_stage: WorkflowUpdateStage, + id: Optional[str] = None, + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + ) -> WorkflowUpdateHandle[LocalReturnType]: ... + + # Overload for string-name start update + @overload + async def start_update_with_start_workflow( + self, + update: str, + arg: Any = temporalio.common._arg_unset, + *, + start_workflow_operation: WithStartWorkflowOperation[Any, Any], + wait_for_stage: WorkflowUpdateStage, + args: Sequence[Any] = [], + id: Optional[str] = None, + result_type: Optional[Type] = None, + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + ) -> WorkflowUpdateHandle[Any]: ... + + async def start_update_with_start_workflow( + self, + update: Union[str, Callable], + arg: Any = temporalio.common._arg_unset, + *, + start_workflow_operation: WithStartWorkflowOperation[Any, Any], + wait_for_stage: WorkflowUpdateStage, + args: Sequence[Any] = [], + id: Optional[str] = None, + result_type: Optional[Type] = None, + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + ) -> WorkflowUpdateHandle[Any]: + """Send an update-with-start request and wait for it to be accepted. + + A WorkflowIDConflictPolicy must be set in the start_workflow_operation. If the + specified workflow execution is not running, a new workflow execution is started + and the update is sent in the first workflow task. Alternatively if the specified + workflow execution is running then, if the WorkflowIDConflictPolicy is + USE_EXISTING, the update is issued against the specified workflow, and if the + WorkflowIDConflictPolicy is FAIL, an error is returned. This call will block until + the update has been accepted, and return a WorkflowUpdateHandle. Note that this + means that the call will not return successfully until the update has been + delivered to a worker. + + .. warning:: + This API is experimental + + Args: + update: Update function or name on the workflow. arg: Single argument to the + update. + args: Multiple arguments to the update. Cannot be set if arg is. + start_workflow_operation: a WithStartWorkflowOperation definining the + WorkflowIDConflictPolicy and how to start the workflow in the event that a + workflow is started. + wait_for_stage: Required stage to wait until returning: either ACCEPTED or + COMPLETED. ADMITTED is not currently supported. See + https://docs.temporal.io/workflows#update for more details. + id: ID of the update. If not set, the default is a new UUID. + result_type: For string updates, this can set the specific result + type hint to deserialize into. + rpc_metadata: Headers used on the RPC call. Keys here override + client-level RPC metadata keys. + rpc_timeout: Optional RPC deadline to set for the RPC call. + + Raises: + WorkflowUpdateFailedError: If the update failed. + WorkflowUpdateRPCTimeoutOrCancelledError: This update call timed out + or was cancelled. This doesn't mean the update itself was timed out or + cancelled. + + RPCError: There was some issue starting the workflow or sending the update to + the workflow. + """ + return await self._start_update_with_start( + update, + arg, + wait_for_stage=wait_for_stage, + args=args, + id=id, + result_type=result_type, + start_workflow_operation=start_workflow_operation, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) + + async def _start_update_with_start( + self, + update: Union[str, Callable], + arg: Any = temporalio.common._arg_unset, + *, + wait_for_stage: WorkflowUpdateStage, + args: Sequence[Any] = [], + id: Optional[str] = None, + result_type: Optional[Type] = None, + start_workflow_operation: WithStartWorkflowOperation[SelfType, ReturnType], + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + ) -> WorkflowUpdateHandle[Any]: + if wait_for_stage == WorkflowUpdateStage.ADMITTED: + raise ValueError("ADMITTED wait stage not supported") + update_name: str + ret_type = result_type + if isinstance(update, temporalio.workflow.UpdateMethodMultiParam): + defn = update._defn + if not defn.name: + raise RuntimeError("Cannot invoke dynamic update definition") + # TODO(cretz): Check count/type of args at runtime? + update_name = defn.name + ret_type = defn.ret_type + else: + update_name = str(update) + + update_input = UpdateWithStartUpdateWorkflowInput( + update_id=id, + update=update_name, + args=temporalio.common._arg_or_args(arg, args), + headers={}, + ret_type=ret_type, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + wait_for_stage=wait_for_stage, + ) + + def on_start_success( + start_response: temporalio.api.workflowservice.v1.StartWorkflowExecutionResponse, + ): + start_workflow_operation._workflow_handle.set_result( + WorkflowHandle( + self, + start_workflow_operation._start_workflow_input.id, + first_execution_run_id=start_response.run_id, + result_run_id=start_response.run_id, + result_type=result_type, + ) + ) + + def on_start_failure( + error: BaseException, + ): + start_workflow_operation._workflow_handle.set_exception(error) + + input = StartWorkflowUpdateWithStartInput( + start_workflow_input=start_workflow_operation._start_workflow_input, + update_workflow_input=update_input, + _on_start=on_start_success, + _on_start_error=on_start_failure, + ) + + return await self._impl.start_update_with_start_workflow(input) + def list_workflows( self, query: Optional[str] = None, @@ -980,7 +1283,7 @@ async def create_schedule( static_details: General fixed details for this workflow execution that may appear in UI/CLI. This can be in Temporal markdown format and can span multiple lines. This is a fixed value on the workflow that cannot be updated. For details that can be - updated, use `Workflow.CurrentDetails` within the workflow. + updated, use :py:meth:`temporalio.workflow.get_current_details` within the workflow. rpc_metadata: Headers used on the RPC call. Keys here override client-level RPC metadata keys. rpc_timeout: Optional RPC deadline to set for the RPC call. @@ -1214,9 +1517,8 @@ def id(self) -> str: @property def run_id(self) -> Optional[str]: - """Run ID used for :py:meth:`signal`, :py:meth:`query`, and - :py:meth:`update` calls if present to ensure the signal/query/update - happen on this exact run. + """If present, run ID used to ensure that requested operations apply + to this exact run. This is only created via :py:meth:`Client.get_workflow_handle`. :py:meth:`Client.start_workflow` will not set this value. @@ -1243,8 +1545,7 @@ def result_run_id(self) -> Optional[str]: @property def first_execution_run_id(self) -> Optional[str]: - """Run ID used for :py:meth:`cancel` and :py:meth:`terminate` calls if - present to ensure the cancel and terminate happen for a workflow ID + """Run ID used to ensure requested operations apply to a workflow ID started with this run ID. This can be set when using :py:meth:`Client.get_workflow_handle`. When @@ -1990,11 +2291,11 @@ async def start_update( This API is experimental Args: - update: Update function or name on the workflow. - arg: Single argument to the update. - wait_for_stage: Required stage to wait until returning. ADMITTED is - not currently supported. See https://docs.temporal.io/workflows#update - for more details. + update: Update function or name on the workflow. arg: Single argument to the + update. + wait_for_stage: Required stage to wait until returning: either ACCEPTED or + COMPLETED. ADMITTED is not currently supported. See + https://docs.temporal.io/workflows#update for more details. args: Multiple arguments to the update. Cannot be set if arg is. id: ID of the update. If not set, the default is a new UUID. result_type: For string updates, this can set the specific result @@ -2005,8 +2306,8 @@ async def start_update( Raises: WorkflowUpdateRPCTimeoutOrCancelledError: This update call timed out - or was cancelled. This doesn't mean the update itself was timed - out or cancelled. + or was cancelled. This doesn't mean the update itself was timed out or + cancelled. RPCError: There was some issue sending the update to the workflow. """ return await self._start_update( @@ -2115,7 +2416,6 @@ def get_update_handle_for( id: Update ID to get a handle to. workflow_run_id: Run ID to tie the handle to. If this is not set, the :py:attr:`run_id` will be used. - result_type: The result type to deserialize into if known. Returns: The update handle. @@ -2125,37 +2425,250 @@ def get_update_handle_for( ) -@dataclass(frozen=True) -class AsyncActivityIDReference: - """Reference to an async activity by its qualified ID.""" - - workflow_id: str - run_id: Optional[str] - activity_id: str +class WithStartWorkflowOperation(Generic[SelfType, ReturnType]): + """Defines a start-workflow operation used by update-with-start requests. + Update-With-Start allows you to send an update to a workflow, while starting the + workflow if necessary. -class AsyncActivityHandle: - """Handle representing an external activity for completion and heartbeat.""" + .. warning:: + This API is experimental + """ + # Overload for no-param workflow, with_start + @overload def __init__( - self, client: Client, id_or_token: Union[AsyncActivityIDReference, bytes] - ) -> None: - """Create an async activity handle.""" - self._client = client - self._id_or_token = id_or_token - - async def heartbeat( self, - *details: Any, - rpc_metadata: Mapping[str, str] = {}, - rpc_timeout: Optional[timedelta] = None, - ) -> None: - """Record a heartbeat for the activity. - - Args: - details: Details of the heartbeat. - rpc_metadata: Headers used on the RPC call. Keys here override - client-level RPC metadata keys. + workflow: MethodAsyncNoParam[SelfType, ReturnType], + *, + id: str, + task_queue: str, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy, + execution_timeout: Optional[timedelta] = None, + run_timeout: Optional[timedelta] = None, + task_timeout: Optional[timedelta] = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + retry_policy: Optional[temporalio.common.RetryPolicy] = None, + cron_schedule: str = "", + memo: Optional[Mapping[str, Any]] = None, + search_attributes: Optional[ + Union[ + temporalio.common.TypedSearchAttributes, + temporalio.common.SearchAttributes, + ] + ] = None, + static_summary: Optional[str] = None, + static_details: Optional[str] = None, + start_delay: Optional[timedelta] = None, + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + ) -> None: ... + + # Overload for single-param workflow, with_start + @overload + def __init__( + self, + workflow: MethodAsyncSingleParam[SelfType, ParamType, ReturnType], + arg: ParamType, + *, + id: str, + task_queue: str, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy, + execution_timeout: Optional[timedelta] = None, + run_timeout: Optional[timedelta] = None, + task_timeout: Optional[timedelta] = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + retry_policy: Optional[temporalio.common.RetryPolicy] = None, + cron_schedule: str = "", + memo: Optional[Mapping[str, Any]] = None, + search_attributes: Optional[ + Union[ + temporalio.common.TypedSearchAttributes, + temporalio.common.SearchAttributes, + ] + ] = None, + static_summary: Optional[str] = None, + static_details: Optional[str] = None, + start_delay: Optional[timedelta] = None, + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + ) -> None: ... + + # Overload for multi-param workflow, with_start + @overload + def __init__( + self, + workflow: Callable[ + Concatenate[SelfType, MultiParamSpec], Awaitable[ReturnType] + ], + *, + args: Sequence[Any], + id: str, + task_queue: str, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy, + execution_timeout: Optional[timedelta] = None, + run_timeout: Optional[timedelta] = None, + task_timeout: Optional[timedelta] = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + retry_policy: Optional[temporalio.common.RetryPolicy] = None, + cron_schedule: str = "", + memo: Optional[Mapping[str, Any]] = None, + search_attributes: Optional[ + Union[ + temporalio.common.TypedSearchAttributes, + temporalio.common.SearchAttributes, + ] + ] = None, + static_summary: Optional[str] = None, + static_details: Optional[str] = None, + start_delay: Optional[timedelta] = None, + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + ) -> None: ... + + # Overload for string-name workflow, with_start + @overload + def __init__( + self, + workflow: str, + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + task_queue: str, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy, + result_type: Optional[Type] = None, + execution_timeout: Optional[timedelta] = None, + run_timeout: Optional[timedelta] = None, + task_timeout: Optional[timedelta] = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + retry_policy: Optional[temporalio.common.RetryPolicy] = None, + cron_schedule: str = "", + memo: Optional[Mapping[str, Any]] = None, + search_attributes: Optional[ + Union[ + temporalio.common.TypedSearchAttributes, + temporalio.common.SearchAttributes, + ] + ] = None, + static_summary: Optional[str] = None, + static_details: Optional[str] = None, + start_delay: Optional[timedelta] = None, + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + ) -> None: ... + + def __init__( + self, + workflow: Union[str, Callable[..., Awaitable[Any]]], + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + task_queue: str, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy, + result_type: Optional[Type] = None, + execution_timeout: Optional[timedelta] = None, + run_timeout: Optional[timedelta] = None, + task_timeout: Optional[timedelta] = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + retry_policy: Optional[temporalio.common.RetryPolicy] = None, + cron_schedule: str = "", + memo: Optional[Mapping[str, Any]] = None, + search_attributes: Optional[ + Union[ + temporalio.common.TypedSearchAttributes, + temporalio.common.SearchAttributes, + ] + ] = None, + static_summary: Optional[str] = None, + static_details: Optional[str] = None, + start_delay: Optional[timedelta] = None, + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + stack_level: int = 2, + ) -> None: + """Create a WithStartWorkflowOperation. + + .. warning:: + This API is experimental + + See :py:meth:`temporalio.client.Client.start_workflow` for documentation of the + arguments. + """ + temporalio.common._warn_on_deprecated_search_attributes( + search_attributes, stack_level=stack_level + ) + name, result_type_from_run_fn = ( + temporalio.workflow._Definition.get_name_and_result_type(workflow) + ) + if id_conflict_policy == temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED: + raise ValueError("WorkflowIDConflictPolicy is required") + + self._start_workflow_input = UpdateWithStartStartWorkflowInput( + workflow=name, + args=temporalio.common._arg_or_args(arg, args), + id=id, + task_queue=task_queue, + execution_timeout=execution_timeout, + run_timeout=run_timeout, + task_timeout=task_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + cron_schedule=cron_schedule, + memo=memo, + search_attributes=search_attributes, + static_summary=static_summary, + static_details=static_details, + start_delay=start_delay, + headers={}, + ret_type=result_type or result_type_from_run_fn, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) + self._workflow_handle: Future[WorkflowHandle[SelfType, ReturnType]] = Future() + + async def workflow_handle(self) -> WorkflowHandle[SelfType, ReturnType]: + """Wait until workflow is running and return a WorkflowHandle. + + .. warning:: + This API is experimental + """ + return await self._workflow_handle + + +@dataclass(frozen=True) +class AsyncActivityIDReference: + """Reference to an async activity by its qualified ID.""" + + workflow_id: str + run_id: Optional[str] + activity_id: str + + +class AsyncActivityHandle: + """Handle representing an external activity for completion and heartbeat.""" + + def __init__( + self, client: Client, id_or_token: Union[AsyncActivityIDReference, bytes] + ) -> None: + """Create an async activity handle.""" + self._client = client + self._id_or_token = id_or_token + + async def heartbeat( + self, + *details: Any, + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + ) -> None: + """Record a heartbeat for the activity. + + Args: + details: Details of the heartbeat. + rpc_metadata: Headers used on the RPC call. Keys here override + client-level RPC metadata keys. rpc_timeout: Optional RPC deadline to set for the RPC call. """ await self._client._impl.heartbeat_async_activity( @@ -4335,18 +4848,19 @@ def __init__( *, workflow_run_id: Optional[str] = None, result_type: Optional[Type] = None, + known_outcome: Optional[temporalio.api.update.v1.Outcome] = None, ): """Create a workflow update handle. Users should not create this directly, but rather use - :py:meth:`Client.start_workflow_update`. + :py:meth:`WorkflowHandle.start_update` or :py:meth:`WorkflowHandle.get_update_handle`. """ self._client = client self._id = id self._workflow_id = workflow_id self._workflow_run_id = workflow_run_id self._result_type = result_type - self._known_outcome: Optional[temporalio.api.update.v1.Outcome] = None + self._known_outcome = known_outcome @property def id(self) -> str: @@ -4730,6 +5244,78 @@ class StartWorkflowUpdateInput: rpc_timeout: Optional[timedelta] +@dataclass +class UpdateWithStartUpdateWorkflowInput: + """Update input for :py:meth:`OutboundInterceptor.start_update_with_start_workflow`. + + .. warning:: + This API is experimental + """ + + update_id: Optional[str] + update: str + args: Sequence[Any] + wait_for_stage: WorkflowUpdateStage + headers: Mapping[str, temporalio.api.common.v1.Payload] + ret_type: Optional[Type] + rpc_metadata: Mapping[str, str] + rpc_timeout: Optional[timedelta] + + +@dataclass +class UpdateWithStartStartWorkflowInput: + """StartWorkflow input for :py:meth:`OutboundInterceptor.start_update_with_start_workflow`. + + .. warning:: + This API is experimental + """ + + # Similar to StartWorkflowInput but without e.g. run_id, start_signal, + # start_signal_args, request_eager_start. + + workflow: str + args: Sequence[Any] + id: str + task_queue: str + execution_timeout: Optional[timedelta] + run_timeout: Optional[timedelta] + task_timeout: Optional[timedelta] + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy + retry_policy: Optional[temporalio.common.RetryPolicy] + cron_schedule: str + memo: Optional[Mapping[str, Any]] + search_attributes: Optional[ + Union[ + temporalio.common.SearchAttributes, temporalio.common.TypedSearchAttributes + ] + ] + start_delay: Optional[timedelta] + headers: Mapping[str, temporalio.api.common.v1.Payload] + static_summary: Optional[str] + static_details: Optional[str] + # Type may be absent + ret_type: Optional[Type] + rpc_metadata: Mapping[str, str] + rpc_timeout: Optional[timedelta] + + +@dataclass +class StartWorkflowUpdateWithStartInput: + """Input for :py:meth:`OutboundInterceptor.start_update_with_start_workflow`. + + .. warning:: + This API is experimental + """ + + start_workflow_input: UpdateWithStartStartWorkflowInput + update_workflow_input: UpdateWithStartUpdateWorkflowInput + _on_start: Callable[ + [temporalio.api.workflowservice.v1.StartWorkflowExecutionResponse], None + ] + _on_start_error: Callable[[BaseException], None] + + @dataclass class HeartbeatAsyncActivityInput: """Input for :py:meth:`OutboundInterceptor.heartbeat_async_activity`.""" @@ -4988,9 +5574,19 @@ async def terminate_workflow(self, input: TerminateWorkflowInput) -> None: async def start_workflow_update( self, input: StartWorkflowUpdateInput ) -> WorkflowUpdateHandle[Any]: - """Called for every :py:meth:`WorkflowHandle.update` and :py:meth:`WorkflowHandle.start_update` call.""" + """Called for every :py:meth:`WorkflowHandle.start_update` and :py:meth:`WorkflowHandle.execute_update` call.""" return await self.next.start_workflow_update(input) + async def start_update_with_start_workflow( + self, input: StartWorkflowUpdateWithStartInput + ) -> WorkflowUpdateHandle[Any]: + """Called for every :py:meth:`Client.start_update_with_start_workflow` and :py:meth:`Client.execute_update_with_start_workflow` call. + + .. warning:: + This API is experimental + """ + return await self.next.start_update_with_start_workflow(input) + ### Async activity calls async def heartbeat_async_activity( @@ -5082,23 +5678,98 @@ def __init__(self, client: Client) -> None: async def start_workflow( self, input: StartWorkflowInput ) -> WorkflowHandle[Any, Any]: - # Build request req: Union[ temporalio.api.workflowservice.v1.StartWorkflowExecutionRequest, temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest, ] if input.start_signal is not None: - req = temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest( - signal_name=input.start_signal - ) - if input.start_signal_args: - req.signal_input.payloads.extend( - await self._client.data_converter.encode(input.start_signal_args) - ) + req = await self._build_signal_with_start_workflow_execution_request(input) else: - req = temporalio.api.workflowservice.v1.StartWorkflowExecutionRequest() - req.request_eager_execution = input.request_eager_start + req = await self._build_start_workflow_execution_request(input) + + resp: Union[ + temporalio.api.workflowservice.v1.StartWorkflowExecutionResponse, + temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionResponse, + ] + first_execution_run_id = None + eagerly_started = False + try: + if isinstance( + req, + temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest, + ): + resp = await self._client.workflow_service.signal_with_start_workflow_execution( + req, + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, + ) + else: + resp = await self._client.workflow_service.start_workflow_execution( + req, + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, + ) + first_execution_run_id = resp.run_id + eagerly_started = resp.HasField("eager_workflow_task") + except RPCError as err: + # If the status is ALREADY_EXISTS and the details can be extracted + # as already started, use a different exception + if err.status == RPCStatusCode.ALREADY_EXISTS and err.grpc_status.details: + details = temporalio.api.errordetails.v1.WorkflowExecutionAlreadyStartedFailure() + if err.grpc_status.details[0].Unpack(details): + raise temporalio.exceptions.WorkflowAlreadyStartedError( + input.id, input.workflow, run_id=details.run_id + ) + raise + handle: WorkflowHandle[Any, Any] = WorkflowHandle( + self._client, + req.workflow_id, + result_run_id=resp.run_id, + first_execution_run_id=first_execution_run_id, + result_type=input.ret_type, + ) + setattr(handle, "__temporal_eagerly_started", eagerly_started) + return handle + + async def _build_start_workflow_execution_request( + self, input: StartWorkflowInput + ) -> temporalio.api.workflowservice.v1.StartWorkflowExecutionRequest: + req = temporalio.api.workflowservice.v1.StartWorkflowExecutionRequest() + req.request_eager_execution = input.request_eager_start + await self._populate_start_workflow_execution_request(req, input) + return req + async def _build_signal_with_start_workflow_execution_request( + self, input: StartWorkflowInput + ) -> temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest: + assert input.start_signal + req = temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest( + signal_name=input.start_signal + ) + if input.start_signal_args: + req.signal_input.payloads.extend( + await self._client.data_converter.encode(input.start_signal_args) + ) + await self._populate_start_workflow_execution_request(req, input) + return req + + async def _build_update_with_start_start_workflow_execution_request( + self, input: UpdateWithStartStartWorkflowInput + ) -> temporalio.api.workflowservice.v1.StartWorkflowExecutionRequest: + req = temporalio.api.workflowservice.v1.StartWorkflowExecutionRequest() + await self._populate_start_workflow_execution_request(req, input) + return req + + async def _populate_start_workflow_execution_request( + self, + req: Union[ + temporalio.api.workflowservice.v1.StartWorkflowExecutionRequest, + temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest, + ], + input: Union[StartWorkflowInput, UpdateWithStartStartWorkflowInput], + ) -> None: req.namespace = self._client.namespace req.workflow_id = input.id req.workflow_type.name = input.workflow @@ -5145,53 +5816,6 @@ async def start_workflow( if input.headers is not None: temporalio.common._apply_headers(input.headers, req.header.fields) - # Start with signal or just normal start - resp: Union[ - temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionResponse, - temporalio.api.workflowservice.v1.StartWorkflowExecutionResponse, - ] - first_execution_run_id = None - eagerly_started = False - try: - if isinstance( - req, - temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest, - ): - resp = await self._client.workflow_service.signal_with_start_workflow_execution( - req, - retry=True, - metadata=input.rpc_metadata, - timeout=input.rpc_timeout, - ) - else: - resp = await self._client.workflow_service.start_workflow_execution( - req, - retry=True, - metadata=input.rpc_metadata, - timeout=input.rpc_timeout, - ) - first_execution_run_id = resp.run_id - eagerly_started = resp.HasField("eager_workflow_task") - except RPCError as err: - # If the status is ALREADY_EXISTS and the details can be extracted - # as already started, use a different exception - if err.status == RPCStatusCode.ALREADY_EXISTS and err.grpc_status.details: - details = temporalio.api.errordetails.v1.WorkflowExecutionAlreadyStartedFailure() - if err.grpc_status.details[0].Unpack(details): - raise temporalio.exceptions.WorkflowAlreadyStartedError( - input.id, input.workflow, run_id=details.run_id - ) - raise - handle: WorkflowHandle[Any, Any] = WorkflowHandle( - self._client, - req.workflow_id, - result_run_id=resp.run_id, - first_execution_run_id=first_execution_run_id, - result_type=input.ret_type, - ) - setattr(handle, "__temporal_eagerly_started", eagerly_started) - return handle - async def cancel_workflow(self, input: CancelWorkflowInput) -> None: await self._client.workflow_service.request_cancel_workflow_execution( temporalio.api.workflowservice.v1.RequestCancelWorkflowExecutionRequest( @@ -5345,14 +5969,70 @@ async def terminate_workflow(self, input: TerminateWorkflowInput) -> None: async def start_workflow_update( self, input: StartWorkflowUpdateInput ) -> WorkflowUpdateHandle[Any]: - # Build request + workflow_id = input.id + req = await self._build_update_workflow_execution_request(input, workflow_id) + + # Repeatedly try to invoke UpdateWorkflowExecution until the update is durable. + resp: temporalio.api.workflowservice.v1.UpdateWorkflowExecutionResponse + while True: + try: + resp = await self._client.workflow_service.update_workflow_execution( + req, + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, + ) + except RPCError as err: + if ( + err.status == RPCStatusCode.DEADLINE_EXCEEDED + or err.status == RPCStatusCode.CANCELLED + ): + raise WorkflowUpdateRPCTimeoutOrCancelledError() from err + else: + raise + except asyncio.CancelledError as err: + raise WorkflowUpdateRPCTimeoutOrCancelledError() from err + if ( + resp.stage + >= temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_ACCEPTED + ): + break + + # Build the handle. If the user's wait stage is COMPLETED, make sure we + # poll for result. + handle: WorkflowUpdateHandle[Any] = WorkflowUpdateHandle( + client=self._client, + id=req.request.meta.update_id, + workflow_id=workflow_id, + workflow_run_id=resp.update_ref.workflow_execution.run_id, + result_type=input.ret_type, + ) + if resp.HasField("outcome"): + handle._known_outcome = resp.outcome + if input.wait_for_stage == WorkflowUpdateStage.COMPLETED: + await handle._poll_until_outcome() + return handle + + async def _build_update_workflow_execution_request( + self, + input: Union[StartWorkflowUpdateInput, UpdateWithStartUpdateWorkflowInput], + workflow_id: str, + ) -> temporalio.api.workflowservice.v1.UpdateWorkflowExecutionRequest: + run_id, first_execution_run_id = ( + ( + input.run_id, + input.first_execution_run_id, + ) + if isinstance(input, StartWorkflowUpdateInput) + else (None, None) + ) req = temporalio.api.workflowservice.v1.UpdateWorkflowExecutionRequest( namespace=self._client.namespace, workflow_execution=temporalio.api.common.v1.WorkflowExecution( - workflow_id=input.id, - run_id=input.run_id or "", + workflow_id=workflow_id, + run_id=run_id or "", ), - first_execution_run_id=input.first_execution_run_id or "", + first_execution_run_id=first_execution_run_id or "", request=temporalio.api.update.v1.Request( meta=temporalio.api.update.v1.Meta( update_id=input.update_id or str(uuid.uuid4()), @@ -5376,49 +6056,137 @@ async def start_workflow_update( temporalio.common._apply_headers( input.headers, req.request.input.header.fields ) + return req - # Repeatedly try to invoke start until the update reaches user-provided - # wait stage or is at least ACCEPTED (as of the time of this writing, - # the user cannot specify sooner than ACCEPTED) - resp: temporalio.api.workflowservice.v1.UpdateWorkflowExecutionResponse - while True: - try: - resp = await self._client.workflow_service.update_workflow_execution( - req, - retry=True, - metadata=input.rpc_metadata, - timeout=input.rpc_timeout, - ) - except RPCError as err: - if ( - err.status == RPCStatusCode.DEADLINE_EXCEEDED - or err.status == RPCStatusCode.CANCELLED - ): - raise WorkflowUpdateRPCTimeoutOrCancelledError() from err - else: - raise - except asyncio.CancelledError as err: + async def start_update_with_start_workflow( + self, input: StartWorkflowUpdateWithStartInput + ) -> WorkflowUpdateHandle[Any]: + seen_start = False + + def on_start( + start_response: temporalio.api.workflowservice.v1.StartWorkflowExecutionResponse, + ): + nonlocal seen_start + if not seen_start: + input._on_start(start_response) + seen_start = True + + err: Optional[BaseException] = None + + try: + return await self._start_workflow_update_with_start( + input.start_workflow_input, input.update_workflow_input, on_start + ) + except asyncio.CancelledError as _err: + err = _err + raise WorkflowUpdateRPCTimeoutOrCancelledError() from err + except RPCError as _err: + err = _err + if err.status in [ + RPCStatusCode.DEADLINE_EXCEEDED, + RPCStatusCode.CANCELLED, + ]: raise WorkflowUpdateRPCTimeoutOrCancelledError() from err + else: + multiop_failure = ( + temporalio.api.errordetails.v1.MultiOperationExecutionFailure() + ) + if err.grpc_status.details[0].Unpack(multiop_failure): + status = next( + ( + st + for st in multiop_failure.statuses + if ( + st.code != RPCStatusCode.OK + and not ( + st.details + and st.details[0].Is( + temporalio.api.failure.v1.MultiOperationExecutionAborted.DESCRIPTOR + ) + ) + ) + ), + None, + ) + if status and status.code in list(RPCStatusCode): + if ( + status.code == RPCStatusCode.ALREADY_EXISTS + and status.details + ): + details = temporalio.api.errordetails.v1.WorkflowExecutionAlreadyStartedFailure() + if status.details[0].Unpack(details): + err = temporalio.exceptions.WorkflowAlreadyStartedError( + input.start_workflow_input.id, + input.start_workflow_input.workflow, + run_id=details.run_id, + ) + else: + err = RPCError( + status.message, + RPCStatusCode(status.code), + err.raw_grpc_status, + ) + + raise err + finally: + if err and not seen_start: + input._on_start_error(err) + + async def _start_workflow_update_with_start( + self, + start_input: UpdateWithStartStartWorkflowInput, + update_input: UpdateWithStartUpdateWorkflowInput, + on_start: Callable[ + [temporalio.api.workflowservice.v1.StartWorkflowExecutionResponse], None + ], + ) -> WorkflowUpdateHandle[Any]: + start_req = ( + await self._build_update_with_start_start_workflow_execution_request( + start_input + ) + ) + update_req = await self._build_update_workflow_execution_request( + update_input, workflow_id=start_input.id + ) + multiop_req = temporalio.api.workflowservice.v1.ExecuteMultiOperationRequest( + namespace=self._client.namespace, + operations=[ + temporalio.api.workflowservice.v1.ExecuteMultiOperationRequest.Operation( + start_workflow=start_req + ), + temporalio.api.workflowservice.v1.ExecuteMultiOperationRequest.Operation( + update_workflow=update_req + ), + ], + ) + + # Repeatedly try to invoke ExecuteMultiOperation until the update is durable + while True: + multiop_response = ( + await self._client.workflow_service.execute_multi_operation(multiop_req) + ) + start_response = multiop_response.responses[0].start_workflow + update_response = multiop_response.responses[1].update_workflow + on_start(start_response) + known_outcome = ( + update_response.outcome if update_response.HasField("outcome") else None + ) if ( - resp.stage >= req.wait_policy.lifecycle_stage - or resp.stage + update_response.stage >= temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_ACCEPTED ): break - # Build the handle. If the user's wait stage is COMPLETED, make sure we - # poll for result. handle: WorkflowUpdateHandle[Any] = WorkflowUpdateHandle( client=self._client, - id=req.request.meta.update_id, - workflow_id=input.id, - workflow_run_id=resp.update_ref.workflow_execution.run_id, - result_type=input.ret_type, + id=update_req.request.meta.update_id, + workflow_id=start_input.id, + workflow_run_id=start_response.run_id, + known_outcome=known_outcome, ) - if resp.HasField("outcome"): - handle._known_outcome = resp.outcome - if input.wait_for_stage == WorkflowUpdateStage.COMPLETED: + if update_input.wait_for_stage == WorkflowUpdateStage.COMPLETED: await handle._poll_until_outcome() + return handle ### Async activity calls diff --git a/temporalio/worker/_workflow.py b/temporalio/worker/_workflow.py index 56aaab0f..c09cfbb5 100644 --- a/temporalio/worker/_workflow.py +++ b/temporalio/worker/_workflow.py @@ -442,7 +442,7 @@ class _DeadlockError(Exception): """Exception class for deadlocks. Contains functionality to swap the default traceback for another.""" def __init__(self, message: str, replacement_tb: Optional[TracebackType] = None): - """Create a new DeadlockError, with message `msg` and optionally a traceback `tb` to be swapped in later. + """Create a new DeadlockError, with message `message` and optionally a traceback `replacement_tb` to be swapped in later. Args: message: Message to be presented through exception. diff --git a/temporalio/workflow.py b/temporalio/workflow.py index dd52d49f..e78a9fee 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -1190,7 +1190,7 @@ async def wait_condition( fn: Non-async callback that accepts no parameters and returns a boolean. timeout: Optional number of seconds to wait until throwing :py:class:`asyncio.TimeoutError`. - timeout_summary: Optional simple string identifying the timer (created if `timeout` is + timeout_summary: Optional simple string identifying the timer (created if ``timeout`` is present) that may be visible in UI/CLI. While it can be normal text, it is best to treat as a timer ID. """ @@ -1313,7 +1313,7 @@ class LoggerAdapter(logging.LoggerAdapter): Values added to ``extra`` are merged with the ``extra`` dictionary from a logging call, with values from the logging call taking precedence. I.e. the - behavior is that of `merge_extra=True` in Python >= 3.13. + behavior is that of ``merge_extra=True`` in Python >= 3.13. """ def __init__( @@ -1426,6 +1426,20 @@ def must_from_run_fn(fn: Callable[..., Awaitable[Any]]) -> _Definition: f"Function {fn_name} missing attributes, was it decorated with @workflow.run and was its class decorated with @workflow.defn?" ) + @classmethod + def get_name_and_result_type( + cls, name_or_run_fn: Union[str, Callable[..., Awaitable[Any]]] + ) -> Tuple[str, Optional[Type]]: + if isinstance(name_or_run_fn, str): + return name_or_run_fn, None + elif callable(name_or_run_fn): + defn = cls.must_from_run_fn(name_or_run_fn) + if not defn.name: + raise ValueError("Cannot invoke dynamic workflow explicitly") + return defn.name, defn.ret_type + else: + raise TypeError("Workflow must be a string or callable") + @staticmethod def _apply_to_class( cls: Type, @@ -3940,7 +3954,7 @@ async def start_child_workflow( static_details: General fixed details for this child workflow execution that may appear in UI/CLI. This can be in Temporal markdown format and can span multiple lines. This is a fixed value on the workflow that cannot be updated. For details that can be - updated, use `Workflow.CurrentDetails` within the workflow. + updated, use :py:meth:`Workflow.get_current_details` within the workflow. Returns: A workflow handle to the started/existing workflow. diff --git a/tests/conftest.py b/tests/conftest.py index 3f4c5411..e8f8116c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -47,9 +47,10 @@ def pytest_addoption(parser): parser.addoption( + "-E", "--workflow-environment", default="local", - help="Which workflow environment to use ('local', 'time-skipping', or target to existing server)", + help="Which workflow environment to use ('local', 'time-skipping', or ip:port for existing server)", ) @@ -83,6 +84,8 @@ async def env(env_type: str) -> AsyncGenerator[WorkflowEnvironment, None]: f"limit.historyCount.suggestContinueAsNew={CONTINUE_AS_NEW_SUGGEST_HISTORY_COUNT}", "--dynamic-config-value", "system.enableEagerWorkflowStart=true", + "--dynamic-config-value", + "frontend.enableExecuteMultiOperation=true", ] ) elif env_type == "time-skipping": diff --git a/tests/worker/test_update_with_start.py b/tests/worker/test_update_with_start.py new file mode 100644 index 00000000..160590e3 --- /dev/null +++ b/tests/worker/test_update_with_start.py @@ -0,0 +1,517 @@ +from __future__ import annotations + +import uuid +from contextlib import contextmanager +from datetime import timedelta +from enum import Enum +from typing import Any, Iterator +from unittest.mock import patch + +import pytest + +from temporalio import activity, workflow +from temporalio.client import ( + Client, + Interceptor, + OutboundInterceptor, + StartWorkflowUpdateWithStartInput, + WithStartWorkflowOperation, + WorkflowUpdateFailedError, + WorkflowUpdateHandle, + WorkflowUpdateStage, +) +from temporalio.common import ( + WorkflowIDConflictPolicy, +) +from temporalio.exceptions import ApplicationError, WorkflowAlreadyStartedError +from temporalio.testing import WorkflowEnvironment +from tests.helpers import ( + new_worker, +) + + +@activity.defn +async def activity_called_by_update() -> None: + pass + + +@workflow.defn +class WorkflowForUpdateWithStartTest: + def __init__(self) -> None: + self.update_finished = False + self.update_may_exit = False + self.received_done_signal = False + + @workflow.run + async def run(self, i: int) -> str: + await workflow.wait_condition(lambda: self.received_done_signal) + return f"workflow-result-{i}" + + @workflow.update + def my_non_blocking_update(self, s: str) -> str: + if s == "fail-after-acceptance": + raise ApplicationError("Workflow deliberate failed update") + return f"update-result-{s}" + + @workflow.update + async def my_blocking_update(self, s: str) -> str: + if s == "fail-after-acceptance": + raise ApplicationError("Workflow deliberate failed update") + await workflow.execute_activity( + activity_called_by_update, start_to_close_timeout=timedelta(seconds=10) + ) + return f"update-result-{s}" + + @workflow.signal + async def done(self): + self.received_done_signal = True + + +class ExpectErrorWhenWorkflowExists(Enum): + YES = "yes" + NO = "no" + + +class UpdateHandlerType(Enum): + NON_BLOCKING = "non-blocking" + BLOCKING = "blocking" + + +class TestUpdateWithStart: + client: Client + workflow_id: str + task_queue: str + update_id = "test-uws-up-id" + + @pytest.mark.parametrize( + "wait_for_stage", + [WorkflowUpdateStage.ACCEPTED, WorkflowUpdateStage.COMPLETED], + ) + async def test_non_blocking_update_with_must_create_workflow_semantics( + self, + client: Client, + env: WorkflowEnvironment, + wait_for_stage: WorkflowUpdateStage, + ): + if env.supports_time_skipping: + pytest.skip( + "TODO: make update_with_start_tests pass under Java test server" + ) + await self._do_test( + client, + f"test-uws-nb-mc-wf-id-{wait_for_stage.name}", + UpdateHandlerType.NON_BLOCKING, + wait_for_stage, + WorkflowIDConflictPolicy.FAIL, + ExpectErrorWhenWorkflowExists.YES, + ) + + @pytest.mark.parametrize( + "wait_for_stage", + [WorkflowUpdateStage.ACCEPTED, WorkflowUpdateStage.COMPLETED], + ) + async def test_non_blocking_update_with_get_or_create_workflow_semantics( + self, + client: Client, + env: WorkflowEnvironment, + wait_for_stage: WorkflowUpdateStage, + ): + if env.supports_time_skipping: + pytest.skip( + "TODO: make update_with_start_tests pass under Java test server" + ) + await self._do_test( + client, + f"test-uws-nb-goc-wf-id-{wait_for_stage.name}", + UpdateHandlerType.NON_BLOCKING, + wait_for_stage, + WorkflowIDConflictPolicy.USE_EXISTING, + ExpectErrorWhenWorkflowExists.NO, + ) + + @pytest.mark.parametrize( + "wait_for_stage", + [WorkflowUpdateStage.ACCEPTED, WorkflowUpdateStage.COMPLETED], + ) + async def test_blocking_update_with_get_or_create_workflow_semantics( + self, + client: Client, + env: WorkflowEnvironment, + wait_for_stage: WorkflowUpdateStage, + ): + if env.supports_time_skipping: + pytest.skip( + "TODO: make update_with_start_tests pass under Java test server" + ) + await self._do_test( + client, + f"test-uws-b-goc-wf-id-{wait_for_stage.name}", + UpdateHandlerType.BLOCKING, + wait_for_stage, + WorkflowIDConflictPolicy.USE_EXISTING, + ExpectErrorWhenWorkflowExists.NO, + ) + + async def _do_test( + self, + client: Client, + workflow_id: str, + update_handler_type: UpdateHandlerType, + wait_for_stage: WorkflowUpdateStage, + id_conflict_policy: WorkflowIDConflictPolicy, + expect_error_when_workflow_exists: ExpectErrorWhenWorkflowExists, + ): + await self._do_execute_update_test( + client, + workflow_id + "-execute-update", + update_handler_type, + id_conflict_policy, + expect_error_when_workflow_exists, + ) + await self._do_start_update_test( + client, + workflow_id + "-start-update", + update_handler_type, + wait_for_stage, + id_conflict_policy, + ) + + async def _do_execute_update_test( + self, + client: Client, + workflow_id: str, + update_handler_type: UpdateHandlerType, + id_conflict_policy: WorkflowIDConflictPolicy, + expect_error_when_workflow_exists: ExpectErrorWhenWorkflowExists, + ): + update_handler = ( + WorkflowForUpdateWithStartTest.my_blocking_update + if update_handler_type == UpdateHandlerType.BLOCKING + else WorkflowForUpdateWithStartTest.my_non_blocking_update + ) + async with new_worker( + client, + WorkflowForUpdateWithStartTest, + activities=[activity_called_by_update], + ) as worker: + self.client = client + self.workflow_id = workflow_id + self.task_queue = worker.task_queue + + start_op_1 = WithStartWorkflowOperation( + WorkflowForUpdateWithStartTest.run, + 1, + id=self.workflow_id, + task_queue=self.task_queue, + id_conflict_policy=id_conflict_policy, + ) + + # First UWS succeeds + assert ( + await client.execute_update_with_start_workflow( + update_handler, "1", start_workflow_operation=start_op_1 + ) + == "update-result-1" + ) + assert ( + await start_op_1.workflow_handle() + ).first_execution_run_id is not None + + # Whether a repeat UWS succeeds depends on the workflow ID conflict policy + start_op_2 = WithStartWorkflowOperation( + WorkflowForUpdateWithStartTest.run, + 2, + id=self.workflow_id, + task_queue=self.task_queue, + id_conflict_policy=id_conflict_policy, + ) + + if expect_error_when_workflow_exists == ExpectErrorWhenWorkflowExists.NO: + assert ( + await client.execute_update_with_start_workflow( + update_handler, "21", start_workflow_operation=start_op_2 + ) + == "update-result-21" + ) + assert ( + await start_op_2.workflow_handle() + ).first_execution_run_id is not None + else: + for aw in [ + client.execute_update_with_start_workflow( + update_handler, "21", start_workflow_operation=start_op_2 + ), + start_op_2.workflow_handle(), + ]: + with pytest.raises(WorkflowAlreadyStartedError): + await aw + + assert ( + await start_op_1.workflow_handle() + ).first_execution_run_id is not None + + # The workflow is still running; finish it. + + wf_handle_1 = await start_op_1.workflow_handle() + await wf_handle_1.signal(WorkflowForUpdateWithStartTest.done) + assert await wf_handle_1.result() == "workflow-result-1" + + async def _do_start_update_test( + self, + client: Client, + workflow_id: str, + update_handler_type: UpdateHandlerType, + wait_for_stage: WorkflowUpdateStage, + id_conflict_policy: WorkflowIDConflictPolicy, + ): + update_handler = ( + WorkflowForUpdateWithStartTest.my_blocking_update + if update_handler_type == UpdateHandlerType.BLOCKING + else WorkflowForUpdateWithStartTest.my_non_blocking_update + ) + async with new_worker( + client, + WorkflowForUpdateWithStartTest, + activities=[activity_called_by_update], + ) as worker: + self.client = client + self.workflow_id = workflow_id + self.task_queue = worker.task_queue + + start_op = WithStartWorkflowOperation( + WorkflowForUpdateWithStartTest.run, + 1, + id=self.workflow_id, + task_queue=self.task_queue, + id_conflict_policy=id_conflict_policy, + ) + + update_handle = await client.start_update_with_start_workflow( + update_handler, + "1", + wait_for_stage=wait_for_stage, + start_workflow_operation=start_op, + ) + assert await update_handle.result() == "update-result-1" + + @contextmanager + def assert_network_call( + self, + expect_network_call: bool, + ) -> Iterator[None]: + with patch.object( + self.client.workflow_service, + "poll_workflow_execution_update", + wraps=self.client.workflow_service.poll_workflow_execution_update, + ) as _wrapped_poll: + yield + assert _wrapped_poll.called == expect_network_call + + +async def test_update_with_start_sets_first_execution_run_id( + client: Client, env: WorkflowEnvironment +): + if env.supports_time_skipping: + pytest.skip("TODO: make update_with_start_tests pass under Java test server") + async with new_worker( + client, + WorkflowForUpdateWithStartTest, + activities=[activity_called_by_update], + ) as worker: + + def make_start_op(workflow_id: str): + return WithStartWorkflowOperation( + WorkflowForUpdateWithStartTest.run, + 0, + id=workflow_id, + id_conflict_policy=WorkflowIDConflictPolicy.FAIL, + task_queue=worker.task_queue, + ) + + # conflict policy is FAIL + # First UWS succeeds and sets the first execution run ID + start_op_1 = make_start_op("wid-1") + update_handle_1 = await client.start_update_with_start_workflow( + WorkflowForUpdateWithStartTest.my_non_blocking_update, + "1", + wait_for_stage=WorkflowUpdateStage.COMPLETED, + start_workflow_operation=start_op_1, + ) + assert (await start_op_1.workflow_handle()).first_execution_run_id is not None + assert await update_handle_1.result() == "update-result-1" + + # Second UWS start fails because the workflow already exists + # first execution run ID is not set on the second UWS handle + start_op_2 = make_start_op("wid-1") + + for aw in [ + client.start_update_with_start_workflow( + WorkflowForUpdateWithStartTest.my_non_blocking_update, + "2", + wait_for_stage=WorkflowUpdateStage.COMPLETED, + start_workflow_operation=start_op_2, + ), + start_op_2.workflow_handle(), + ]: + with pytest.raises(WorkflowAlreadyStartedError): + await aw + + # Third UWS start succeeds, but the update fails after acceptance + start_op_3 = make_start_op("wid-2") + update_handle_3 = await client.start_update_with_start_workflow( + WorkflowForUpdateWithStartTest.my_non_blocking_update, + "fail-after-acceptance", + wait_for_stage=WorkflowUpdateStage.COMPLETED, + start_workflow_operation=start_op_3, + ) + assert (await start_op_3.workflow_handle()).first_execution_run_id is not None + with pytest.raises(WorkflowUpdateFailedError): + await update_handle_3.result() + + # Despite the update failure, the first execution run ID is set on the with_start_request, + # and the handle can be used to obtain the workflow result. + assert (await start_op_3.workflow_handle()).first_execution_run_id is not None + wf_handle_3 = await start_op_3.workflow_handle() + await wf_handle_3.signal(WorkflowForUpdateWithStartTest.done) + assert await wf_handle_3.result() == "workflow-result-0" + + # Fourth UWS is same as third, but we use execute_update instead of start_update. + start_op_4 = make_start_op("wid-3") + with pytest.raises(WorkflowUpdateFailedError): + await client.execute_update_with_start_workflow( + WorkflowForUpdateWithStartTest.my_non_blocking_update, + "fail-after-acceptance", + start_workflow_operation=start_op_4, + ) + assert (await start_op_4.workflow_handle()).first_execution_run_id is not None + + +async def test_update_with_start_failure_start_workflow_error( + client: Client, env: WorkflowEnvironment +): + """ + When the workflow start fails, the update_with_start_call should raise the appropriate + gRPC error, and the start_workflow_operation promise should be rejected with the same + error. + """ + if env.supports_time_skipping: + pytest.skip("TODO: make update_with_start_tests pass under Java test server") + async with new_worker( + client, + WorkflowForUpdateWithStartTest, + ) as worker: + + def make_start_op(workflow_id: str): + return WithStartWorkflowOperation( + WorkflowForUpdateWithStartTest.run, + 0, + id=workflow_id, + id_conflict_policy=WorkflowIDConflictPolicy.FAIL, + task_queue=worker.task_queue, + ) + + wid = f"wf-{uuid.uuid4()}" + start_op_1 = make_start_op(wid) + await client.start_update_with_start_workflow( + WorkflowForUpdateWithStartTest.my_non_blocking_update, + "1", + wait_for_stage=WorkflowUpdateStage.COMPLETED, + start_workflow_operation=start_op_1, + ) + + start_op_2 = make_start_op(wid) + + for aw in [ + client.start_update_with_start_workflow( + WorkflowForUpdateWithStartTest.my_non_blocking_update, + "2", + wait_for_stage=WorkflowUpdateStage.COMPLETED, + start_workflow_operation=start_op_2, + ), + start_op_2.workflow_handle(), + ]: + with pytest.raises(WorkflowAlreadyStartedError): + await aw + + +class SimpleClientInterceptor(Interceptor): + def intercept_client(self, next: OutboundInterceptor) -> OutboundInterceptor: + return SimpleClientOutboundInterceptor(super().intercept_client(next)) + + +class SimpleClientOutboundInterceptor(OutboundInterceptor): + def __init__(self, next: OutboundInterceptor) -> None: + super().__init__(next) + + async def start_update_with_start_workflow( + self, input: StartWorkflowUpdateWithStartInput + ) -> WorkflowUpdateHandle[Any]: + input.start_workflow_input.args = ["intercepted-workflow-arg"] + input.update_workflow_input.args = ["intercepted-update-arg"] + return await super().start_update_with_start_workflow(input) + + +@workflow.defn +class UpdateWithStartInterceptorWorkflow: + def __init__(self) -> None: + self.received_update = False + + @workflow.run + async def run(self, arg: str) -> str: + await workflow.wait_condition(lambda: self.received_update) + return arg + + @workflow.update + async def my_update(self, arg: str) -> str: + self.received_update = True + await workflow.wait_condition(lambda: self.received_update) + return arg + + +async def test_update_with_start_client_outbound_interceptor( + client: Client, env: WorkflowEnvironment +): + if env.supports_time_skipping: + pytest.skip("TODO: make update_with_start_tests pass under Java test server") + interceptor = SimpleClientInterceptor() + client = Client(**{**client.config(), "interceptors": [interceptor]}) # type: ignore + + async with new_worker( + client, + UpdateWithStartInterceptorWorkflow, + ) as worker: + start_op = WithStartWorkflowOperation( + UpdateWithStartInterceptorWorkflow.run, + "original-workflow-arg", + id=f"wf-{uuid.uuid4()}", + task_queue=worker.task_queue, + id_conflict_policy=WorkflowIDConflictPolicy.FAIL, + ) + update_result = await client.execute_update_with_start_workflow( + UpdateWithStartInterceptorWorkflow.my_update, + "original-update-arg", + start_workflow_operation=start_op, + ) + assert update_result == "intercepted-update-arg" + + wf_handle = await start_op.workflow_handle() + assert await wf_handle.result() == "intercepted-workflow-arg" + + +def test_with_start_workflow_operation_requires_conflict_policy(): + with pytest.raises(ValueError): + WithStartWorkflowOperation( + WorkflowForUpdateWithStartTest.run, + 0, + id="wid-1", + id_conflict_policy=WorkflowIDConflictPolicy.UNSPECIFIED, + task_queue="test-queue", + ) + + with pytest.raises(TypeError): + WithStartWorkflowOperation( # type: ignore + WorkflowForUpdateWithStartTest.run, + 0, + id="wid-1", + task_queue="test-queue", + )