From d390e93f071a56bae1bc65f48aa785cee0060d84 Mon Sep 17 00:00:00 2001 From: Aliaksandr Kuzmik <98702584+alexkuzmik@users.noreply.github.com> Date: Thu, 23 Jan 2025 11:44:48 +0100 Subject: [PATCH] [OPIK-750] sdk add the possibility to manually sets the cost of individual spans (#1107) * Add total_cost to low-level API * Pass total_cost to requests payload as estimated_total_cost * Add total_cost to update_current_span * Add e2e test for total cost logging * Fix lint errors * Update unit tests and backend emulator * Update old test * Disable sentry if pytest env is detected. Relax configurator test * Update the docstrings to mention price units * Add missing docstring --- sdks/python/src/opik/__init__.py | 3 +- .../src/opik/api_objects/opik_client.py | 3 ++ sdks/python/src/opik/api_objects/span.py | 11 ++++++++ sdks/python/src/opik/api_objects/trace.py | 3 ++ .../message_processing/message_processors.py | 25 ++--------------- .../src/opik/message_processing/messages.py | 14 ++++++++++ sdks/python/src/opik/opik_context.py | 3 ++ sdks/python/tests/e2e/test_tracing.py | 28 +++++++++++++++++++ sdks/python/tests/e2e/verifiers.py | 8 ++++-- .../backend_emulator_message_processor.py | 2 ++ sdks/python/tests/testlib/models.py | 1 + .../tests/unit/configurator/test_configure.py | 3 +- .../unit/decorator/test_tracker_outputs.py | 5 +++- .../batching/test_batch_manager.py | 1 + .../test_message_streaming.py | 1 + 15 files changed, 83 insertions(+), 28 deletions(-) diff --git a/sdks/python/src/opik/__init__.py b/sdks/python/src/opik/__init__.py index 8582395f42..e1f98b720f 100644 --- a/sdks/python/src/opik/__init__.py +++ b/sdks/python/src/opik/__init__.py @@ -1,4 +1,4 @@ -from . import _logging, error_tracking, package_version +from . import _logging, error_tracking, package_version, environment from .api_objects.dataset import Dataset from .api_objects.experiment.experiment_item import ( ExperimentItemContent, @@ -40,6 +40,7 @@ if ( error_tracking.enabled_in_config() + and not environment.in_pytest() and error_tracking.randomized_should_enable_reporting() ): error_tracking.setup_sentry_error_tracker() diff --git a/sdks/python/src/opik/api_objects/opik_client.py b/sdks/python/src/opik/api_objects/opik_client.py index df0876bccb..f51169cdd8 100644 --- a/sdks/python/src/opik/api_objects/opik_client.py +++ b/sdks/python/src/opik/api_objects/opik_client.py @@ -226,6 +226,7 @@ def span( model: Optional[str] = None, provider: Optional[str] = None, error_info: Optional[ErrorInfoDict] = None, + total_cost: Optional[float] = None, ) -> span.Span: """ Create and log a new span. @@ -249,6 +250,7 @@ def span( model: The name of LLM (in this case `type` parameter should be == `llm`) provider: The provider of LLM. error_info: The dictionary with error information (typically used when the span function has failed). + total_cost: The cost of the span in USD. This value takes priority over the cost calculated by Opik from the usage. Returns: span.Span: The created span object. @@ -304,6 +306,7 @@ def span( model=model, provider=provider, error_info=error_info, + total_cost=total_cost, ) self._streamer.put(create_span_message) diff --git a/sdks/python/src/opik/api_objects/span.py b/sdks/python/src/opik/api_objects/span.py index 60f2351f14..a136a74dc9 100644 --- a/sdks/python/src/opik/api_objects/span.py +++ b/sdks/python/src/opik/api_objects/span.py @@ -46,6 +46,7 @@ def end( tags: Optional[List[str]] = None, usage: Optional[UsageDict] = None, error_info: Optional[ErrorInfoDict] = None, + total_cost: Optional[float] = None, ) -> None: """ End the span and update its attributes. @@ -61,6 +62,7 @@ def end( tags: A list of tags to be associated with the span. usage: Usage information for the span. error_info: The dictionary with error information (typically used when the span function has failed). + total_cost: The cost of the span in USD. This value takes priority over the cost calculated by Opik from the usage. Returns: None @@ -77,6 +79,7 @@ def end( tags=tags, usage=usage, error_info=error_info, + total_cost=total_cost, ) def update( @@ -90,6 +93,7 @@ def update( model: Optional[str] = None, provider: Optional[str] = None, error_info: Optional[ErrorInfoDict] = None, + total_cost: Optional[float] = None, ) -> None: """ Update the span attributes. @@ -104,6 +108,7 @@ def update( model: The name of LLM. provider: The provider of LLM. error_info: The dictionary with error information (typically used when the span function has failed). + total_cost: The cost of the span in USD. This value takes priority over the cost calculated by Opik from the usage. Returns: None @@ -130,6 +135,7 @@ def update( model=model, provider=provider, error_info=error_info, + total_cost=total_cost, ) self._streamer.put(end_span_message) @@ -148,6 +154,7 @@ def span( model: Optional[str] = None, provider: Optional[str] = None, error_info: Optional[ErrorInfoDict] = None, + total_cost: Optional[float] = None, ) -> "Span": """ Create a new child span within the current span. @@ -165,6 +172,8 @@ def span( usage: Usage information for the span. model: The name of LLM (in this case `type` parameter should be == `llm`) provider: The provider of LLM. + error_info: The dictionary with error information (typically used when the span function has failed). + total_cost: The cost of the span in USD. This value takes priority over the cost calculated by Opik from the usage. Returns: Span: The created child span object. @@ -198,6 +207,7 @@ def span( model=model, provider=provider, error_info=error_info, + total_cost=total_cost, ) self._streamer.put(create_span_message) @@ -284,6 +294,7 @@ class SpanData: model: Optional[str] = None provider: Optional[str] = None error_info: Optional[ErrorInfoDict] = None + total_cost: Optional[float] = None def update(self, **new_data: Any) -> "SpanData": for key, value in new_data.items(): diff --git a/sdks/python/src/opik/api_objects/trace.py b/sdks/python/src/opik/api_objects/trace.py index 6205712b6d..a01e637b8c 100644 --- a/sdks/python/src/opik/api_objects/trace.py +++ b/sdks/python/src/opik/api_objects/trace.py @@ -116,6 +116,7 @@ def span( model: Optional[str] = None, provider: Optional[str] = None, error_info: Optional[ErrorInfoDict] = None, + total_cost: Optional[float] = None, ) -> span.Span: """ Create a new span within the trace. @@ -135,6 +136,7 @@ def span( model: The name of LLM (in this case `type` parameter should be == `llm`) provider: The provider of LLM. error_info: The dictionary with error information (typically used when the span function has failed). + total_cost: The cost of the span in USD. This value takes priority over the cost calculated by Opik from the usage. Returns: span.Span: The created span object. @@ -168,6 +170,7 @@ def span( model=model, provider=provider, error_info=error_info, + total_cost=total_cost, ) self._streamer.put(create_span_message) diff --git a/sdks/python/src/opik/message_processing/message_processors.py b/sdks/python/src/opik/message_processing/message_processors.py index 4346046a60..7c1c7bd135 100644 --- a/sdks/python/src/opik/message_processing/message_processors.py +++ b/sdks/python/src/opik/message_processing/message_processors.py @@ -88,20 +88,7 @@ def _process_create_trace_message( self._rest_client.traces.create_trace(**cleaned_create_trace_kwargs) def _process_update_span_message(self, message: messages.UpdateSpanMessage) -> None: - update_span_kwargs = { - "id": message.span_id, - "parent_span_id": message.parent_span_id, - "project_name": message.project_name, - "trace_id": message.trace_id, - "end_time": message.end_time, - "input": message.input, - "output": message.output, - "metadata": message.metadata, - "tags": message.tags, - "usage": message.usage, - "model": message.model, - "provider": message.provider, - } + update_span_kwargs = message.as_payload_dict() cleaned_update_span_kwargs = dict_utils.remove_none_from_dict( update_span_kwargs @@ -113,15 +100,7 @@ def _process_update_span_message(self, message: messages.UpdateSpanMessage) -> N def _process_update_trace_message( self, message: messages.UpdateTraceMessage ) -> None: - update_trace_kwargs = { - "id": message.trace_id, - "project_name": message.project_name, - "end_time": message.end_time, - "input": message.input, - "output": message.output, - "metadata": message.metadata, - "tags": message.tags, - } + update_trace_kwargs = message.as_payload_dict() cleaned_update_trace_kwargs = dict_utils.remove_none_from_dict( update_trace_kwargs diff --git a/sdks/python/src/opik/message_processing/messages.py b/sdks/python/src/opik/message_processing/messages.py index d8072a6c58..10efceb0f3 100644 --- a/sdks/python/src/opik/message_processing/messages.py +++ b/sdks/python/src/opik/message_processing/messages.py @@ -46,6 +46,11 @@ class UpdateTraceMessage(BaseMessage): tags: Optional[List[str]] error_info: Optional[ErrorInfoDict] + def as_payload_dict(self) -> Dict[str, Any]: + data = super().as_payload_dict() + data["id"] = data.pop("trace_id") + return data + @dataclasses.dataclass class CreateSpanMessage(BaseMessage): @@ -65,10 +70,12 @@ class CreateSpanMessage(BaseMessage): model: Optional[str] provider: Optional[str] error_info: Optional[ErrorInfoDict] + total_cost: Optional[float] def as_payload_dict(self) -> Dict[str, Any]: data = super().as_payload_dict() data["id"] = data.pop("span_id") + data["total_estimated_cost"] = data.pop("total_cost") return data @@ -89,6 +96,13 @@ class UpdateSpanMessage(BaseMessage): model: Optional[str] provider: Optional[str] error_info: Optional[ErrorInfoDict] + total_cost: Optional[float] + + def as_payload_dict(self) -> Dict[str, Any]: + data = super().as_payload_dict() + data["id"] = data.pop("span_id") + data["total_estimated_cost"] = data.pop("total_cost") + return data @dataclasses.dataclass diff --git a/sdks/python/src/opik/opik_context.py b/sdks/python/src/opik/opik_context.py index ab1ab27e96..45d8d18c21 100644 --- a/sdks/python/src/opik/opik_context.py +++ b/sdks/python/src/opik/opik_context.py @@ -51,6 +51,7 @@ def update_current_span( tags: Optional[List[str]] = None, usage: Optional[UsageDict] = None, feedback_scores: Optional[List[FeedbackScoreDict]] = None, + total_cost: Optional[float] = None, ) -> None: """ Update the current span with the provided parameters. This method is usually called within a tracked function. @@ -63,6 +64,7 @@ def update_current_span( tags: The tags of the span. usage: The usage data of the span. feedback_scores: The feedback scores of the span. + total_cost: The cost of the span in USD. This value takes priority over the cost calculated by Opik from the usage. """ new_params = { "name": name, @@ -72,6 +74,7 @@ def update_current_span( "tags": tags, "usage": usage, "feedback_scores": feedback_scores, + "total_cost": total_cost, } current_span_data = context_storage.top_span_data() if current_span_data is None: diff --git a/sdks/python/tests/e2e/test_tracing.py b/sdks/python/tests/e2e/test_tracing.py index d6d68aecf6..1e8be094cb 100644 --- a/sdks/python/tests/e2e/test_tracing.py +++ b/sdks/python/tests/e2e/test_tracing.py @@ -514,3 +514,31 @@ def test_search_spans__happyflow(opik_client): # Verify that the matching trace is returned assert len(spans) == 1, "Expected to find 1 matching span" assert spans[0].id == matching_span.id, "Expected to find the matching span" + + +def test_tracked_function__update_current_span_used_to_update_cost__happyflow( + opik_client, +): + # Setup + ID_STORAGE = {} + + @opik.track + def f(): + opik_context.update_current_span(total_cost=0.42) + ID_STORAGE["f_span-id"] = opik_context.get_current_span_data().id + ID_STORAGE["f_trace-id"] = opik_context.get_current_trace_data().id + + # Call + f() + opik.flush_tracker() + + # Verify top level span + verifiers.verify_span( + opik_client=opik_client, + span_id=ID_STORAGE["f_span-id"], + parent_span_id=None, + trace_id=ID_STORAGE["f_trace-id"], + name="f", + project_name=OPIK_E2E_TESTS_PROJECT_NAME, + total_cost=0.42, + ) diff --git a/sdks/python/tests/e2e/verifiers.py b/sdks/python/tests/e2e/verifiers.py index 0b663da9fd..26b08d4de7 100644 --- a/sdks/python/tests/e2e/verifiers.py +++ b/sdks/python/tests/e2e/verifiers.py @@ -102,6 +102,7 @@ def verify_span( model: Optional[str] = mock.ANY, # type: ignore provider: Optional[str] = mock.ANY, # type: ignore error_info: Optional[ErrorInfoDict] = mock.ANY, # type: ignore + total_cost: Optional[float] = mock.ANY, # type: ignore ): if not synchronization.until( lambda: (opik_client.get_span_content(id=span_id) is not None), @@ -132,8 +133,11 @@ def verify_span( assert ( _try_get__dict__(span.error_info) == error_info ), testlib.prepare_difference_report(span.error_info, error_info) - assert span.model == model - assert span.provider == provider + assert span.model == model, f"{span.model} != {model}" + assert span.provider == provider, f"{span.provider} != {provider}" + assert ( + span.total_estimated_cost == total_cost + ), f"{span.total_estimated_cost} != {total_cost}" if project_name is not mock.ANY: span_project = opik_client.get_project(span.project_id) diff --git a/sdks/python/tests/testlib/backend_emulator_message_processor.py b/sdks/python/tests/testlib/backend_emulator_message_processor.py index 9240cd999e..3cd6f4ed2e 100644 --- a/sdks/python/tests/testlib/backend_emulator_message_processor.py +++ b/sdks/python/tests/testlib/backend_emulator_message_processor.py @@ -129,6 +129,7 @@ def _dispatch_message(self, message: messages.BaseMessage) -> None: model=message.model, provider=message.provider, error_info=message.error_info, + total_cost=message.total_cost, ) self._span_to_parent_span[span.id] = message.parent_span_id @@ -156,6 +157,7 @@ def _dispatch_message(self, message: messages.BaseMessage) -> None: "error_info": message.error_info, "tags": message.tags, "input": message.input, + "total_cost": message.total_cost, } cleaned_update_payload = dict_utils.remove_none_from_dict(update_payload) span.__dict__.update(cleaned_update_payload) diff --git a/sdks/python/tests/testlib/models.py b/sdks/python/tests/testlib/models.py index bfd6d19587..e214d0305e 100644 --- a/sdks/python/tests/testlib/models.py +++ b/sdks/python/tests/testlib/models.py @@ -29,6 +29,7 @@ class SpanModel: model: Optional[str] = None provider: Optional[str] = None error_info: Optional[ErrorInfoDict] = None + total_cost: Optional[float] = None @dataclasses.dataclass diff --git a/sdks/python/tests/unit/configurator/test_configure.py b/sdks/python/tests/unit/configurator/test_configure.py index 76b56ec2f8..18ef1806f4 100644 --- a/sdks/python/tests/unit/configurator/test_configure.py +++ b/sdks/python/tests/unit/configurator/test_configure.py @@ -203,11 +203,12 @@ def test_update_config_session_update_failure( OpikConfigurator(api_key, workspace, url)._update_config() # Ensure config object is created and saved - mock_opik_config.assert_called_with( + mock_opik_config.assert_any_call( api_key=api_key, url_override="http://example.com/opik/api/", workspace=workspace, ) + mock_config_instance.save_to_file.assert_called_once() diff --git a/sdks/python/tests/unit/decorator/test_tracker_outputs.py b/sdks/python/tests/unit/decorator/test_tracker_outputs.py index b8c07eceff..2cedc30640 100644 --- a/sdks/python/tests/unit/decorator/test_tracker_outputs.py +++ b/sdks/python/tests/unit/decorator/test_tracker_outputs.py @@ -1058,7 +1058,9 @@ def test_track__span_and_trace_updated_via_opik_context(fake_backend): @tracker.track def f(x): opik_context.update_current_span( - name="span-name", metadata={"span-metadata-key": "span-metadata-value"} + name="span-name", + metadata={"span-metadata-key": "span-metadata-value"}, + total_cost=0.42, ) opik_context.update_current_trace( name="trace-name", @@ -1087,6 +1089,7 @@ def f(x): output={"output": "f-output"}, start_time=ANY_BUT_NONE, end_time=ANY_BUT_NONE, + total_cost=0.42, spans=[], ) ], diff --git a/sdks/python/tests/unit/message_processing/batching/test_batch_manager.py b/sdks/python/tests/unit/message_processing/batching/test_batch_manager.py index 94b9ccfb4a..933177756d 100644 --- a/sdks/python/tests/unit/message_processing/batching/test_batch_manager.py +++ b/sdks/python/tests/unit/message_processing/batching/test_batch_manager.py @@ -100,6 +100,7 @@ def test_batch_manager__start_and_stop_were_called__accumulated_data_is_flushed( model=NOT_USED, provider=NOT_USED, error_info=NOT_USED, + total_cost=NOT_USED, ) example_span_batcher = batchers.CreateSpanMessageBatcher( diff --git a/sdks/python/tests/unit/message_processing/test_message_streaming.py b/sdks/python/tests/unit/message_processing/test_message_streaming.py index bbe1024a1d..062180f0db 100644 --- a/sdks/python/tests/unit/message_processing/test_message_streaming.py +++ b/sdks/python/tests/unit/message_processing/test_message_streaming.py @@ -24,6 +24,7 @@ def create_span_message(): model=NOT_USED, provider=NOT_USED, error_info=NOT_USED, + total_cost=NOT_USED, )