Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 45 additions & 10 deletions litellm/proxy/db/db_spend_update_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import os
import time
import traceback
import random
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union, cast, overload

Expand Down Expand Up @@ -769,7 +770,14 @@ async def _commit_spend_updates_to_db( # noqa: PLR0915
proxy_logging_obj=proxy_logging_obj,
)
# Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # Exponential backoff
await asyncio.sleep(
# Sleep a random amount to avoid retrying and deadlocking again: when two transactions deadlock they are
# cancelled basically at the same time, so if they wait the same time they will also retry at the same time
# and thus they are more likely to deadlock again.
# Instead, we sleep a random amount so that they retry at slightly different times, lowering the chance of
# repeated deadlocks, and therefore of exceeding the retry limit.
random.uniform(2**i, 2 ** (i + 1))
)
except Exception as e:
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
Expand Down Expand Up @@ -849,8 +857,27 @@ async def _update_daily_spend(
try:
for i in range(n_retry_times + 1):
try:
# Sort the transactions to minimize the probability of deadlocks by reducing the chance of concurrent
# trasactions locking the same rows/ranges in different orders.
transactions_to_process = dict(
list(daily_spend_transactions.items())[:BATCH_SIZE]
sorted(
daily_spend_transactions.items(),
# Normally to avoid deadlocks we would sort by the index, but since we have sprinkled indexes
# on our schema like we're discount Salt Bae, we just sort by all fields that have an index,
# in an ad-hoc (but hopefully sensible) order of indexes. The actual ordering matters less than
# ensuring that all concurrent transactions sort in the same order.
# We could in theory use the dict key, as it contains basically the same fields, but this is more
# robust to future changes in the key format.
# If _update_daily_spend ever gets the ability to write to multiple tables at once, the sorting
# should sort by the table first.
key=lambda x: (
x[1]["date"],
x[1].get(entity_id_field),
x[1]["api_key"],
x[1]["model"],
x[1]["custom_llm_provider"],
),
)[:BATCH_SIZE]
)

if len(transactions_to_process) == 0:
Expand Down Expand Up @@ -893,7 +920,8 @@ async def _update_daily_spend(
"model_group": transaction.get("model_group"),
"mcp_namespaced_tool_name": transaction.get(
"mcp_namespaced_tool_name"
) or "",
)
or "",
"custom_llm_provider": transaction.get(
"custom_llm_provider"
),
Expand All @@ -909,13 +937,13 @@ async def _update_daily_spend(

# Add cache-related fields if they exist
if "cache_read_input_tokens" in transaction:
common_data["cache_read_input_tokens"] = (
transaction.get("cache_read_input_tokens", 0)
)
common_data[
"cache_read_input_tokens"
] = transaction.get("cache_read_input_tokens", 0)
if "cache_creation_input_tokens" in transaction:
common_data["cache_creation_input_tokens"] = (
transaction.get("cache_creation_input_tokens", 0)
)
common_data[
"cache_creation_input_tokens"
] = transaction.get("cache_creation_input_tokens", 0)

# Create update data structure
update_data = {
Expand Down Expand Up @@ -976,7 +1004,14 @@ async def _update_daily_spend(
start_time=start_time,
proxy_logging_obj=proxy_logging_obj,
)
await asyncio.sleep(2**i)
await asyncio.sleep(
# Sleep a random amount to avoid retrying and deadlocking again: when two transactions deadlock they are
# cancelled basically at the same time, so if they wait the same time they will also retry at the same time
# and thus they are more likely to deadlock again.
# Instead, we sleep a random amount so that they retry at slightly different times, lowering the chance of
# repeated deadlocks, and therefore of exceeding the retry limit.
random.uniform(2**i, 2 ** (i + 1))
)

except Exception as e:
if "transactions_to_process" in locals():
Expand Down
4 changes: 2 additions & 2 deletions tests/proxy_unit_tests/test_update_spend.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ async def mock_sleep(seconds):

# Verify exponential backoff
assert len(sleep_times) == 2 # Should have slept twice
assert sleep_times[0] == 1 # First retry after 2^0 seconds
assert sleep_times[1] == 2 # Second retry after 2^1 seconds
assert sleep_times[0] >= 1 and sleep_times[0] <= 2 # First retry after 2^0~2^1 seconds
assert sleep_times[1] >= 2 and sleep_times[1] <= 4 # Second retry after 2^1~2^2 seconds


@pytest.mark.asyncio
Expand Down
88 changes: 87 additions & 1 deletion tests/test_litellm/proxy/db/test_db_spend_update_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch, call

import pytest

Expand Down Expand Up @@ -135,3 +135,89 @@ async def test_update_daily_spend_with_null_entity_id():
assert create_data["api_requests"] == 1
assert create_data["successful_requests"] == 1
assert create_data["failed_requests"] == 0


@pytest.mark.asyncio
async def test_update_daily_spend_sorting():
"""
Test that table.upsert is called with events sorted

Ensures that writes are sorted between transactions to minimize deadlocks
"""
# Setup
mock_prisma_client = MagicMock()
mock_batcher = MagicMock()
mock_table = MagicMock()
mock_prisma_client.db.batch_.return_value.__aenter__.return_value = mock_batcher
mock_batcher.litellm_dailyuserspend = mock_table

# Create a 50 transactions with out-of-order entity_ids
# In reality we sort using multiple fields, but entity_id is sufficient to test sorting
daily_spend_transactions = {}
upsert_calls = []
for i in range(50):
daily_spend_transactions[f"test_key_{i}"] = {
"user_id": f"user{60-i}", # user60 ... user11, reverse order
"date": "2024-01-01",
"api_key": "test-api-key",
"model": "gpt-4",
"custom_llm_provider": "openai",
"prompt_tokens": 10,
"completion_tokens": 20,
"spend": 0.1,
"api_requests": 1,
"successful_requests": 1,
"failed_requests": 0,
}
upsert_calls.append(call(
where={
"user_id_date_api_key_model_custom_llm_provider": {
"user_id": f"user{i+11}", # user11 ... user60, sorted order
"date": "2024-01-01",
"api_key": "test-api-key",
"model": "gpt-4",
"custom_llm_provider": "openai",
"mcp_namespaced_tool_name": "",
}
},
data={
"create": {
"user_id": f"user{i+11}",
"date": "2024-01-01",
"api_key": "test-api-key",
"model": "gpt-4",
"model_group": None,
"mcp_namespaced_tool_name": "",
"custom_llm_provider": "openai",
"prompt_tokens": 10,
"completion_tokens": 20,
"spend": 0.1,
"api_requests": 1,
"successful_requests": 1,
"failed_requests": 0,
},
"update": {
"prompt_tokens": {"increment": 10},
"completion_tokens": {"increment": 20},
"spend": {"increment": 0.1},
"api_requests": {"increment": 1},
"successful_requests": {"increment": 1},
"failed_requests": {"increment": 0},
},
},
))

# Call the method
await DBSpendUpdateWriter._update_daily_spend(
n_retry_times=1,
prisma_client=mock_prisma_client,
proxy_logging_obj=MagicMock(),
daily_spend_transactions=daily_spend_transactions,
entity_type="user",
entity_id_field="user_id",
table_name="litellm_dailyuserspend",
unique_constraint_name="user_id_date_api_key_model_custom_llm_provider",
)

# Verify that table.upsert was called
mock_table.upsert.assert_has_calls(upsert_calls)
Loading