Skip to content
Open
47 changes: 47 additions & 0 deletions litellm/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,10 @@ def __init__( # noqa: PLR0915
if self.alerting_config is not None:
self._initialize_alerting()

# --- Lifecycle additions (minimal) ---
# Track closed state for idempotent teardown
self._closed: bool = False

self.initialize_assistants_endpoint()
self.initialize_router_endpoints()
self.apply_default_settings()
Expand Down Expand Up @@ -659,6 +663,49 @@ def discard(self):
litellm.callbacks, callback, require_self=False
)

# ------------------------------
# Public Lifecycle API
# ------------------------------
def close(self) -> None:
"""
Deterministically tear down Router hooks/callbacks.

Minimal and idempotent: marks closed and unhooks callbacks so short‑lived
scripts/tests exit cleanly without lingering Router-managed globals.
"""
if getattr(self, "_closed", False):
return
self._closed = True
# Unhook router-specific callbacks from global managers
try:
self.discard()
except Exception:
pass

async def aclose(self) -> None:
"""
Async variant of close(). Provided for symmetry.
"""
self.close()

# ------------------------------
# Context Manager Support
# ------------------------------
def __enter__(self):
return self

def __exit__(self, exc_type, exc, tb):
try:
self.close()
finally:
return False # propagate exceptions

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc, tb):
await self.aclose()

@staticmethod
def _create_redis_cache(
cache_config: Dict[str, Any],
Expand Down
50 changes: 50 additions & 0 deletions tests/test_litellm/test_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,40 @@ def test_update_kwargs_does_not_mutate_defaults_and_merges_metadata():
assert kwargs["litellm_metadata"] == {"baz": 123}


def test_router_close_idempotent():
import litellm

router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo"},
}
]
)

router.close()
# calling twice should be safe
router.close()
assert getattr(router, "_closed", False) is True


def test_router_sync_context_manager():
import litellm

with litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo"},
}
]
) as r:
assert isinstance(r, litellm.Router)
# after exiting context, router should be closed
assert getattr(r, "_closed", False) is True


def test_router_with_model_info_and_model_group():
"""
Test edge case where user specifies model_group in model_info
Expand Down Expand Up @@ -122,6 +156,22 @@ async def test_arouter_with_tags_and_fallbacks():
)


@pytest.mark.asyncio
async def test_router_async_context_manager():
import litellm

async with litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo"},
}
]
) as r:
assert isinstance(r, litellm.Router)
assert getattr(r, "_closed", False) is True


@pytest.mark.asyncio
async def test_async_router_acreate_file():
"""
Expand Down
Loading