Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent re-use of update-with-start WithStartWorkflowOperation #714

Merged
merged 10 commits into from
Dec 20, 2024
Merged
14 changes: 10 additions & 4 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,6 +1073,11 @@ async def _start_update_with_start(
) -> WorkflowUpdateHandle[Any]:
if wait_for_stage == WorkflowUpdateStage.ADMITTED:
raise ValueError("ADMITTED wait stage not supported")

if start_workflow_operation._used:
raise RuntimeError("WithStartWorkflowOperation cannot be reused")
start_workflow_operation._used = True

update_name: str
ret_type = result_type
if isinstance(update, temporalio.workflow.UpdateMethodMultiParam):
Expand All @@ -1096,7 +1101,7 @@ async def _start_update_with_start(
wait_for_stage=wait_for_stage,
)

def on_start_success(
def on_start(
start_response: temporalio.api.workflowservice.v1.StartWorkflowExecutionResponse,
):
start_workflow_operation._workflow_handle.set_result(
Expand All @@ -1109,16 +1114,16 @@ def on_start_success(
)
)

def on_start_failure(
def on_start_error(
error: BaseException,
):
start_workflow_operation._workflow_handle.set_exception(error)

input = StartWorkflowUpdateWithStartInput(
start_workflow_input=start_workflow_operation._start_workflow_input,
update_workflow_input=update_input,
_on_start=on_start_success,
_on_start_error=on_start_failure,
_on_start=on_start,
_on_start_error=on_start_error,
)

return await self._impl.start_update_with_start_workflow(input)
Expand Down Expand Up @@ -2621,6 +2626,7 @@ def __init__(
rpc_timeout=rpc_timeout,
)
self._workflow_handle: Future[WorkflowHandle[SelfType, ReturnType]] = Future()
self._used = False

async def workflow_handle(self) -> WorkflowHandle[SelfType, ReturnType]:
"""Wait until workflow is running and return a WorkflowHandle.
Expand Down
269 changes: 254 additions & 15 deletions tests/worker/test_update_with_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,30 @@ async def done(self):
self.received_done_signal = True


async def test_with_start_workflow_operation_cannot_be_reused(client: Client):
async with new_worker(client, WorkflowForUpdateWithStartTest) as worker:
start_op = WithStartWorkflowOperation(
WorkflowForUpdateWithStartTest.run,
0,
id=f"wid-{uuid.uuid4()}",
task_queue=worker.task_queue,
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
)

async def start_update_with_start(start_op: WithStartWorkflowOperation):
return await client.start_update_with_start_workflow(
WorkflowForUpdateWithStartTest.my_non_blocking_update,
"1",
wait_for_stage=WorkflowUpdateStage.COMPLETED,
start_workflow_operation=start_op,
)

await start_update_with_start(start_op)
with pytest.raises(RuntimeError) as exc_info:
await start_update_with_start(start_op)
assert "WithStartWorkflowOperation cannot be reused" in str(exc_info.value)


class ExpectErrorWhenWorkflowExists(Enum):
YES = "yes"
NO = "no"
Expand Down Expand Up @@ -387,7 +411,7 @@ def make_start_op(workflow_id: str):
assert (await start_op_4.workflow_handle()).first_execution_run_id is not None


async def test_update_with_start_failure_start_workflow_error(
async def test_update_with_start_workflow_already_started_error(
client: Client, env: WorkflowEnvironment
):
"""
Expand Down Expand Up @@ -520,13 +544,13 @@ def test_with_start_workflow_operation_requires_conflict_policy():

@dataclass
class DataClass1:
a: int
a: str
b: str


@dataclass
class DataClass2:
a: int
a: str
b: str


Expand All @@ -536,32 +560,247 @@ def __init__(self) -> None:
self.received_update = False

@workflow.run
async def run(self) -> DataClass1:
async def run(self, arg: str) -> DataClass1:
await workflow.wait_condition(lambda: self.received_update)
return DataClass1(a=1, b="workflow-result")
return DataClass1(a=arg, b="workflow-result")

@workflow.update
async def update(self) -> DataClass2:
async def my_update(self, arg: str) -> DataClass2:
self.received_update = True
return DataClass2(a=2, b="update-result")
return DataClass2(a=arg, b="update-result")


async def test_workflow_and_update_can_return_dataclass(client: Client):
async with new_worker(client, WorkflowCanReturnDataClass) as worker:
start_op = WithStartWorkflowOperation(
WorkflowCanReturnDataClass.run,
id=f"workflow-{uuid.uuid4()}",
task_queue=worker.task_queue,
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,

def make_start_op(workflow_id: str):
return WithStartWorkflowOperation(
WorkflowCanReturnDataClass.run,
"workflow-arg",
id=workflow_id,
task_queue=worker.task_queue,
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
)

# no-param update-function overload
start_op = make_start_op(f"wf-{uuid.uuid4()}")

update_handle = await client.start_update_with_start_workflow(
WorkflowCanReturnDataClass.my_update,
"update-arg",
wait_for_stage=WorkflowUpdateStage.COMPLETED,
start_workflow_operation=start_op,
)

assert await update_handle.result() == DataClass2(
a="update-arg", b="update-result"
)

wf_handle = await start_op.workflow_handle()
assert await wf_handle.result() == DataClass1(
a="workflow-arg", b="workflow-result"
)

# no-param update-string-name overload
start_op = make_start_op(f"wf-{uuid.uuid4()}")

update_handle = await client.start_update_with_start_workflow(
WorkflowCanReturnDataClass.update,
"my_update",
"update-arg",
wait_for_stage=WorkflowUpdateStage.COMPLETED,
start_workflow_operation=start_op,
result_type=DataClass2,
)

assert await update_handle.result() == DataClass2(a=2, b="update-result")
assert await update_handle.result() == DataClass2(
a="update-arg", b="update-result"
)

wf_handle = await start_op.workflow_handle()
assert await wf_handle.result() == DataClass1(a=1, b="workflow-result")
assert await wf_handle.result() == DataClass1(
a="workflow-arg", b="workflow-result"
)


@dataclass
class WorkflowResult:
result: str


@dataclass
class UpdateResult:
result: str


@workflow.defn
class NoParamWorkflow:
def __init__(self) -> None:
self.received_update = False

@workflow.run
async def my_workflow_run(self) -> WorkflowResult:
await workflow.wait_condition(lambda: self.received_update)
return WorkflowResult(result="workflow-result")

@workflow.update(name="my_update")
async def update(self) -> UpdateResult:
self.received_update = True
return UpdateResult(result="update-result")


@workflow.defn
class OneParamWorkflow:
def __init__(self) -> None:
self.received_update = False

@workflow.run
async def my_workflow_run(self, arg: str) -> WorkflowResult:
await workflow.wait_condition(lambda: self.received_update)
return WorkflowResult(result=arg)

@workflow.update(name="my_update")
async def update(self, arg: str) -> UpdateResult:
self.received_update = True
return UpdateResult(result=arg)


@workflow.defn
class TwoParamWorkflow:
def __init__(self) -> None:
self.received_update = False

@workflow.run
async def my_workflow_run(self, arg1: str, arg2: str) -> WorkflowResult:
await workflow.wait_condition(lambda: self.received_update)
return WorkflowResult(result=arg1 + "-" + arg2)

@workflow.update(name="my_update")
async def update(self, arg1: str, arg2: str) -> UpdateResult:
self.received_update = True
return UpdateResult(result=arg1 + "-" + arg2)


async def test_update_with_start_overloads(client: Client):
async with new_worker(
client,
NoParamWorkflow,
OneParamWorkflow,
TwoParamWorkflow,
) as worker:
# No-params typed
no_param_start_op = WithStartWorkflowOperation(
NoParamWorkflow.my_workflow_run,
id=f"wf-{uuid.uuid4()}",
task_queue=worker.task_queue,
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
)
update_handle = await client.start_update_with_start_workflow(
NoParamWorkflow.update,
wait_for_stage=WorkflowUpdateStage.COMPLETED,
start_workflow_operation=no_param_start_op,
)
assert await update_handle.result() == UpdateResult(result="update-result")
wf_handle = await no_param_start_op.workflow_handle()
assert await wf_handle.result() == WorkflowResult(result="workflow-result")

# No-params string name
no_param_start_op = WithStartWorkflowOperation(
"NoParamWorkflow",
id=f"wf-{uuid.uuid4()}",
task_queue=worker.task_queue,
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
result_type=WorkflowResult,
)
update_handle = await client.start_update_with_start_workflow(
"my_update",
wait_for_stage=WorkflowUpdateStage.COMPLETED,
start_workflow_operation=no_param_start_op,
result_type=UpdateResult,
)
assert await update_handle.result() == UpdateResult(result="update-result")
wf_handle = await no_param_start_op.workflow_handle()
assert await wf_handle.result() == WorkflowResult(result="workflow-result")

# One-param typed
one_param_start_op = WithStartWorkflowOperation(
OneParamWorkflow.my_workflow_run,
"workflow-arg",
id=f"wf-{uuid.uuid4()}",
task_queue=worker.task_queue,
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
)
update_handle = await client.start_update_with_start_workflow(
OneParamWorkflow.update,
"update-arg",
wait_for_stage=WorkflowUpdateStage.COMPLETED,
start_workflow_operation=one_param_start_op,
)
assert await update_handle.result() == UpdateResult(result="update-arg")
wf_handle = await one_param_start_op.workflow_handle()
assert await wf_handle.result() == WorkflowResult(result="workflow-arg")

# One-param string name
one_param_start_op = WithStartWorkflowOperation(
"OneParamWorkflow",
"workflow-arg",
id=f"wf-{uuid.uuid4()}",
task_queue=worker.task_queue,
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
result_type=WorkflowResult,
)
update_handle = await client.start_update_with_start_workflow(
"my_update",
"update-arg",
wait_for_stage=WorkflowUpdateStage.COMPLETED,
start_workflow_operation=one_param_start_op,
result_type=UpdateResult,
)
assert await update_handle.result() == UpdateResult(result="update-arg")
wf_handle = await one_param_start_op.workflow_handle()
assert await wf_handle.result() == WorkflowResult(result="workflow-arg")

# Two-params typed
two_param_start_op = WithStartWorkflowOperation(
TwoParamWorkflow.my_workflow_run,
args=("workflow-arg1", "workflow-arg2"),
id=f"wf-{uuid.uuid4()}",
task_queue=worker.task_queue,
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
)
update_handle = await client.start_update_with_start_workflow(
TwoParamWorkflow.update,
args=("update-arg1", "update-arg2"),
wait_for_stage=WorkflowUpdateStage.COMPLETED,
start_workflow_operation=two_param_start_op,
)
assert await update_handle.result() == UpdateResult(
result="update-arg1-update-arg2"
)
wf_handle = await two_param_start_op.workflow_handle()
assert await wf_handle.result() == WorkflowResult(
result="workflow-arg1-workflow-arg2"
)

# Two-params string name
two_param_start_op = WithStartWorkflowOperation(
"TwoParamWorkflow",
args=("workflow-arg1", "workflow-arg2"),
id=f"wf-{uuid.uuid4()}",
task_queue=worker.task_queue,
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
result_type=WorkflowResult,
)
update_handle = await client.start_update_with_start_workflow(
"my_update",
args=("update-arg1", "update-arg2"),
wait_for_stage=WorkflowUpdateStage.COMPLETED,
start_workflow_operation=two_param_start_op,
result_type=UpdateResult,
)
assert await update_handle.result() == UpdateResult(
result="update-arg1-update-arg2"
)
wf_handle = await two_param_start_op.workflow_handle()
assert await wf_handle.result() == WorkflowResult(
result="workflow-arg1-workflow-arg2"
)
Loading