Skip to content

Commit 8e32bac

Browse files
committed
attempt to avoid/minimize deadlocks
1 parent 13703f2 commit 8e32bac

File tree

3 files changed

+134
-13
lines changed

3 files changed

+134
-13
lines changed

litellm/proxy/db/db_spend_update_writer.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import os
1111
import time
1212
import traceback
13+
import random
1314
from datetime import datetime, timedelta
1415
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union, cast, overload
1516

@@ -769,7 +770,14 @@ async def _commit_spend_updates_to_db( # noqa: PLR0915
769770
proxy_logging_obj=proxy_logging_obj,
770771
)
771772
# Optionally, sleep for a bit before retrying
772-
await asyncio.sleep(2**i) # Exponential backoff
773+
await asyncio.sleep(
774+
# Sleep a random amount to avoid retrying and deadlocking again: when two transactions deadlock they are
775+
# cancelled basically at the same time, so if they wait the same time they will also retry at the same time
776+
# and thus they are more likely to deadlock again.
777+
# Instead, we sleep a random amount so that they retry at slightly different times, lowering the chance of
778+
# repeated deadlocks, and therefore of exceeding the retry limit.
779+
random.uniform(2**i, 2 ** (i + 1))
780+
)
773781
except Exception as e:
774782
_raise_failed_update_spend_exception(
775783
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
@@ -849,8 +857,27 @@ async def _update_daily_spend(
849857
try:
850858
for i in range(n_retry_times + 1):
851859
try:
860+
# Sort the transactions to minimize the probability of deadlocks by reducing the chance of concurrent
861+
# trasactions locking the same rows/ranges in different orders.
852862
transactions_to_process = dict(
853-
list(daily_spend_transactions.items())[:BATCH_SIZE]
863+
sorted(
864+
daily_spend_transactions.items(),
865+
# Normally to avoid deadlocks we would sort by the index, but since we have sprinkled indexes
866+
# on our schema like we're discount Salt Bae, we just sort by all fields that have an index,
867+
# in an ad-hoc (but hopefully sensible) order of indexes. The actual ordering matters less than
868+
# ensuring that all concurrent transactions sort in the same order.
869+
# We could in theory use the dict key, as it contains basically the same fields, but this is more
870+
# robust to future changes in the key format.
871+
# If _update_daily_spend ever gets the ability to write to multiple tables at once, the sorting
872+
# should sort by the table first.
873+
key=lambda x: (
874+
x[1]["date"],
875+
x[1].get(entity_id_field),
876+
x[1]["api_key"],
877+
x[1]["model"],
878+
x[1]["custom_llm_provider"],
879+
),
880+
)[:BATCH_SIZE]
854881
)
855882

856883
if len(transactions_to_process) == 0:
@@ -893,7 +920,8 @@ async def _update_daily_spend(
893920
"model_group": transaction.get("model_group"),
894921
"mcp_namespaced_tool_name": transaction.get(
895922
"mcp_namespaced_tool_name"
896-
) or "",
923+
)
924+
or "",
897925
"custom_llm_provider": transaction.get(
898926
"custom_llm_provider"
899927
),
@@ -909,13 +937,13 @@ async def _update_daily_spend(
909937

910938
# Add cache-related fields if they exist
911939
if "cache_read_input_tokens" in transaction:
912-
common_data["cache_read_input_tokens"] = (
913-
transaction.get("cache_read_input_tokens", 0)
914-
)
940+
common_data[
941+
"cache_read_input_tokens"
942+
] = transaction.get("cache_read_input_tokens", 0)
915943
if "cache_creation_input_tokens" in transaction:
916-
common_data["cache_creation_input_tokens"] = (
917-
transaction.get("cache_creation_input_tokens", 0)
918-
)
944+
common_data[
945+
"cache_creation_input_tokens"
946+
] = transaction.get("cache_creation_input_tokens", 0)
919947

920948
# Create update data structure
921949
update_data = {
@@ -976,7 +1004,14 @@ async def _update_daily_spend(
9761004
start_time=start_time,
9771005
proxy_logging_obj=proxy_logging_obj,
9781006
)
979-
await asyncio.sleep(2**i)
1007+
await asyncio.sleep(
1008+
# Sleep a random amount to avoid retrying and deadlocking again: when two transactions deadlock they are
1009+
# cancelled basically at the same time, so if they wait the same time they will also retry at the same time
1010+
# and thus they are more likely to deadlock again.
1011+
# Instead, we sleep a random amount so that they retry at slightly different times, lowering the chance of
1012+
# repeated deadlocks, and therefore of exceeding the retry limit.
1013+
random.uniform(2**i, 2 ** (i + 1))
1014+
)
9801015

9811016
except Exception as e:
9821017
if "transactions_to_process" in locals():

tests/proxy_unit_tests/test_update_spend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,8 @@ async def mock_sleep(seconds):
198198

199199
# Verify exponential backoff
200200
assert len(sleep_times) == 2 # Should have slept twice
201-
assert sleep_times[0] == 1 # First retry after 2^0 seconds
202-
assert sleep_times[1] == 2 # Second retry after 2^1 seconds
201+
assert sleep_times[0] >= 1 and sleep_times[0] <= 2 # First retry after 2^0~2^1 seconds
202+
assert sleep_times[1] >= 2 and sleep_times[1] <= 4 # Second retry after 2^1~2^2 seconds
203203

204204

205205
@pytest.mark.asyncio

tests/test_litellm/proxy/db/test_db_spend_update_writer.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
from datetime import datetime
11-
from unittest.mock import AsyncMock, MagicMock, patch
11+
from unittest.mock import AsyncMock, MagicMock, patch, call
1212

1313
import pytest
1414

@@ -135,3 +135,89 @@ async def test_update_daily_spend_with_null_entity_id():
135135
assert create_data["api_requests"] == 1
136136
assert create_data["successful_requests"] == 1
137137
assert create_data["failed_requests"] == 0
138+
139+
140+
@pytest.mark.asyncio
141+
async def test_update_daily_spend_sorting():
142+
"""
143+
Test that table.upsert is called with events sorted
144+
145+
Ensures that writes are sorted between transactions to minimize deadlocks
146+
"""
147+
# Setup
148+
mock_prisma_client = MagicMock()
149+
mock_batcher = MagicMock()
150+
mock_table = MagicMock()
151+
mock_prisma_client.db.batch_.return_value.__aenter__.return_value = mock_batcher
152+
mock_batcher.litellm_dailyuserspend = mock_table
153+
154+
# Create a 50 transactions with out-of-order entity_ids
155+
# In reality we sort using multiple fields, but entity_id is sufficient to test sorting
156+
daily_spend_transactions = {}
157+
upsert_calls = []
158+
for i in range(50):
159+
daily_spend_transactions[f"test_key_{i}"] = {
160+
"user_id": f"user{60-i}", # user60 ... user11, reverse order
161+
"date": "2024-01-01",
162+
"api_key": "test-api-key",
163+
"model": "gpt-4",
164+
"custom_llm_provider": "openai",
165+
"prompt_tokens": 10,
166+
"completion_tokens": 20,
167+
"spend": 0.1,
168+
"api_requests": 1,
169+
"successful_requests": 1,
170+
"failed_requests": 0,
171+
}
172+
upsert_calls.append(call(
173+
where={
174+
"user_id_date_api_key_model_custom_llm_provider": {
175+
"user_id": f"user{i+11}", # user11 ... user60, sorted order
176+
"date": "2024-01-01",
177+
"api_key": "test-api-key",
178+
"model": "gpt-4",
179+
"custom_llm_provider": "openai",
180+
"mcp_namespaced_tool_name": "",
181+
}
182+
},
183+
data={
184+
"create": {
185+
"user_id": f"user{i+11}",
186+
"date": "2024-01-01",
187+
"api_key": "test-api-key",
188+
"model": "gpt-4",
189+
"model_group": None,
190+
"mcp_namespaced_tool_name": "",
191+
"custom_llm_provider": "openai",
192+
"prompt_tokens": 10,
193+
"completion_tokens": 20,
194+
"spend": 0.1,
195+
"api_requests": 1,
196+
"successful_requests": 1,
197+
"failed_requests": 0,
198+
},
199+
"update": {
200+
"prompt_tokens": {"increment": 10},
201+
"completion_tokens": {"increment": 20},
202+
"spend": {"increment": 0.1},
203+
"api_requests": {"increment": 1},
204+
"successful_requests": {"increment": 1},
205+
"failed_requests": {"increment": 0},
206+
},
207+
},
208+
))
209+
210+
# Call the method
211+
await DBSpendUpdateWriter._update_daily_spend(
212+
n_retry_times=1,
213+
prisma_client=mock_prisma_client,
214+
proxy_logging_obj=MagicMock(),
215+
daily_spend_transactions=daily_spend_transactions,
216+
entity_type="user",
217+
entity_id_field="user_id",
218+
table_name="litellm_dailyuserspend",
219+
unique_constraint_name="user_id_date_api_key_model_custom_llm_provider",
220+
)
221+
222+
# Verify that table.upsert was called
223+
mock_table.upsert.assert_has_calls(upsert_calls)

0 commit comments

Comments
 (0)