1010import os
1111import time
1212import traceback
13+ import random
1314from datetime import datetime , timedelta
1415from typing import TYPE_CHECKING , Any , Dict , Literal , Optional , Union , cast , overload
1516
@@ -849,8 +850,23 @@ async def _update_daily_spend(
849850 try :
850851 for i in range (n_retry_times + 1 ):
851852 try :
853+ # Sort the transactions to ensure consistent processing order and minimize the probability of deadlocks
852854 transactions_to_process = dict (
853- list (daily_spend_transactions .items ())[:BATCH_SIZE ]
855+ sorted (
856+ daily_spend_transactions .items (),
857+ # Normally to avoid deadlocks we would sort by the index, but since we have sprinkled indexes
858+ # on our schema like we're discount Salt Bae, we just sort by all fields that have an index,
859+ # in an ad-hoc (but hopefully sensible) order of indexes.
860+ # If _update_daily_spend ever gets the ability to write to multiple tables at once, the sorting
861+ # should sort by the table first.
862+ key = lambda x : (
863+ x [1 ]["date" ],
864+ x [1 ][entity_id_field ],
865+ x [1 ]["api_key" ],
866+ x [1 ]["model" ],
867+ x [1 ]["custom_llm_provider" ],
868+ ),
869+ )[:BATCH_SIZE ]
854870 )
855871
856872 if len (transactions_to_process ) == 0 :
@@ -893,7 +909,8 @@ async def _update_daily_spend(
893909 "model_group" : transaction .get ("model_group" ),
894910 "mcp_namespaced_tool_name" : transaction .get (
895911 "mcp_namespaced_tool_name"
896- ) or "" ,
912+ )
913+ or "" ,
897914 "custom_llm_provider" : transaction .get (
898915 "custom_llm_provider"
899916 ),
@@ -909,13 +926,13 @@ async def _update_daily_spend(
909926
910927 # Add cache-related fields if they exist
911928 if "cache_read_input_tokens" in transaction :
912- common_data ["cache_read_input_tokens" ] = (
913- transaction . get ( "cache_read_input_tokens" , 0 )
914- )
929+ common_data [
930+ "cache_read_input_tokens"
931+ ] = transaction . get ( "cache_read_input_tokens" , 0 )
915932 if "cache_creation_input_tokens" in transaction :
916- common_data ["cache_creation_input_tokens" ] = (
917- transaction . get ( "cache_creation_input_tokens" , 0 )
918- )
933+ common_data [
934+ "cache_creation_input_tokens"
935+ ] = transaction . get ( "cache_creation_input_tokens" , 0 )
919936
920937 # Create update data structure
921938 update_data = {
@@ -976,7 +993,9 @@ async def _update_daily_spend(
976993 start_time = start_time ,
977994 proxy_logging_obj = proxy_logging_obj ,
978995 )
979- await asyncio .sleep (2 ** i )
996+ await asyncio .sleep (
997+ random .uniform (2 ** i , 2 ** (i + 1 ))
998+ ) # Sleep a random amount to avoid retrying and deadlocking again
980999
9811000 except Exception as e :
9821001 if "transactions_to_process" in locals ():
0 commit comments