Skip to content

Commit 19a574f

Browse files
committed
fix(runtime): ensure stop flag is set for policy violations in parallel rails
1 parent a2f17bc commit 19a574f

File tree

2 files changed

+146
-1
lines changed

2 files changed

+146
-1
lines changed

nemoguardrails/colang/v1_0/runtime/runtime.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,13 @@ async def _run_flows_in_parallel(
311311
# Wrapper function to help reverse map the task result to the flow ID
312312
async def task_call_helper(flow_uid, post_event, func, *args, **kwargs):
313313
result = await func(*args, **kwargs)
314-
if post_event:
314+
315+
has_stop = any(
316+
event["type"] == "BotIntent" and event["intent"] == "stop"
317+
for event in result
318+
)
319+
320+
if post_event and not has_stop:
315321
result.append(post_event)
316322
args[1].append(
317323
{"type": "event", "timestamp": time(), "data": post_event}
@@ -361,6 +367,7 @@ async def task_call_helper(flow_uid, post_event, func, *args, **kwargs):
361367
unique_flow_ids[flow_uid] = task
362368

363369
stopped_task_results: List[dict] = []
370+
stopped_task_processing_log: List[dict] = []
364371

365372
# Process tasks as they complete using as_completed
366373
try:
@@ -377,6 +384,9 @@ async def task_call_helper(flow_uid, post_event, func, *args, **kwargs):
377384
# If this flow had a stop event
378385
if has_stop:
379386
stopped_task_results = task_results[flow_id] + result
387+
stopped_task_processing_log = task_processing_logs[
388+
flow_id
389+
].copy()
380390

381391
# Cancel all remaining tasks
382392
for pending_task in tasks:
@@ -433,6 +443,15 @@ async def task_call_helper(flow_uid, post_event, func, *args, **kwargs):
433443
finished_task_processing_logs.extend(task_processing_logs[flow_id])
434444

435445
if processing_log:
446+
for plog in stopped_task_processing_log:
447+
# Filter out "Listen" and "start_flow" events from task processing log
448+
if plog["type"] == "event" and (
449+
plog["data"]["type"] == "Listen"
450+
or plog["data"]["type"] == "start_flow"
451+
):
452+
continue
453+
processing_log.append(plog)
454+
436455
for plog in finished_task_processing_logs:
437456
# Filter out "Listen" and "start_flow" events from task processing log
438457
if plog["type"] == "event" and (

tests/test_parallel_rails.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,129 @@ async def test_parallel_rails_output_fail_2():
152152
and result.response[0]["content"]
153153
== "I cannot express a term in the bot answer."
154154
)
155+
156+
157+
@pytest.mark.asyncio
158+
async def test_parallel_rails_input_stop_flag():
159+
config = RailsConfig.from_path(os.path.join(CONFIGS_FOLDER, "parallel_rails"))
160+
chat = TestChat(
161+
config,
162+
llm_completions=[
163+
"No",
164+
"Hi there! How can I assist you with questions about the ABC Company today?",
165+
"No",
166+
],
167+
)
168+
169+
chat >> "hi, this is a blocked term."
170+
result = await chat.app.generate_async(messages=chat.history, options=OPTIONS)
171+
172+
stopped_rails = [rail for rail in result.log.activated_rails if rail.stop]
173+
assert len(stopped_rails) == 1, "Expected exactly one stopped rail"
174+
assert (
175+
"check blocked input terms" in stopped_rails[0].name
176+
), f"Expected 'check blocked input terms' rail to be stopped, got {stopped_rails[0].name}"
177+
178+
179+
@pytest.mark.asyncio
180+
async def test_parallel_rails_output_stop_flag():
181+
config = RailsConfig.from_path(os.path.join(CONFIGS_FOLDER, "parallel_rails"))
182+
chat = TestChat(
183+
config,
184+
llm_completions=[
185+
"No",
186+
"Hi there! This is a blocked term!",
187+
"No",
188+
],
189+
)
190+
191+
chat >> "hi!"
192+
result = await chat.app.generate_async(messages=chat.history, options=OPTIONS)
193+
194+
stopped_rails = [rail for rail in result.log.activated_rails if rail.stop]
195+
assert len(stopped_rails) == 1, "Expected exactly one stopped rail"
196+
assert (
197+
"check blocked output terms" in stopped_rails[0].name
198+
), f"Expected 'check blocked output terms' rail to be stopped, got {stopped_rails[0].name}"
199+
200+
201+
@pytest.mark.asyncio
202+
async def test_parallel_rails_client_code_pattern():
203+
config = RailsConfig.from_path(os.path.join(CONFIGS_FOLDER, "parallel_rails"))
204+
chat = TestChat(
205+
config,
206+
llm_completions=[
207+
"No",
208+
"Hi there! This is a blocked term!",
209+
"No",
210+
],
211+
)
212+
213+
chat >> "hi!"
214+
result = await chat.app.generate_async(messages=chat.history, options=OPTIONS)
215+
216+
activated_rails = result.log.activated_rails if result.log else None
217+
assert activated_rails is not None, "Expected activated_rails to be present"
218+
219+
rails_to_check = [
220+
"self check output",
221+
"check blocked output terms $duration=1.0",
222+
]
223+
rails_set = set(rails_to_check)
224+
225+
stopping_rails = [rail for rail in activated_rails if rail.stop]
226+
227+
assert len(stopping_rails) > 0, "Expected at least one stopping rail"
228+
229+
blocked_rails = []
230+
for rail in stopping_rails:
231+
if rail.name in rails_set:
232+
blocked_rails.append(rail.name)
233+
234+
assert (
235+
len(blocked_rails) == 1
236+
), f"Expected exactly one blocked rail from our check list, got {len(blocked_rails)}: {blocked_rails}"
237+
assert (
238+
"check blocked output terms $duration=1.0" in blocked_rails
239+
), f"Expected 'check blocked output terms $duration=1.0' to be blocked, got {blocked_rails}"
240+
241+
non_stopped_rails = [rail for rail in activated_rails if not rail.stop]
242+
for rail in non_stopped_rails:
243+
assert (
244+
rail.stop is False or rail.stop is None
245+
), f"Non-stopped rail {rail.name} should not have stop=True"
246+
247+
248+
@pytest.mark.asyncio
249+
async def test_parallel_rails_multiple_activated_rails():
250+
config = RailsConfig.from_path(os.path.join(CONFIGS_FOLDER, "parallel_rails"))
251+
chat = TestChat(
252+
config,
253+
llm_completions=[
254+
"No",
255+
"Hi there! This is a blocked term!",
256+
"No",
257+
],
258+
)
259+
260+
chat >> "hi!"
261+
result = await chat.app.generate_async(messages=chat.history, options=OPTIONS)
262+
263+
activated_rails = result.log.activated_rails if result.log else None
264+
assert activated_rails is not None, "Expected activated_rails to be present"
265+
assert len(activated_rails) > 1, (
266+
f"Expected multiple activated_rails, got {len(activated_rails)}: "
267+
f"{[rail.name for rail in activated_rails]}"
268+
)
269+
270+
stopped_rails = [rail for rail in activated_rails if rail.stop]
271+
assert len(stopped_rails) == 1, (
272+
f"Expected exactly one stopped rail, got {len(stopped_rails)}: "
273+
f"{[rail.name for rail in stopped_rails]}"
274+
)
275+
276+
rails_with_stop_true = [rail for rail in activated_rails if rail.stop is True]
277+
assert len(rails_with_stop_true) == 1, (
278+
f"Expected exactly one rail with stop=True, got {len(rails_with_stop_true)}: "
279+
f"{[rail.name for rail in rails_with_stop_true]}"
280+
)

0 commit comments

Comments
 (0)