Skip to content

Commit 53e2f3d

Browse files
committed
Merge branch 'main' of github.com:cadence-workflow/cadence-python-client into WorkflowDefinition
2 parents 60b1444 + f09463a commit 53e2f3d

File tree

2 files changed

+586
-5
lines changed

2 files changed

+586
-5
lines changed

cadence/client.py

Lines changed: 160 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import os
22
import socket
3-
from typing import TypedDict, Unpack, Any, cast
3+
import uuid
4+
from datetime import timedelta
5+
from typing import TypedDict, Unpack, Any, cast, Union, Callable
46

57
from grpc import ChannelCredentials, Compression
8+
from google.protobuf.duration_pb2 import Duration
69

710
from cadence._internal.rpc.error import CadenceErrorInterceptor
811
from cadence._internal.rpc.retry import RetryInterceptor
@@ -11,10 +14,47 @@
1114
from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub
1215
from grpc.aio import Channel, ClientInterceptor, secure_channel, insecure_channel
1316
from cadence.api.v1.service_workflow_pb2_grpc import WorkflowAPIStub
17+
from cadence.api.v1.service_workflow_pb2 import (
18+
StartWorkflowExecutionRequest,
19+
StartWorkflowExecutionResponse,
20+
)
21+
from cadence.api.v1.common_pb2 import WorkflowType, WorkflowExecution
22+
from cadence.api.v1.tasklist_pb2 import TaskList
1423
from cadence.data_converter import DataConverter, DefaultDataConverter
1524
from cadence.metrics import MetricsEmitter, NoOpMetricsEmitter
1625

1726

27+
class StartWorkflowOptions(TypedDict, total=False):
28+
"""Options for starting a workflow execution."""
29+
30+
task_list: str
31+
execution_start_to_close_timeout: timedelta
32+
workflow_id: str
33+
task_start_to_close_timeout: timedelta
34+
cron_schedule: str
35+
36+
37+
def _validate_and_apply_defaults(options: StartWorkflowOptions) -> StartWorkflowOptions:
38+
"""Validate required fields and apply defaults to StartWorkflowOptions."""
39+
if not options.get("task_list"):
40+
raise ValueError("task_list is required")
41+
42+
execution_timeout = options.get("execution_start_to_close_timeout")
43+
if not execution_timeout:
44+
raise ValueError("execution_start_to_close_timeout is required")
45+
if execution_timeout <= timedelta(0):
46+
raise ValueError("execution_start_to_close_timeout must be greater than 0")
47+
48+
# Apply default for task_start_to_close_timeout if not provided (matching Go/Java clients)
49+
task_timeout = options.get("task_start_to_close_timeout")
50+
if task_timeout is None:
51+
options["task_start_to_close_timeout"] = timedelta(seconds=10)
52+
elif task_timeout <= timedelta(0):
53+
raise ValueError("task_start_to_close_timeout must be greater than 0")
54+
55+
return options
56+
57+
1858
class ClientOptions(TypedDict, total=False):
1959
domain: str
2060
target: str
@@ -28,6 +68,7 @@ class ClientOptions(TypedDict, total=False):
2868
metrics_emitter: MetricsEmitter
2969
interceptors: list[ClientInterceptor]
3070

71+
3172
_DEFAULT_OPTIONS: ClientOptions = {
3273
"data_converter": DefaultDataConverter(),
3374
"identity": f"{os.getpid()}@{socket.gethostname()}",
@@ -40,6 +81,7 @@ class ClientOptions(TypedDict, total=False):
4081
"interceptors": [],
4182
}
4283

84+
4385
class Client:
4486
def __init__(self, **kwargs: Unpack[ClientOptions]) -> None:
4587
self._options = _validate_and_copy_defaults(ClientOptions(**kwargs))
@@ -82,12 +124,112 @@ async def ready(self) -> None:
82124
async def close(self) -> None:
83125
await self._channel.close()
84126

85-
async def __aenter__(self) -> 'Client':
127+
async def __aenter__(self) -> "Client":
86128
return self
87129

88130
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
89131
await self.close()
90132

