diff --git a/litellm/caching/caching_handler.py b/litellm/caching/caching_handler.py index 9526c4a2f393..fcbc52ad34fa 100644 --- a/litellm/caching/caching_handler.py +++ b/litellm/caching/caching_handler.py @@ -534,15 +534,11 @@ def _async_log_cache_hit_on_callbacks( from litellm.litellm_core_utils.logging_worker import GLOBAL_LOGGING_WORKER GLOBAL_LOGGING_WORKER.ensure_initialized_and_enqueue( - async_coroutine=logging_obj.async_success_handler( + async_coroutine=logging_obj.unified_success_handler( result=cached_result, start_time=start_time, end_time=end_time, cache_hit=cache_hit ) ) - logging_obj.handle_sync_success_callbacks_for_async_calls( - result=cached_result, start_time=start_time, end_time=end_time, cache_hit=cache_hit - ) - async def _retrieve_from_cache( self, call_type: str, kwargs: Dict[str, Any], args: Tuple[Any, ...] ) -> Optional[Any]: diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 19d7c5512ba6..afb902816d44 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -1,6 +1,7 @@ # What is this? ## Common Utility file for Logging handler # Logging function -> log the exact model details + what's being sent | Non-Blocking +import asyncio import copy import datetime import json @@ -11,7 +12,7 @@ import time import traceback from datetime import datetime as dt_object -from functools import lru_cache +from functools import lru_cache, partial from typing import ( TYPE_CHECKING, Any, @@ -241,6 +242,11 @@ def set_cache(self, litellm_call_id: str, service_name: str, trace_id: str) -> N in_memory_dynamic_logger_cache = DynamicLoggingCache() +AssembledStreamingResponse = Optional[ + Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse] +] + + class Logging(LiteLLMLoggingBaseClass): global supabaseClient, promptLayerLogger, weightsBiasesLogger, logfireLogger, capture_exception, add_breadcrumb, lunaryLogger, logfireLogger, prometheusLogger, slack_app custom_pricing: bool = False @@ -1251,24 +1257,6 @@ def _response_cost_calculator( return None - async def _response_cost_calculator_async( - self, - result: Union[ - ModelResponse, - ModelResponseStream, - EmbeddingResponse, - ImageResponse, - TranscriptionResponse, - TextCompletionResponse, - HttpxBinaryResponseContent, - RerankResponse, - Batch, - FineTuningJob, - ], - cache_hit: Optional[bool] = None, - ) -> Optional[float]: - return self._response_cost_calculator(result=result, cache_hit=cache_hit) - def should_run_logging( self, event_type: Literal[ @@ -1383,6 +1371,7 @@ def _success_handler_helper_fn( end_time=None, cache_hit=None, standard_logging_object: Optional[StandardLoggingPayload] = None, + complete_streaming_response: AssembledStreamingResponse = None, ): try: if start_time is None: @@ -1411,74 +1400,63 @@ def _success_handler_helper_fn( logging_result = self.normalize_logging_result(result=result) - if ( + should_do_full_processing = ( standard_logging_object is None and result is not None and self.stream is not True - ): - if self._is_recognized_call_type_for_logging( - logging_result=logging_result - ): - ## HIDDEN PARAMS ## - hidden_params = getattr(logging_result, "_hidden_params", {}) - if hidden_params: - # add to metadata for logging - if self.model_call_details.get("litellm_params") is not None: - self.model_call_details["litellm_params"].setdefault( - "metadata", {} - ) - if ( - self.model_call_details["litellm_params"]["metadata"] - is None - ): - self.model_call_details["litellm_params"][ - "metadata" - ] = {} - - self.model_call_details["litellm_params"]["metadata"][ # type: ignore - "hidden_params" - ] = getattr( - logging_result, "_hidden_params", {} - ) - ## RESPONSE COST - Only calculate if not in hidden_params ## - if "response_cost" in hidden_params: - self.model_call_details["response_cost"] = hidden_params[ - "response_cost" - ] - else: - self.model_call_details["response_cost"] = ( - self._response_cost_calculator(result=logging_result) - ) - ## STANDARDIZED LOGGING PAYLOAD + and self._is_recognized_call_type_for_logging(logging_result=logging_result) + ) - self.model_call_details["standard_logging_object"] = ( - get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj=logging_result, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="success", - standard_built_in_tools_params=self.standard_built_in_tools_params, - ) - ) - elif isinstance(result, dict) or isinstance(result, list): - ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details["standard_logging_object"] = ( - get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj=result, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="success", - standard_built_in_tools_params=self.standard_built_in_tools_params, + should_create_logging_object = ( + standard_logging_object is None + and result is not None + and (self.stream is not True or complete_streaming_response is not None) + ) + + # Execute full processing if needed + if should_do_full_processing: + ## HIDDEN PARAMS ## + hidden_params = getattr(logging_result, "_hidden_params", {}) + if hidden_params: + # add to metadata for logging + if self.model_call_details.get("litellm_params") is not None: + self.model_call_details["litellm_params"].setdefault("metadata", {}) + if self.model_call_details["litellm_params"]["metadata"] is None: + self.model_call_details["litellm_params"]["metadata"] = {} + + self.model_call_details["litellm_params"]["metadata"]["hidden_params"] = getattr( + logging_result, "_hidden_params", {} ) - ) - elif standard_logging_object is not None: - self.model_call_details["standard_logging_object"] = ( - standard_logging_object + + ## RESPONSE COST - Only calculate if not in hidden_params ## + if "response_cost" in hidden_params: + self.model_call_details["response_cost"] = hidden_params["response_cost"] + else: + self.model_call_details["response_cost"] = self._response_cost_calculator(result=logging_result) + + # Handle standard_logging_object creation + if should_create_logging_object: + # Determine init_response_obj based on processing type + if should_do_full_processing: + init_response_obj = logging_result + elif complete_streaming_response is not None: + init_response_obj = complete_streaming_response + else: + init_response_obj = result + + + ## STANDARDIZED LOGGING PAYLOAD + self.model_call_details["standard_logging_object"] = get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj=init_response_obj, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="success", + standard_built_in_tools_params=self.standard_built_in_tools_params, ) + elif standard_logging_object is not None: + self.model_call_details["standard_logging_object"] = standard_logging_object else: # streaming chunks + image gen. self.model_call_details["response_cost"] = None @@ -1510,7 +1488,10 @@ def _success_handler_helper_fn( total_time=float_diff, standard_built_in_tools_params=self.standard_built_in_tools_params, ) - + result = redact_message_input_output_from_logging( + model_call_details=(self.model_call_details if hasattr(self, "model_call_details") else {}), + result=result, + ) return start_time, end_time, result except Exception as e: raise Exception(f"[Non-Blocking] LiteLLM.Success_Call Error: {str(e)}") @@ -1523,20 +1504,25 @@ def _is_recognized_call_type_for_logging( Returns True if the call type is recognized for logging (eg. ModelResponse, ModelResponseStream, etc.) """ if ( - isinstance(logging_result, ModelResponse) - or isinstance(logging_result, ModelResponseStream) - or isinstance(logging_result, EmbeddingResponse) - or isinstance(logging_result, ImageResponse) - or isinstance(logging_result, TranscriptionResponse) - or isinstance(logging_result, TextCompletionResponse) - or isinstance(logging_result, HttpxBinaryResponseContent) # tts - or isinstance(logging_result, RerankResponse) - or isinstance(logging_result, FineTuningJob) - or isinstance(logging_result, LiteLLMBatch) - or isinstance(logging_result, ResponsesAPIResponse) - or isinstance(logging_result, OpenAIFileObject) - or isinstance(logging_result, LiteLLMRealtimeStreamLoggingObject) - or isinstance(logging_result, OpenAIModerationResponse) + isinstance( + logging_result, + ( + ModelResponse, + ModelResponseStream, + EmbeddingResponse, + ImageResponse, + TranscriptionResponse, + TextCompletionResponse, + HttpxBinaryResponseContent, # tts + RerankResponse, + FineTuningJob, + LiteLLMBatch, + ResponsesAPIResponse, + OpenAIFileObject, + LiteLLMRealtimeStreamLoggingObject, + OpenAIModerationResponse, + ), + ) or (self.call_type == CallTypes.call_mcp_tool.value) ): return True @@ -1596,7 +1582,15 @@ async def async_flush_passthrough_collected_chunks( return def success_handler( # noqa: PLR0915 - self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs + self, + result=None, + start_time=None, + end_time=None, + cache_hit=None, + *, + unified_flow: bool = False, + complete_streaming_response: AssembledStreamingResponse = None, + **kwargs, ): verbose_logger.debug( f"Logging Details LiteLLM-Success Call: Cache_hit={cache_hit}" @@ -1605,27 +1599,20 @@ def success_handler( # noqa: PLR0915 event_type="sync_success" ): # prevent double logging return - start_time, end_time, result = self._success_handler_helper_fn( - start_time=start_time, - end_time=end_time, - result=result, - cache_hit=cache_hit, - standard_logging_object=kwargs.get("standard_logging_object", None), - ) - try: - ## BUILD COMPLETE STREAMED RESPONSE - complete_streaming_response: Optional[ - Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse] - ] = None - if "complete_streaming_response" in self.model_call_details: - return # break out of this. - complete_streaming_response = self._get_assembled_streaming_response( - result=result, + + if not unified_flow: # see unified_success_handler + start_time, end_time, result = self._success_handler_helper_fn( start_time=start_time, end_time=end_time, - is_async=False, - streaming_chunks=self.sync_streaming_chunks, + result=result, + cache_hit=cache_hit, + standard_logging_object=kwargs.get("standard_logging_object", None), ) + try: + if not unified_flow: # see unified_success_handler + complete_streaming_response = self._get_assembled_streaming_response( + result + ) if complete_streaming_response is not None: verbose_logger.debug( "Logging Details LiteLLM-Success Call streaming complete" @@ -1633,35 +1620,11 @@ def success_handler( # noqa: PLR0915 self.model_call_details["complete_streaming_response"] = ( complete_streaming_response ) - self.model_call_details["response_cost"] = ( - self._response_cost_calculator(result=complete_streaming_response) - ) - ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details["standard_logging_object"] = ( - get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj=complete_streaming_response, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="success", - standard_built_in_tools_params=self.standard_built_in_tools_params, - ) - ) callbacks = self.get_combined_callback_list( dynamic_success_callbacks=self.dynamic_success_callbacks, global_callbacks=litellm.success_callback, ) - ## REDACT MESSAGES ## - result = redact_message_input_output_from_logging( - model_call_details=( - self.model_call_details - if hasattr(self, "model_call_details") - else {} - ), - result=result, - ) ## LOGGING HOOK ## for callback in callbacks: if isinstance(callback, CustomLogger): @@ -2076,8 +2039,48 @@ def success_handler( # noqa: PLR0915 ), ) + async def unified_success_handler(self, result=None, start_time=None, end_time=None, cache_hit=None): + complete_streaming_response = self._get_assembled_streaming_response(result) + + # don't let _success_handler_helper_fn block the loop + start_time, end_time, result = await asyncio.get_event_loop().run_in_executor( + executor, + partial( + self._success_handler_helper_fn, + start_time=start_time, + end_time=end_time, + result=result, + cache_hit=cache_hit, + complete_streaming_response=complete_streaming_response, + ), + ) + self.handle_sync_success_callbacks_for_async_calls( + start_time=start_time, + end_time=end_time, + result=result, + cache_hit=cache_hit, + unified_flow=True, + complete_streaming_response=complete_streaming_response, + ) + await self.async_success_handler( + start_time=start_time, + end_time=end_time, + result=result, + cache_hit=cache_hit, + unified_flow=True, + complete_streaming_response=complete_streaming_response, + ) + async def async_success_handler( # noqa: PLR0915 - self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs + self, + result=None, + start_time=None, + end_time=None, + cache_hit=None, + *, + unified_flow: bool = False, + complete_streaming_response: AssembledStreamingResponse = None, + **kwargs, ): """ Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. @@ -2127,26 +2130,28 @@ async def async_success_handler( # noqa: PLR0915 result._hidden_params["batch_models"] = batch_models result.usage = batch_usage - start_time, end_time, result = self._success_handler_helper_fn( - start_time=start_time, - end_time=end_time, - result=result, - cache_hit=cache_hit, - standard_logging_object=kwargs.get("standard_logging_object", None), - ) + if not unified_flow: # see unified_success_handler + # don't let _success_handler_helper_fn block the loop + start_time, end_time, result = await asyncio.get_event_loop().run_in_executor( + executor, + partial( + self._success_handler_helper_fn, + start_time=start_time, + end_time=end_time, + result=result, + cache_hit=cache_hit, + standard_logging_object=kwargs.get("standard_logging_object", None), + ), + ) ## BUILD COMPLETE STREAMED RESPONSE if "async_complete_streaming_response" in self.model_call_details: return # break out of this. - complete_streaming_response: Optional[ - Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse] - ] = self._get_assembled_streaming_response( - result=result, - start_time=start_time, - end_time=end_time, - is_async=True, - streaming_chunks=self.streaming_chunks, - ) + + if not unified_flow: # see unified_success_helper + complete_streaming_response: Optional[ + Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse] + ] = self._get_assembled_streaming_response(result) if complete_streaming_response is not None: print_verbose("Async success callbacks: Got a complete streaming response") @@ -2159,10 +2164,6 @@ async def async_success_handler( # noqa: PLR0915 if self.model_call_details.get("cache_hit", False) is True: self.model_call_details["response_cost"] = 0.0 else: - # check if base_model set on azure - _get_base_model_from_metadata( - model_call_details=self.model_call_details - ) # base_model defaults to None if not set on model_info self.model_call_details["response_cost"] = ( self._response_cost_calculator( @@ -2179,30 +2180,11 @@ async def async_success_handler( # noqa: PLR0915 ) self.model_call_details["response_cost"] = None - ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details["standard_logging_object"] = ( - get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj=complete_streaming_response, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="success", - standard_built_in_tools_params=self.standard_built_in_tools_params, - ) - ) callbacks = self.get_combined_callback_list( dynamic_success_callbacks=self.dynamic_async_success_callbacks, global_callbacks=litellm._async_success_callback, ) - result = redact_message_input_output_from_logging( - model_call_details=( - self.model_call_details if hasattr(self, "model_call_details") else {} - ), - result=result, - ) - ## LOGGING HOOK ## for callback in callbacks: @@ -2778,6 +2760,9 @@ def handle_sync_success_callbacks_for_async_calls( start_time: datetime.datetime, end_time: datetime.datetime, cache_hit: Optional[Any] = None, + *, + unified_flow: bool = False, + complete_streaming_response: AssembledStreamingResponse = None, ) -> None: """ Handles calling success callbacks for Async calls. @@ -2793,6 +2778,8 @@ def handle_sync_success_callbacks_for_async_calls( start_time, end_time, cache_hit, + unified_flow=unified_flow, + complete_streaming_response=complete_streaming_response, ) def _should_run_sync_callbacks_for_async_calls(self) -> bool: @@ -2893,19 +2880,11 @@ def _get_assembled_streaming_response( ResponseCompletedEvent, Any, ], - start_time: datetime.datetime, - end_time: datetime.datetime, - is_async: bool, - streaming_chunks: List[Any], - ) -> Optional[Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse]]: - if isinstance(result, ModelResponse): - return result - elif isinstance(result, TextCompletionResponse): + ) -> AssembledStreamingResponse: + if isinstance(result, (ModelResponse, TextCompletionResponse)): return result - elif isinstance(result, ResponseCompletedEvent): + if isinstance(result, ResponseCompletedEvent): return result.response - else: - return None return None def _handle_anthropic_messages_response_logging(self, result: Any) -> ModelResponse: diff --git a/tests/test_litellm/test_logging.py b/tests/test_litellm/test_logging.py index 7e5931d8c0f9..d1b52ed35659 100644 --- a/tests/test_litellm/test_logging.py +++ b/tests/test_litellm/test_logging.py @@ -30,6 +30,7 @@ ) from litellm.integrations.custom_logger import CustomLogger from litellm.types.utils import StandardLoggingPayload +from litellm.litellm_core_utils.litellm_logging import Logging class CacheHitCustomLogger(CustomLogger): @@ -152,3 +153,36 @@ async def test_cache_hit_includes_custom_llm_provider(): # Clean up litellm.callbacks = original_callbacks litellm.cache = None + + +@pytest.mark.asyncio +async def test_unified_handler_calls_get_standard_logging_object_payload_once(mocker, monkeypatch): + """ + Tests that for a cache hit, + the test_unified_handler_calls_get_standard_logging_object_payload_once is called exactly once. + """ + monkeypatch.setattr(litellm, "cache", litellm.Cache()) + test_message = [{"role": "user", "content": "helper function call test"}] + + mock_helper_fn = mocker.spy(Logging, "_success_handler_helper_fn") + mock_get_standard_logging_object = mocker.spy( + litellm.litellm_core_utils.litellm_logging, "get_standard_logging_object_payload" + ) + # First call (miss) - this will call the helper function + await litellm.acompletion(model="gpt-3.5-turbo", messages=test_message, mock_response="r2", caching=True) + await asyncio.sleep(0.1) + mock_helper_fn.reset_mock() + mock_get_standard_logging_object.reset_mock() + + # Second call (hit) - this is the call we are testing + await litellm.acompletion(model="gpt-3.5-turbo", messages=test_message, mock_response="r2", caching=True) + await asyncio.sleep(0.1) # allow logs to process + + # Assert the helper function was called exactly once during the cache hit + mock_helper_fn.assert_called_once() + assert mock_helper_fn.call_args.kwargs["cache_hit"] is True + assert mock_helper_fn.call_args.kwargs["result"].choices[0].message.content == "r2" + + mock_get_standard_logging_object.assert_called_once() + assert mock_get_standard_logging_object.call_args.kwargs["status"] == "success" + assert mock_get_standard_logging_object.call_args.kwargs["init_response_obj"].choices[0].message.content == "r2"