diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 8c59992b..5f5dd155 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -208,6 +208,7 @@ def __init__(self, det: WorkflowInstanceDetails) -> None: self._worker_level_failure_exception_types = ( det.worker_level_failure_exception_types ) + self._primary_task_initter: Optional[Callable[[], asyncio.Task[None]]] = None self._primary_task: Optional[asyncio.Task[None]] = None self._time_ns = 0 self._cancel_requested = False @@ -356,39 +357,22 @@ def activate( self._current_thread_id = threading.get_ident() activation_err: Optional[Exception] = None try: - # Split into job sets with patches, then signals + updates, then - # non-queries, then queries - start_job = None - job_sets: List[ - List[temporalio.bridge.proto.workflow_activation.WorkflowActivationJob] - ] = [[], [], [], []] + # Apply every job, running the loop afterward + is_query = False for job in act.jobs: - if job.HasField("notify_has_patch"): - job_sets[0].append(job) - elif job.HasField("signal_workflow") or job.HasField("do_update"): - job_sets[1].append(job) - elif not job.HasField("query_workflow"): - if job.HasField("initialize_workflow"): - start_job = job.initialize_workflow - job_sets[2].append(job) - else: - job_sets[3].append(job) - - if start_job: - self._workflow_input = self._make_workflow_input(start_job) - - # Apply every job set, running after each set - for index, job_set in enumerate(job_sets): - if not job_set: - continue - for job in job_set: - # Let errors bubble out of these to the caller to fail the task - self._apply(job) - - # Run one iteration of the loop. We do not allow conditions to - # be checked in patch jobs (first index) or query jobs (last - # index). - self._run_once(check_conditions=index == 1 or index == 2) + if job.HasField("initialize_workflow"): + self._workflow_input = self._make_workflow_input(job.initialize_workflow) + # Let errors bubble out of these to the caller to fail the task + self._apply(job) + if job.HasField("query_workflow"): + is_query = True + + # Ensure the main loop is called, and called last, if needed + if self._primary_task_initter is not None and self._primary_task is None: + self._primary_task_initter() + # Conditions are not checked on query activations. Query activations always come without + # any other jobs. + self._run_once(check_conditions=not is_query) except Exception as err: # We want some errors during activation, like those that can happen # during payload conversion, to be able to fail the workflow not the @@ -508,6 +492,15 @@ def _apply_cancel_workflow( # workflow the ability to receive the cancellation, so we must defer # this cancellation to the next iteration of the event loop. self.call_soon(self._primary_task.cancel) + elif self._primary_task_initter: + # If we're being cancelled before ever being started, we need to run the cancel + # after initialization + old_initter = self._primary_task_initter + def init_then_cancel(): + old_initter() + self.call_soon(self._primary_task.cancel) + self._primary_task_initter = init_then_cancel + def _apply_do_update( self, job: temporalio.bridge.proto.workflow_activation.DoUpdate @@ -889,10 +882,12 @@ async def run_workflow(input: ExecuteWorkflowInput) -> None: raise RuntimeError( "Expected workflow input to be set. This is an SDK Python bug." ) - self._primary_task = self.create_task( - self._run_top_level_workflow_function(run_workflow(self._workflow_input)), - name="run", - ) + def primary_initter(): + self._primary_task = self.create_task( + self._run_top_level_workflow_function(run_workflow(self._workflow_input)), + name="run", + ) + self._primary_task_initter = primary_initter def _apply_update_random_seed( self, job: temporalio.bridge.proto.workflow_activation.UpdateRandomSeed