133+
async def _build_start_workflow_request(
134+
self,
135+
workflow: Union[str, Callable],
136+
args: tuple[Any, ...],
137+
options: StartWorkflowOptions,
138+
) -> StartWorkflowExecutionRequest:
139+
"""Build a StartWorkflowExecutionRequest from parameters."""
140+
# Generate workflow ID if not provided
141+
workflow_id = options.get("workflow_id") or str(uuid.uuid4())
142+
143+
# Determine workflow type name
144+
if isinstance(workflow, str):
145+
workflow_type_name = workflow
146+
else:
147+
# For callable, use function name or __name__ attribute
148+
workflow_type_name = getattr(workflow, "__name__", str(workflow))
149+
150+
# Encode input arguments
151+
input_payload = None
152+
if args:
153+
try:
154+
input_payload = await self.data_converter.to_data(list(args))
155+
except Exception as e:
156+
raise ValueError(f"Failed to encode workflow arguments: {e}")
157+
158+
# Convert timedelta to protobuf Duration
159+
execution_timeout = Duration()
160+
execution_timeout.FromTimedelta(options["execution_start_to_close_timeout"])
161+
162+
task_timeout = Duration()
163+
task_timeout.FromTimedelta(options["task_start_to_close_timeout"])
164+
165+
# Build the request
166+
request = StartWorkflowExecutionRequest(
167+
domain=self.domain,
168+
workflow_id=workflow_id,
169+
workflow_type=WorkflowType(name=workflow_type_name),
170+
task_list=TaskList(name=options["task_list"]),
171+
identity=self.identity,
172+
request_id=str(uuid.uuid4()),
173+
)
174+
175+
# Set required timeout fields
176+
request.execution_start_to_close_timeout.CopyFrom(execution_timeout)
177+
request.task_start_to_close_timeout.CopyFrom(task_timeout)
178+
179+
# Set optional fields
180+
if input_payload:
181+
request.input.CopyFrom(input_payload)
182+
if options.get("cron_schedule"):
183+
request.cron_schedule = options["cron_schedule"]
184+
185+
return request
186+
187+
async def start_workflow(
188+
self,
189+
workflow: Union[str, Callable],
190+
*args,
191+
**options_kwargs: Unpack[StartWorkflowOptions],
192+
) -> WorkflowExecution:
193+
"""
194+
Start a workflow execution asynchronously.
195+
196+
Args:
197+
workflow: Workflow function or workflow type name string
198+
*args: Arguments to pass to the workflow
199+
**options_kwargs: StartWorkflowOptions as keyword arguments
200+
201+
Returns:
202+
WorkflowExecution with workflow_id and run_id
203+
204+
Raises:
205+
ValueError: If required parameters are missing or invalid
206+
Exception: If the gRPC call fails
207+
"""
208+
# Convert kwargs to StartWorkflowOptions and validate
209+
options = _validate_and_apply_defaults(StartWorkflowOptions(**options_kwargs))
210+
211+
# Build the gRPC request
212+
request = await self._build_start_workflow_request(workflow, args, options)
213+
214+
# Execute the gRPC call
215+
try:
216+
response: StartWorkflowExecutionResponse = (
217+
await self.workflow_stub.StartWorkflowExecution(request)
218+
)
219+
220+
# Emit metrics if available
221+
if self.metrics_emitter:
222+
# TODO: Add workflow start metrics similar to Go client
223+
pass
224+
225+
execution = WorkflowExecution()
226+
execution.workflow_id = request.workflow_id
227+
execution.run_id = response.run_id
228+
return execution
229+
except Exception:
230+
raise
231+
232+
91233
def _validate_and_copy_defaults(options: ClientOptions) -> ClientOptions:
92234
if "target" not in options:
93235
raise ValueError("target must be specified")
@@ -105,11 +247,24 @@ def _validate_and_copy_defaults(options: ClientOptions) -> ClientOptions:
105247

106248
def _create_channel(options: ClientOptions) -> Channel:
107249
interceptors = list(options["interceptors"])
108-
interceptors.append(YarpcMetadataInterceptor(options["service_name"], options["caller_name"]))
250+
interceptors.append(
251+
YarpcMetadataInterceptor(options["service_name"], options["caller_name"])
252+
)
109253
interceptors.append(RetryInterceptor())
110254
interceptors.append(CadenceErrorInterceptor())
111255

112256
if options["credentials"]:
113-
return secure_channel(options["target"], options["credentials"], options["channel_arguments"], options["compression"], interceptors)
257+
return secure_channel(
258+
options["target"],
259+
options["credentials"],
260+
options["channel_arguments"],
261+
options["compression"],
262+
interceptors,
263+
)
114264
else:
115-
return insecure_channel(options["target"], options["channel_arguments"], options["compression"], interceptors)
265+
return insecure_channel(
266+
options["target"],
267+
options["channel_arguments"],
268+
options["compression"],
269+
interceptors,
270+
)

0 commit comments

Comments
 (0)