Skip to content

Commit

Permalink
Prevent re-use of update-with-start WithStartWorkflowOperation (#714)
Browse files Browse the repository at this point in the history
* Add test that WithStartWorkflowOperation cannot be reused
* RuntimeError if start_op is reused
  • Loading branch information
dandavison authored Dec 20, 2024
1 parent 702c868 commit 66e7650
Show file tree
Hide file tree
Showing 2 changed files with 269 additions and 31 deletions.
30 changes: 14 additions & 16 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,30 +1073,27 @@ async def _start_update_with_start(
) -> WorkflowUpdateHandle[Any]:
if wait_for_stage == WorkflowUpdateStage.ADMITTED:
raise ValueError("ADMITTED wait stage not supported")
update_name: str
ret_type = result_type
if isinstance(update, temporalio.workflow.UpdateMethodMultiParam):
defn = update._defn
if not defn.name:
raise RuntimeError("Cannot invoke dynamic update definition")
# TODO(cretz): Check count/type of args at runtime?
update_name = defn.name
ret_type = defn.ret_type
else:
update_name = str(update)

if start_workflow_operation._used:
raise RuntimeError("WithStartWorkflowOperation cannot be reused")
start_workflow_operation._used = True

update_name, result_type_from_type_hint = (
temporalio.workflow._UpdateDefinition.get_name_and_result_type(update)
)

update_input = UpdateWithStartUpdateWorkflowInput(
update_id=id,
update=update_name,
args=temporalio.common._arg_or_args(arg, args),
headers={},
ret_type=ret_type,
ret_type=result_type or result_type_from_type_hint,
rpc_metadata=rpc_metadata,
rpc_timeout=rpc_timeout,
wait_for_stage=wait_for_stage,
)

def on_start_success(
def on_start(
start_response: temporalio.api.workflowservice.v1.StartWorkflowExecutionResponse,
):
start_workflow_operation._workflow_handle.set_result(
Expand All @@ -1109,16 +1106,16 @@ def on_start_success(
)
)

def on_start_failure(
def on_start_error(
error: BaseException,
):
start_workflow_operation._workflow_handle.set_exception(error)

input = StartWorkflowUpdateWithStartInput(
start_workflow_input=start_workflow_operation._start_workflow_input,
update_workflow_input=update_input,
_on_start=on_start_success,
_on_start_error=on_start_failure,
_on_start=on_start,
_on_start_error=on_start_error,
)

return await self._impl.start_update_with_start_workflow(input)
Expand Down Expand Up @@ -2621,6 +2618,7 @@ def __init__(
rpc_timeout=rpc_timeout,
)
self._workflow_handle: Future[WorkflowHandle[SelfType, ReturnType]] = Future()
self._used = False

async def workflow_handle(self) -> WorkflowHandle[SelfType, ReturnType]:
"""Wait until workflow is running and return a WorkflowHandle.
Expand Down
270 changes: 255 additions & 15 deletions tests/worker/test_update_with_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,30 @@ async def done(self):
self.received_done_signal = True


async def test_with_start_workflow_operation_cannot_be_reused(client: Client):
async with new_worker(client, WorkflowForUpdateWithStartTest) as worker:
start_op = WithStartWorkflowOperation(
WorkflowForUpdateWithStartTest.run,
0,
id=f"wid-{uuid.uuid4()}",
task_queue=worker.task_queue,
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
)

async def start_update_with_start(start_op: WithStartWorkflowOperation):
return await client.start_update_with_start_workflow(
WorkflowForUpdateWithStartTest.my_non_blocking_update,
"1",
wait_for_stage=WorkflowUpdateStage.COMPLETED,
start_workflow_operation=start_op,
)

await start_update_with_start(start_op)
with pytest.raises(RuntimeError) as exc_info:
await start_update_with_start(start_op)
assert "WithStartWorkflowOperation cannot be reused" in str(exc_info.value)


class ExpectErrorWhenWorkflowExists(Enum):
YES = "yes"
NO = "no"
Expand Down Expand Up @@ -387,7 +411,7 @@ def make_start_op(workflow_id: str):
assert (await start_op_4.workflow_handle()).first_execution_run_id is not None


async def test_update_with_start_failure_start_workflow_error(
async def test_update_with_start_workflow_already_started_error(
client: Client, env: WorkflowEnvironment
):
"""
Expand Down Expand Up @@ -520,13 +544,13 @@ def test_with_start_workflow_operation_requires_conflict_policy():

@dataclass
class DataClass1:
a: int
a: str
b: str


@dataclass
class DataClass2:
a: int
a: str
b: str


Expand All @@ -536,32 +560,248 @@ def __init__(self) -> None:
self.received_update = False

@workflow.run
async def run(self) -> DataClass1:
async def run(self, arg: str) -> DataClass1:
await workflow.wait_condition(lambda: self.received_update)
return DataClass1(a=1, b="workflow-result")
return DataClass1(a=arg, b="workflow-result")

@workflow.update
async def update(self) -> DataClass2:
async def my_update(self, arg: str) -> DataClass2:
self.received_update = True
return DataClass2(a=2, b="update-result")
return DataClass2(a=arg, b="update-result")


async def test_workflow_and_update_can_return_dataclass(client: Client):
async with new_worker(client, WorkflowCanReturnDataClass) as worker:
start_op = WithStartWorkflowOperation(
WorkflowCanReturnDataClass.run,
id=f"workflow-{uuid.uuid4()}",
task_queue=worker.task_queue,
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,

def make_start_op(workflow_id: str):
return WithStartWorkflowOperation(
WorkflowCanReturnDataClass.run,
"workflow-arg",
id=workflow_id,
task_queue=worker.task_queue,
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
)

# no-param update-function overload
start_op = make_start_op(f"wf-{uuid.uuid4()}")

update_handle = await client.start_update_with_start_workflow(
WorkflowCanReturnDataClass.my_update,
"update-arg",
wait_for_stage=WorkflowUpdateStage.COMPLETED,
start_workflow_operation=start_op,
)

assert await update_handle.result() == DataClass2(
a="update-arg", b="update-result"
)

wf_handle = await start_op.workflow_handle()
assert await wf_handle.result() == DataClass1(
a="workflow-arg", b="workflow-result"
)

# no-param update-string-name overload
start_op = make_start_op(f"wf-{uuid.uuid4()}")

update_handle = await client.start_update_with_start_workflow(
WorkflowCanReturnDataClass.update,
"my_update",
"update-arg",
wait_for_stage=WorkflowUpdateStage.COMPLETED,
start_workflow_operation=start_op,
result_type=DataClass2,
)

assert await update_handle.result() == DataClass2(a=2, b="update-result")
assert await update_handle.result() == DataClass2(
a="update-arg", b="update-result"
)

wf_handle = await start_op.workflow_handle()
assert await wf_handle.result() == DataClass1(a=1, b="workflow-result")
assert await wf_handle.result() == DataClass1(
a="workflow-arg", b="workflow-result"
)


@dataclass
class WorkflowResult:
result: str


@dataclass
class UpdateResult:
result: str


@workflow.defn
class NoParamWorkflow:
def __init__(self) -> None:
self.received_update = False

@workflow.run
async def my_workflow_run(self) -> WorkflowResult:
await workflow.wait_condition(lambda: self.received_update)
return WorkflowResult(result="workflow-result")

@workflow.update(name="my_update")
async def update(self) -> UpdateResult:
self.received_update = True
return UpdateResult(result="update-result")


@workflow.defn
class OneParamWorkflow:
def __init__(self) -> None:
self.received_update = False

@workflow.run
async def my_workflow_run(self, arg: str) -> WorkflowResult:
await workflow.wait_condition(lambda: self.received_update)
return WorkflowResult(result=arg)

@workflow.update(name="my_update")
async def update(self, arg: str) -> UpdateResult:
self.received_update = True
return UpdateResult(result=arg)


@workflow.defn
class TwoParamWorkflow:
def __init__(self) -> None:
self.received_update = False

@workflow.run
async def my_workflow_run(self, arg1: str, arg2: str) -> WorkflowResult:
await workflow.wait_condition(lambda: self.received_update)
return WorkflowResult(result=arg1 + "-" + arg2)

@workflow.update(name="my_update")
async def update(self, arg1: str, arg2: str) -> UpdateResult:
self.received_update = True
return UpdateResult(result=arg1 + "-" + arg2)


async def test_update_with_start_no_param(client: Client):
async with new_worker(client, NoParamWorkflow) as worker:
# No-params typed
no_param_start_op = WithStartWorkflowOperation(
NoParamWorkflow.my_workflow_run,
id=f"wf-{uuid.uuid4()}",
task_queue=worker.task_queue,
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
)
update_handle = await client.start_update_with_start_workflow(
NoParamWorkflow.update,
wait_for_stage=WorkflowUpdateStage.COMPLETED,
start_workflow_operation=no_param_start_op,
)
assert await update_handle.result() == UpdateResult(result="update-result")
wf_handle = await no_param_start_op.workflow_handle()
assert await wf_handle.result() == WorkflowResult(result="workflow-result")

# No-params string name
no_param_start_op = WithStartWorkflowOperation(
"NoParamWorkflow",
id=f"wf-{uuid.uuid4()}",
task_queue=worker.task_queue,
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
result_type=WorkflowResult,
)
update_handle = await client.start_update_with_start_workflow(
"my_update",
wait_for_stage=WorkflowUpdateStage.COMPLETED,
start_workflow_operation=no_param_start_op,
result_type=UpdateResult,
)
assert await update_handle.result() == UpdateResult(result="update-result")
wf_handle = await no_param_start_op.workflow_handle()
assert await wf_handle.result() == WorkflowResult(result="workflow-result")


async def test_update_with_start_one_param(client: Client):
async with new_worker(client, OneParamWorkflow) as worker:
# One-param typed
one_param_start_op = WithStartWorkflowOperation(
OneParamWorkflow.my_workflow_run,
"workflow-arg",
id=f"wf-{uuid.uuid4()}",
task_queue=worker.task_queue,
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
)
update_handle = await client.start_update_with_start_workflow(
OneParamWorkflow.update,
"update-arg",
wait_for_stage=WorkflowUpdateStage.COMPLETED,
start_workflow_operation=one_param_start_op,
)
assert await update_handle.result() == UpdateResult(result="update-arg")
wf_handle = await one_param_start_op.workflow_handle()
assert await wf_handle.result() == WorkflowResult(result="workflow-arg")

# One-param string name
one_param_start_op = WithStartWorkflowOperation(
"OneParamWorkflow",
"workflow-arg",
id=f"wf-{uuid.uuid4()}",
task_queue=worker.task_queue,
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
result_type=WorkflowResult,
)
update_handle = await client.start_update_with_start_workflow(
"my_update",
"update-arg",
wait_for_stage=WorkflowUpdateStage.COMPLETED,
start_workflow_operation=one_param_start_op,
result_type=UpdateResult,
)
assert await update_handle.result() == UpdateResult(result="update-arg")
wf_handle = await one_param_start_op.workflow_handle()
assert await wf_handle.result() == WorkflowResult(result="workflow-arg")


async def test_update_with_start_two_param(client: Client):
async with new_worker(client, TwoParamWorkflow) as worker:
# Two-params typed
two_param_start_op = WithStartWorkflowOperation(
TwoParamWorkflow.my_workflow_run,
args=("workflow-arg1", "workflow-arg2"),
id=f"wf-{uuid.uuid4()}",
task_queue=worker.task_queue,
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
)
update_handle = await client.start_update_with_start_workflow(
TwoParamWorkflow.update,
args=("update-arg1", "update-arg2"),
wait_for_stage=WorkflowUpdateStage.COMPLETED,
start_workflow_operation=two_param_start_op,
)
assert await update_handle.result() == UpdateResult(
result="update-arg1-update-arg2"
)
wf_handle = await two_param_start_op.workflow_handle()
assert await wf_handle.result() == WorkflowResult(
result="workflow-arg1-workflow-arg2"
)

# Two-params string name
two_param_start_op = WithStartWorkflowOperation(
"TwoParamWorkflow",
args=("workflow-arg1", "workflow-arg2"),
id=f"wf-{uuid.uuid4()}",
task_queue=worker.task_queue,
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
result_type=WorkflowResult,
)
update_handle = await client.start_update_with_start_workflow(
"my_update",
args=("update-arg1", "update-arg2"),
wait_for_stage=WorkflowUpdateStage.COMPLETED,
start_workflow_operation=two_param_start_op,
result_type=UpdateResult,
)
assert await update_handle.result() == UpdateResult(
result="update-arg1-update-arg2"
)
wf_handle = await two_param_start_op.workflow_handle()
assert await wf_handle.result() == WorkflowResult(
result="workflow-arg1-workflow-arg2"
)

0 comments on commit 66e7650

Please sign in to comment.