Skip to content

Commit d8b3cfd

Browse files
pgrayyjsamuel1
authored andcommitted
async model stream interface (strands-agents#306)
1 parent 20cee8b commit d8b3cfd

27 files changed

+878
-541
lines changed

src/strands/agent/agent.py

Lines changed: 76 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99
2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")`
1010
"""
1111

12+
import asyncio
1213
import json
1314
import logging
1415
import os
1516
import random
1617
from concurrent.futures import ThreadPoolExecutor
17-
from typing import Any, AsyncIterator, Callable, Generator, List, Mapping, Optional, Type, TypeVar, Union, cast
18+
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Mapping, Optional, Type, TypeVar, Union, cast
1819

1920
from opentelemetry import trace
2021
from pydantic import BaseModel
@@ -418,33 +419,43 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
418419
- metrics: Performance metrics from the event loop
419420
- state: The final state of the event loop
420421
"""
421-
callback_handler = kwargs.get("callback_handler", self.callback_handler)
422422

423-
self._start_agent_trace_span(prompt)
423+
def execute() -> AgentResult:
424+
return asyncio.run(self.invoke_async(prompt, **kwargs))
424425

425-
try:
426-
events = self._run_loop(prompt, kwargs)
427-
for event in events:
428-
if "callback" in event:
429-
callback_handler(**event["callback"])
426+
with ThreadPoolExecutor() as executor:
427+
future = executor.submit(execute)
428+
return future.result()
430429

431-
stop_reason, message, metrics, state = event["stop"]
432-
result = AgentResult(stop_reason, message, metrics, state)
430+
async def invoke_async(self, prompt: str, **kwargs: Any) -> AgentResult:
431+
"""Process a natural language prompt through the agent's event loop.
433432
434-
self._end_agent_trace_span(response=result)
433+
This method implements the conversational interface (e.g., `agent("hello!")`). It adds the user's prompt to
434+
the conversation history, processes it through the model, executes any tool calls, and returns the final result.
435435
436-
return result
436+
Args:
437+
prompt: The natural language prompt from the user.
438+
**kwargs: Additional parameters to pass through the event loop.
437439
438-
except Exception as e:
439-
self._end_agent_trace_span(error=e)
440-
raise
440+
Returns:
441+
Result object containing:
442+
443+
- stop_reason: Why the event loop stopped (e.g., "end_turn", "max_tokens")
444+
- message: The final message from the model
445+
- metrics: Performance metrics from the event loop
446+
- state: The final state of the event loop
447+
"""
448+
events = self.stream_async(prompt, **kwargs)
449+
async for event in events:
450+
_ = event
451+
452+
return cast(AgentResult, event["result"])
441453

442454
def structured_output(self, output_model: Type[T], prompt: Optional[str] = None) -> T:
443455
"""This method allows you to get structured output from the agent.
444456
445457
If you pass in a prompt, it will be added to the conversation history and the agent will respond to it.
446458
If you don't pass in a prompt, it will use only the conversation history to respond.
447-
If no conversation history exists and no prompt is provided, an error will be raised.
448459
449460
For smaller models, you may want to use the optional prompt string to add additional instructions to explicitly
450461
instruct the model to output the structured data.
@@ -453,25 +464,52 @@ def structured_output(self, output_model: Type[T], prompt: Optional[str] = None)
453464
output_model: The output model (a JSON schema written as a Pydantic BaseModel)
454465
that the agent will use when responding.
455466
prompt: The prompt to use for the agent.
467+
468+
Raises:
469+
ValueError: If no conversation history or prompt is provided.
470+
"""
471+
472+
def execute() -> T:
473+
return asyncio.run(self.structured_output_async(output_model, prompt))
474+
475+
with ThreadPoolExecutor() as executor:
476+
future = executor.submit(execute)
477+
return future.result()
478+
479+
async def structured_output_async(self, output_model: Type[T], prompt: Optional[str] = None) -> T:
480+
"""This method allows you to get structured output from the agent.
481+
482+
If you pass in a prompt, it will be added to the conversation history and the agent will respond to it.
483+
If you don't pass in a prompt, it will use only the conversation history to respond.
484+
485+
For smaller models, you may want to use the optional prompt string to add additional instructions to explicitly
486+
instruct the model to output the structured data.
487+
488+
Args:
489+
output_model: The output model (a JSON schema written as a Pydantic BaseModel)
490+
that the agent will use when responding.
491+
prompt: The prompt to use for the agent.
492+
493+
Raises:
494+
ValueError: If no conversation history or prompt is provided.
456495
"""
457496
self._hooks.invoke_callbacks(StartRequestEvent(agent=self))
458497

459498
try:
460-
messages = self.messages
461-
if not messages and not prompt:
499+
if not self.messages and not prompt:
462500
raise ValueError("No conversation history or prompt provided")
463501

464502
# add the prompt as the last message
465503
if prompt:
466-
messages.append({"role": "user", "content": [{"text": prompt}]})
504+
self.messages.append({"role": "user", "content": [{"text": prompt}]})
467505

468-
# get the structured output from the model
469-
events = self.model.structured_output(output_model, messages)
470-
for event in events:
506+
events = self.model.structured_output(output_model, self.messages)
507+
async for event in events:
471508
if "callback" in event:
472509
self.callback_handler(**cast(dict, event["callback"]))
473510

474511
return event["output"]
512+
475513
finally:
476514
self._hooks.invoke_callbacks(EndRequestEvent(agent=self))
477515

@@ -511,21 +549,22 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
511549

512550
try:
513551
events = self._run_loop(prompt, kwargs)
514-
for event in events:
552+
async for event in events:
515553
if "callback" in event:
516554
callback_handler(**event["callback"])
517555
yield event["callback"]
518556

519-
stop_reason, message, metrics, state = event["stop"]
520-
result = AgentResult(stop_reason, message, metrics, state)
557+
result = AgentResult(*event["stop"])
558+
callback_handler(result=result)
559+
yield {"result": result}
521560

522561
self._end_agent_trace_span(response=result)
523562

524563
except Exception as e:
525564
self._end_agent_trace_span(error=e)
526565
raise
527566

528-
def _run_loop(self, prompt: str, kwargs: dict[str, Any]) -> Generator[dict[str, Any], None, None]:
567+
async def _run_loop(self, prompt: str, kwargs: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
529568
"""Execute the agent's event loop with the given prompt and parameters."""
530569
self._hooks.invoke_callbacks(StartRequestEvent(agent=self))
531570

@@ -539,13 +578,15 @@ def _run_loop(self, prompt: str, kwargs: dict[str, Any]) -> Generator[dict[str,
539578
self.messages.append(new_message)
540579

541580
# Execute the event loop cycle with retry logic for context limits
542-
yield from self._execute_event_loop_cycle(kwargs)
581+
events = self._execute_event_loop_cycle(kwargs)
582+
async for event in events:
583+
yield event
543584

544585
finally:
545586
self.conversation_manager.apply_management(self)
546587
self._hooks.invoke_callbacks(EndRequestEvent(agent=self))
547588

548-
def _execute_event_loop_cycle(self, kwargs: dict[str, Any]) -> Generator[dict[str, Any], None, None]:
589+
async def _execute_event_loop_cycle(self, kwargs: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
549590
"""Execute the event loop cycle with retry logic for context window limits.
550591
551592
This internal method handles the execution of the event loop cycle and implements
@@ -583,7 +624,7 @@ def _execute_event_loop_cycle(self, kwargs: dict[str, Any]) -> Generator[dict[st
583624

584625
try:
585626
# Execute the main event loop cycle
586-
yield from event_loop_cycle(
627+
events = event_loop_cycle(
587628
model=self.model,
588629
system_prompt=self.system_prompt,
589630
messages=self.messages, # will be modified by event_loop_cycle
@@ -594,11 +635,15 @@ def _execute_event_loop_cycle(self, kwargs: dict[str, Any]) -> Generator[dict[st
594635
event_loop_parent_span=self.trace_span,
595636
kwargs=kwargs,
596637
)
638+
async for event in events:
639+
yield event
597640

598641
except ContextWindowOverflowException as e:
599642
# Try reducing the context size and retrying
600643
self.conversation_manager.reduce_context(self, e=e)
601-
yield from self._execute_event_loop_cycle(kwargs)
644+
events = self._execute_event_loop_cycle(kwargs)
645+
async for event in events:
646+
yield event
602647

603648
def _record_tool_execution(
604649
self,
@@ -623,7 +668,7 @@ def _record_tool_execution(
623668
messages: The message history to append to.
624669
"""
625670
# Create user message describing the tool call
626-
user_msg_content: List[ContentBlock] = [
671+
user_msg_content: list[ContentBlock] = [
627672
{"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {json.dumps(tool['input'])}\n")}
628673
]
629674

src/strands/event_loop/event_loop.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import uuid
1414
from concurrent.futures import ThreadPoolExecutor
1515
from functools import partial
16-
from typing import Any, Generator, Optional
16+
from typing import Any, AsyncGenerator, Optional
1717

1818
from opentelemetry import trace
1919

@@ -35,7 +35,7 @@
3535
MAX_DELAY = 240 # 4 minutes
3636

3737

38-
def event_loop_cycle(
38+
async def event_loop_cycle(
3939
model: Model,
4040
system_prompt: Optional[str],
4141
messages: Messages,
@@ -45,7 +45,7 @@ def event_loop_cycle(
4545
event_loop_metrics: EventLoopMetrics,
4646
event_loop_parent_span: Optional[trace.Span],
4747
kwargs: dict[str, Any],
48-
) -> Generator[dict[str, Any], None, None]:
48+
) -> AsyncGenerator[dict[str, Any], None]:
4949
"""Execute a single cycle of the event loop.
5050
5151
This core function processes a single conversation turn, handling model inference, tool execution, and error
@@ -132,7 +132,7 @@ def event_loop_cycle(
132132
try:
133133
# TODO: To maintain backwards compatability, we need to combine the stream event with kwargs before yielding
134134
# to the callback handler. This will be revisited when migrating to strongly typed events.
135-
for event in stream_messages(model, system_prompt, messages, tool_config):
135+
async for event in stream_messages(model, system_prompt, messages, tool_config):
136136
if "callback" in event:
137137
yield {"callback": {**event["callback"], **(kwargs if "delta" in event["callback"] else {})}}
138138

@@ -202,7 +202,7 @@ def event_loop_cycle(
202202
)
203203

204204
# Handle tool execution
205-
yield from _handle_tool_execution(
205+
events = _handle_tool_execution(
206206
stop_reason,
207207
message,
208208
model,
@@ -218,6 +218,9 @@ def event_loop_cycle(
218218
cycle_start_time,
219219
kwargs,
220220
)
221+
async for event in events:
222+
yield event
223+
221224
return
222225

223226
# End the cycle and return results
@@ -250,7 +253,7 @@ def event_loop_cycle(
250253
yield {"stop": (stop_reason, message, event_loop_metrics, kwargs["request_state"])}
251254

252255

253-
def recurse_event_loop(
256+
async def recurse_event_loop(
254257
model: Model,
255258
system_prompt: Optional[str],
256259
messages: Messages,
@@ -260,7 +263,7 @@ def recurse_event_loop(
260263
event_loop_metrics: EventLoopMetrics,
261264
event_loop_parent_span: Optional[trace.Span],
262265
kwargs: dict[str, Any],
263-
) -> Generator[dict[str, Any], None, None]:
266+
) -> AsyncGenerator[dict[str, Any], None]:
264267
"""Make a recursive call to event_loop_cycle with the current state.
265268
266269
This function is used when the event loop needs to continue processing after tool execution.
@@ -292,7 +295,8 @@ def recurse_event_loop(
292295
cycle_trace.add_child(recursive_trace)
293296

294297
yield {"callback": {"start": True}}
295-
yield from event_loop_cycle(
298+
299+
events = event_loop_cycle(
296300
model=model,
297301
system_prompt=system_prompt,
298302
messages=messages,
@@ -303,11 +307,13 @@ def recurse_event_loop(
303307
event_loop_parent_span=event_loop_parent_span,
304308
kwargs=kwargs,
305309
)
310+
async for event in events:
311+
yield event
306312

307313
recursive_trace.end()
308314

309315

310-
def _handle_tool_execution(
316+
async def _handle_tool_execution(
311317
stop_reason: StopReason,
312318
message: Message,
313319
model: Model,
@@ -322,7 +328,7 @@ def _handle_tool_execution(
322328
cycle_span: Any,
323329
cycle_start_time: float,
324330
kwargs: dict[str, Any],
325-
) -> Generator[dict[str, Any], None, None]:
331+
) -> AsyncGenerator[dict[str, Any], None]:
326332
tool_uses: list[ToolUse] = []
327333
tool_results: list[ToolResult] = []
328334
invalid_tool_use_ids: list[str] = []
@@ -369,7 +375,7 @@ def _handle_tool_execution(
369375
kwargs=kwargs,
370376
)
371377

372-
yield from run_tools(
378+
tool_events = run_tools(
373379
handler=tool_handler_process,
374380
tool_uses=tool_uses,
375381
event_loop_metrics=event_loop_metrics,
@@ -379,6 +385,8 @@ def _handle_tool_execution(
379385
parent_span=cycle_span,
380386
thread_pool=thread_pool,
381387
)
388+
for tool_event in tool_events:
389+
yield tool_event
382390

383391
# Store parent cycle ID for the next cycle
384392
kwargs["event_loop_parent_cycle_id"] = kwargs["event_loop_cycle_id"]
@@ -400,7 +408,7 @@ def _handle_tool_execution(
400408
yield {"stop": (stop_reason, message, event_loop_metrics, kwargs["request_state"])}
401409
return
402410

403-
yield from recurse_event_loop(
411+
events = recurse_event_loop(
404412
model=model,
405413
system_prompt=system_prompt,
406414
messages=messages,
@@ -411,3 +419,5 @@ def _handle_tool_execution(
411419
event_loop_parent_span=event_loop_parent_span,
412420
kwargs=kwargs,
413421
)
422+
async for event in events:
423+
yield event

0 commit comments

Comments
 (0)