diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 71165926..02a56a1c 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -136,6 +136,7 @@ def event_loop_cycle( metrics: Metrics # Retry loop for handling throttling exceptions + current_delay = INITIAL_DELAY for attempt in range(MAX_ATTEMPTS): model_id = model.config.get("model_id") if hasattr(model, "config") else None model_invoke_span = tracer.start_model_invoke_span( @@ -168,7 +169,7 @@ def event_loop_cycle( # Handle throttling errors with exponential backoff should_retry, current_delay = handle_throttling_error( - e, attempt, MAX_ATTEMPTS, INITIAL_DELAY, MAX_DELAY, callback_handler, kwargs + e, attempt, MAX_ATTEMPTS, current_delay, MAX_DELAY, callback_handler, kwargs ) if should_retry: continue diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 8c46e009..734457aa 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -11,6 +11,13 @@ from strands.types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException +@pytest.fixture +def mock_time(): + """Fixture to mock the time module in the error_handler.""" + with unittest.mock.patch.object(strands.event_loop.error_handler, "time") as mock: + yield mock + + @pytest.fixture def model(): return unittest.mock.Mock() @@ -157,8 +164,8 @@ def test_event_loop_cycle_text_response( assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state -@unittest.mock.patch.object(strands.event_loop.error_handler, "time") def test_event_loop_cycle_text_response_throttling( + mock_time, model, model_id, system_prompt, @@ -191,6 +198,53 @@ def test_event_loop_cycle_text_response_throttling( exp_request_state = {} assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state + # Verify that sleep was called once with the initial delay + mock_time.sleep.assert_called_once() + + +def test_event_loop_cycle_exponential_backoff( + mock_time, + model, + model_id, + system_prompt, + messages, + tool_config, + callback_handler, + tool_handler, + tool_execution_handler, +): + """Test that the exponential backoff works correctly with multiple retries.""" + # Set up the model to raise throttling exceptions multiple times before succeeding + model.converse.side_effect = [ + ModelThrottledException("ThrottlingException | ConverseStream"), + ModelThrottledException("ThrottlingException | ConverseStream"), + ModelThrottledException("ThrottlingException | ConverseStream"), + [ + {"contentBlockDelta": {"delta": {"text": "test text"}}}, + {"contentBlockStop": {}}, + ], + ] + + tru_stop_reason, tru_message, _, tru_request_state = strands.event_loop.event_loop.event_loop_cycle( + model=model, + model_id=model_id, + system_prompt=system_prompt, + messages=messages, + tool_config=tool_config, + callback_handler=callback_handler, + tool_handler=tool_handler, + tool_execution_handler=tool_execution_handler, + ) + + # Verify the final response + assert tru_stop_reason == "end_turn" + assert tru_message == {"role": "assistant", "content": [{"text": "test text"}]} + assert tru_request_state == {} + + # Verify that sleep was called with increasing delays + # Initial delay is 4, then 8, then 16 + assert mock_time.sleep.call_count == 3 + assert mock_time.sleep.call_args_list == [call(4), call(8), call(16)] def test_event_loop_cycle_text_response_error(