Skip to content

Commit 1f0259d

Browse files
committed
feat: define workflow signal decorator in workflowDefinition
Signed-off-by: Tim Li <[email protected]>
1 parent 9410b26 commit 1f0259d

File tree

2 files changed

+135
-2
lines changed

2 files changed

+135
-2
lines changed

cadence/workflow.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,16 @@ class WorkflowDefinition:
5959
Provides type safety and metadata for workflow classes.
6060
"""
6161

62-
def __init__(self, cls: Type, name: str, run_method_name: str):
62+
def __init__(self, cls: Type, name: str, run_method_name: str, signals: dict[str, Callable[..., Any]]):
6363
self._cls = cls
6464
self._name = name
6565
self._run_method_name = run_method_name
66+
self._signals = signals
67+
68+
@property
69+
def signals(self) -> dict[str, Callable[..., Any]]:
70+
"""Get the signals."""
71+
return self._signals
6672

6773
@property
6874
def name(self) -> str:
@@ -98,6 +104,9 @@ def wrap(cls: Type, opts: WorkflowDefinitionOptions) -> "WorkflowDefinition":
98104
name = opts["name"]
99105

100106
# Validate that the class has exactly one run method and find it
107+
# Also validate that class does not have multiple signal methods with the same name
108+
signals: dict[str, Callable[..., Any]] = {}
109+
signal_names: dict[str, str] = {} # Map signal name to method name for duplicate detection
101110
run_method_name = None
102111
for attr_name in dir(cls):
103112
if attr_name.startswith("_"):
@@ -114,11 +123,22 @@ def wrap(cls: Type, opts: WorkflowDefinitionOptions) -> "WorkflowDefinition":
114123
f"Multiple @workflow.run methods found in class {cls.__name__}"
115124
)
116125
run_method_name = attr_name
126+
127+
if hasattr(attr, "_workflow_signal"):
128+
signal_name = getattr(attr, "_workflow_signal")
129+
if signal_name in signal_names:
130+
raise ValueError(
131+
f"Multiple @workflow.signal methods found in class {cls.__name__} "
132+
f"with signal name '{signal_name}': '{attr_name}' and '{signal_names[signal_name]}'"
133+
)
134+
signals[attr_name] = attr
135+
signal_names[signal_name] = attr_name
117136

118137
if run_method_name is None:
119138
raise ValueError(f"No @workflow.run method found in class {cls.__name__}")
120139

121-
return WorkflowDefinition(cls, name, run_method_name)
140+
return WorkflowDefinition(cls, name, run_method_name, signals)
141+
122142

123143

124144
def run(func: Optional[T] = None) -> Union[T, Callable[[T], T]]:
@@ -161,6 +181,33 @@ def decorator(f: T) -> T:
161181
# Called without parentheses: @workflow.run
162182
return decorator(func)
163183

184+
def signal(name: str | None = None) -> Callable[[T], T]:
185+
"""
186+
Decorator to mark a method as a workflow signal handler.
187+
188+
Example:
189+
@workflow.signal(name="approval_channel")
190+
async def approve(self, approved: bool):
191+
self.approved = approved
192+
193+
Args:
194+
name: The name of the signal
195+
196+
Returns:
197+
The decorated method with workflow signal metadata
198+
199+
Raises:
200+
ValueError: If name is not provided
201+
202+
"""
203+
if name is None:
204+
raise ValueError("name is required")
205+
206+
def decorator(f: T) -> T:
207+
f._workflow_signal = name # type: ignore
208+
return f
209+
# Only allow @workflow.signal(name), require name to be explicitly provided
210+
return decorator
164211

165212
@dataclass
166213
class WorkflowInfo:

tests/cadence/worker/test_registry.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,3 +212,89 @@ async def run(self, input: str) -> str:
212212
workflow_def = reg.get_workflow("custom_workflow_name")
213213
assert workflow_def.name == "custom_workflow_name"
214214
assert workflow_def.cls == CustomWorkflow
215+
216+
def test_workflow_with_signal(self):
217+
"""Test workflow with signal handler."""
218+
reg = Registry()
219+
220+
@reg.workflow
221+
class WorkflowWithSignal:
222+
@workflow.run
223+
async def run(self):
224+
return "done"
225+
226+
@workflow.signal(name="approval")
227+
async def handle_approval(self, approved: bool):
228+
self.approved = approved
229+
230+
workflow_def = reg.get_workflow("WorkflowWithSignal")
231+
assert isinstance(workflow_def, WorkflowDefinition)
232+
assert len(workflow_def.signals) == 1
233+
assert "handle_approval" in workflow_def.signals
234+
assert hasattr(workflow_def.signals["handle_approval"], "_workflow_signal")
235+
assert workflow_def.signals["handle_approval"]._workflow_signal == "approval"
236+
237+
def test_workflow_with_multiple_signals(self):
238+
"""Test workflow with multiple signal handlers."""
239+
reg = Registry()
240+
241+
@reg.workflow
242+
class WorkflowWithMultipleSignals:
243+
@workflow.run
244+
async def run(self):
245+
return "done"
246+
247+
@workflow.signal(name="approval")
248+
async def handle_approval(self, approved: bool):
249+
self.approved = approved
250+
251+
@workflow.signal(name="cancel")
252+
async def handle_cancel(self):
253+
self.cancelled = True
254+
255+
workflow_def = reg.get_workflow("WorkflowWithMultipleSignals")
256+
assert len(workflow_def.signals) == 2
257+
assert "handle_approval" in workflow_def.signals
258+
assert "handle_cancel" in workflow_def.signals
259+
assert getattr(workflow_def.signals["handle_approval"], "_workflow_signal") == "approval"
260+
assert getattr(workflow_def.signals["handle_cancel"], "_workflow_signal") == "cancel"
261+
262+
def test_signal_decorator_requires_name(self):
263+
"""Test that signal decorator requires name parameter."""
264+
with pytest.raises(ValueError, match="name is required"):
265+
@workflow.signal()
266+
async def test_signal(self):
267+
pass
268+
269+
def test_workflow_without_signals(self):
270+
"""Test that workflow without signals has empty signals dict."""
271+
reg = Registry()
272+
273+
@reg.workflow
274+
class WorkflowWithoutSignals:
275+
@workflow.run
276+
async def run(self):
277+
return "done"
278+
279+
workflow_def = reg.get_workflow("WorkflowWithoutSignals")
280+
assert isinstance(workflow_def.signals, dict)
281+
assert len(workflow_def.signals) == 0
282+
283+
def test_duplicate_signal_names_error(self):
284+
"""Test that duplicate signal names raise ValueError."""
285+
reg = Registry()
286+
287+
with pytest.raises(ValueError, match="Multiple.*signal.*found.*with signal name 'approval'"):
288+
@reg.workflow
289+
class WorkflowWithDuplicateSignalNames:
290+
@workflow.run
291+
async def run(self):
292+
return "done"
293+
294+
@workflow.signal(name="approval")
295+
async def handle_approval(self, approved: bool):
296+
self.approved = approved
297+
298+
@workflow.signal(name="approval")
299+
async def handle_approval_different(self):
300+
self.also_approved = True

0 commit comments

Comments
 (0)