Skip to content

Commit

Permalink
[OPIK-750] sdk add the possibility to manually sets the cost of indiv…
Browse files Browse the repository at this point in the history
…idual spans (#1107)

* Add total_cost to low-level API

* Pass total_cost to requests payload as estimated_total_cost

* Add total_cost to update_current_span

* Add e2e test for total cost logging

* Fix lint errors

* Update unit tests and backend emulator

* Update old test

* Disable sentry if pytest env is detected. Relax configurator test

* Update the docstrings to mention price units

* Add missing docstring
  • Loading branch information
alexkuzmik authored Jan 23, 2025
1 parent 9a22472 commit d390e93
Show file tree
Hide file tree
Showing 15 changed files with 83 additions and 28 deletions.
3 changes: 2 additions & 1 deletion sdks/python/src/opik/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import _logging, error_tracking, package_version
from . import _logging, error_tracking, package_version, environment
from .api_objects.dataset import Dataset
from .api_objects.experiment.experiment_item import (
ExperimentItemContent,
Expand Down Expand Up @@ -40,6 +40,7 @@

if (
error_tracking.enabled_in_config()
and not environment.in_pytest()
and error_tracking.randomized_should_enable_reporting()
):
error_tracking.setup_sentry_error_tracker()
3 changes: 3 additions & 0 deletions sdks/python/src/opik/api_objects/opik_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def span(
model: Optional[str] = None,
provider: Optional[str] = None,
error_info: Optional[ErrorInfoDict] = None,
total_cost: Optional[float] = None,
) -> span.Span:
"""
Create and log a new span.
Expand All @@ -249,6 +250,7 @@ def span(
model: The name of LLM (in this case `type` parameter should be == `llm`)
provider: The provider of LLM.
error_info: The dictionary with error information (typically used when the span function has failed).
total_cost: The cost of the span in USD. This value takes priority over the cost calculated by Opik from the usage.
Returns:
span.Span: The created span object.
Expand Down Expand Up @@ -304,6 +306,7 @@ def span(
model=model,
provider=provider,
error_info=error_info,
total_cost=total_cost,
)
self._streamer.put(create_span_message)

Expand Down
11 changes: 11 additions & 0 deletions sdks/python/src/opik/api_objects/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def end(
tags: Optional[List[str]] = None,
usage: Optional[UsageDict] = None,
error_info: Optional[ErrorInfoDict] = None,
total_cost: Optional[float] = None,
) -> None:
"""
End the span and update its attributes.
Expand All @@ -61,6 +62,7 @@ def end(
tags: A list of tags to be associated with the span.
usage: Usage information for the span.
error_info: The dictionary with error information (typically used when the span function has failed).
total_cost: The cost of the span in USD. This value takes priority over the cost calculated by Opik from the usage.
Returns:
None
Expand All @@ -77,6 +79,7 @@ def end(
tags=tags,
usage=usage,
error_info=error_info,
total_cost=total_cost,
)

def update(
Expand All @@ -90,6 +93,7 @@ def update(
model: Optional[str] = None,
provider: Optional[str] = None,
error_info: Optional[ErrorInfoDict] = None,
total_cost: Optional[float] = None,
) -> None:
"""
Update the span attributes.
Expand All @@ -104,6 +108,7 @@ def update(
model: The name of LLM.
provider: The provider of LLM.
error_info: The dictionary with error information (typically used when the span function has failed).
total_cost: The cost of the span in USD. This value takes priority over the cost calculated by Opik from the usage.
Returns:
None
Expand All @@ -130,6 +135,7 @@ def update(
model=model,
provider=provider,
error_info=error_info,
total_cost=total_cost,
)
self._streamer.put(end_span_message)

Expand All @@ -148,6 +154,7 @@ def span(
model: Optional[str] = None,
provider: Optional[str] = None,
error_info: Optional[ErrorInfoDict] = None,
total_cost: Optional[float] = None,
) -> "Span":
"""
Create a new child span within the current span.
Expand All @@ -165,6 +172,8 @@ def span(
usage: Usage information for the span.
model: The name of LLM (in this case `type` parameter should be == `llm`)
provider: The provider of LLM.
error_info: The dictionary with error information (typically used when the span function has failed).
total_cost: The cost of the span in USD. This value takes priority over the cost calculated by Opik from the usage.
Returns:
Span: The created child span object.
Expand Down Expand Up @@ -198,6 +207,7 @@ def span(
model=model,
provider=provider,
error_info=error_info,
total_cost=total_cost,
)
self._streamer.put(create_span_message)

Expand Down Expand Up @@ -284,6 +294,7 @@ class SpanData:
model: Optional[str] = None
provider: Optional[str] = None
error_info: Optional[ErrorInfoDict] = None
total_cost: Optional[float] = None

def update(self, **new_data: Any) -> "SpanData":
for key, value in new_data.items():
Expand Down
3 changes: 3 additions & 0 deletions sdks/python/src/opik/api_objects/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def span(
model: Optional[str] = None,
provider: Optional[str] = None,
error_info: Optional[ErrorInfoDict] = None,
total_cost: Optional[float] = None,
) -> span.Span:
"""
Create a new span within the trace.
Expand All @@ -135,6 +136,7 @@ def span(
model: The name of LLM (in this case `type` parameter should be == `llm`)
provider: The provider of LLM.
error_info: The dictionary with error information (typically used when the span function has failed).
total_cost: The cost of the span in USD. This value takes priority over the cost calculated by Opik from the usage.
Returns:
span.Span: The created span object.
Expand Down Expand Up @@ -168,6 +170,7 @@ def span(
model=model,
provider=provider,
error_info=error_info,
total_cost=total_cost,
)
self._streamer.put(create_span_message)

Expand Down
25 changes: 2 additions & 23 deletions sdks/python/src/opik/message_processing/message_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,20 +88,7 @@ def _process_create_trace_message(
self._rest_client.traces.create_trace(**cleaned_create_trace_kwargs)

def _process_update_span_message(self, message: messages.UpdateSpanMessage) -> None:
update_span_kwargs = {
"id": message.span_id,
"parent_span_id": message.parent_span_id,
"project_name": message.project_name,
"trace_id": message.trace_id,
"end_time": message.end_time,
"input": message.input,
"output": message.output,
"metadata": message.metadata,
"tags": message.tags,
"usage": message.usage,
"model": message.model,
"provider": message.provider,
}
update_span_kwargs = message.as_payload_dict()

cleaned_update_span_kwargs = dict_utils.remove_none_from_dict(
update_span_kwargs
Expand All @@ -113,15 +100,7 @@ def _process_update_span_message(self, message: messages.UpdateSpanMessage) -> N
def _process_update_trace_message(
self, message: messages.UpdateTraceMessage
) -> None:
update_trace_kwargs = {
"id": message.trace_id,
"project_name": message.project_name,
"end_time": message.end_time,
"input": message.input,
"output": message.output,
"metadata": message.metadata,
"tags": message.tags,
}
update_trace_kwargs = message.as_payload_dict()

cleaned_update_trace_kwargs = dict_utils.remove_none_from_dict(
update_trace_kwargs
Expand Down
14 changes: 14 additions & 0 deletions sdks/python/src/opik/message_processing/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ class UpdateTraceMessage(BaseMessage):
tags: Optional[List[str]]
error_info: Optional[ErrorInfoDict]

def as_payload_dict(self) -> Dict[str, Any]:
data = super().as_payload_dict()
data["id"] = data.pop("trace_id")
return data


@dataclasses.dataclass
class CreateSpanMessage(BaseMessage):
Expand All @@ -65,10 +70,12 @@ class CreateSpanMessage(BaseMessage):
model: Optional[str]
provider: Optional[str]
error_info: Optional[ErrorInfoDict]
total_cost: Optional[float]

def as_payload_dict(self) -> Dict[str, Any]:
data = super().as_payload_dict()
data["id"] = data.pop("span_id")
data["total_estimated_cost"] = data.pop("total_cost")
return data


Expand All @@ -89,6 +96,13 @@ class UpdateSpanMessage(BaseMessage):
model: Optional[str]
provider: Optional[str]
error_info: Optional[ErrorInfoDict]
total_cost: Optional[float]

def as_payload_dict(self) -> Dict[str, Any]:
data = super().as_payload_dict()
data["id"] = data.pop("span_id")
data["total_estimated_cost"] = data.pop("total_cost")
return data


@dataclasses.dataclass
Expand Down
3 changes: 3 additions & 0 deletions sdks/python/src/opik/opik_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def update_current_span(
tags: Optional[List[str]] = None,
usage: Optional[UsageDict] = None,
feedback_scores: Optional[List[FeedbackScoreDict]] = None,
total_cost: Optional[float] = None,
) -> None:
"""
Update the current span with the provided parameters. This method is usually called within a tracked function.
Expand All @@ -63,6 +64,7 @@ def update_current_span(
tags: The tags of the span.
usage: The usage data of the span.
feedback_scores: The feedback scores of the span.
total_cost: The cost of the span in USD. This value takes priority over the cost calculated by Opik from the usage.
"""
new_params = {
"name": name,
Expand All @@ -72,6 +74,7 @@ def update_current_span(
"tags": tags,
"usage": usage,
"feedback_scores": feedback_scores,
"total_cost": total_cost,
}
current_span_data = context_storage.top_span_data()
if current_span_data is None:
Expand Down
28 changes: 28 additions & 0 deletions sdks/python/tests/e2e/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,3 +514,31 @@ def test_search_spans__happyflow(opik_client):
# Verify that the matching trace is returned
assert len(spans) == 1, "Expected to find 1 matching span"
assert spans[0].id == matching_span.id, "Expected to find the matching span"


def test_tracked_function__update_current_span_used_to_update_cost__happyflow(
opik_client,
):
# Setup
ID_STORAGE = {}

@opik.track
def f():
opik_context.update_current_span(total_cost=0.42)
ID_STORAGE["f_span-id"] = opik_context.get_current_span_data().id
ID_STORAGE["f_trace-id"] = opik_context.get_current_trace_data().id

# Call
f()
opik.flush_tracker()

# Verify top level span
verifiers.verify_span(
opik_client=opik_client,
span_id=ID_STORAGE["f_span-id"],
parent_span_id=None,
trace_id=ID_STORAGE["f_trace-id"],
name="f",
project_name=OPIK_E2E_TESTS_PROJECT_NAME,
total_cost=0.42,
)
8 changes: 6 additions & 2 deletions sdks/python/tests/e2e/verifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def verify_span(
model: Optional[str] = mock.ANY, # type: ignore
provider: Optional[str] = mock.ANY, # type: ignore
error_info: Optional[ErrorInfoDict] = mock.ANY, # type: ignore
total_cost: Optional[float] = mock.ANY, # type: ignore
):
if not synchronization.until(
lambda: (opik_client.get_span_content(id=span_id) is not None),
Expand Down Expand Up @@ -132,8 +133,11 @@ def verify_span(
assert (
_try_get__dict__(span.error_info) == error_info
), testlib.prepare_difference_report(span.error_info, error_info)
assert span.model == model
assert span.provider == provider
assert span.model == model, f"{span.model} != {model}"
assert span.provider == provider, f"{span.provider} != {provider}"
assert (
span.total_estimated_cost == total_cost
), f"{span.total_estimated_cost} != {total_cost}"

if project_name is not mock.ANY:
span_project = opik_client.get_project(span.project_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def _dispatch_message(self, message: messages.BaseMessage) -> None:
model=message.model,
provider=message.provider,
error_info=message.error_info,
total_cost=message.total_cost,
)

self._span_to_parent_span[span.id] = message.parent_span_id
Expand Down Expand Up @@ -156,6 +157,7 @@ def _dispatch_message(self, message: messages.BaseMessage) -> None:
"error_info": message.error_info,
"tags": message.tags,
"input": message.input,
"total_cost": message.total_cost,
}
cleaned_update_payload = dict_utils.remove_none_from_dict(update_payload)
span.__dict__.update(cleaned_update_payload)
Expand Down
1 change: 1 addition & 0 deletions sdks/python/tests/testlib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class SpanModel:
model: Optional[str] = None
provider: Optional[str] = None
error_info: Optional[ErrorInfoDict] = None
total_cost: Optional[float] = None


@dataclasses.dataclass
Expand Down
3 changes: 2 additions & 1 deletion sdks/python/tests/unit/configurator/test_configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,12 @@ def test_update_config_session_update_failure(
OpikConfigurator(api_key, workspace, url)._update_config()

# Ensure config object is created and saved
mock_opik_config.assert_called_with(
mock_opik_config.assert_any_call(
api_key=api_key,
url_override="http://example.com/opik/api/",
workspace=workspace,
)

mock_config_instance.save_to_file.assert_called_once()


Expand Down
5 changes: 4 additions & 1 deletion sdks/python/tests/unit/decorator/test_tracker_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,7 +1058,9 @@ def test_track__span_and_trace_updated_via_opik_context(fake_backend):
@tracker.track
def f(x):
opik_context.update_current_span(
name="span-name", metadata={"span-metadata-key": "span-metadata-value"}
name="span-name",
metadata={"span-metadata-key": "span-metadata-value"},
total_cost=0.42,
)
opik_context.update_current_trace(
name="trace-name",
Expand Down Expand Up @@ -1087,6 +1089,7 @@ def f(x):
output={"output": "f-output"},
start_time=ANY_BUT_NONE,
end_time=ANY_BUT_NONE,
total_cost=0.42,
spans=[],
)
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def test_batch_manager__start_and_stop_were_called__accumulated_data_is_flushed(
model=NOT_USED,
provider=NOT_USED,
error_info=NOT_USED,
total_cost=NOT_USED,
)

example_span_batcher = batchers.CreateSpanMessageBatcher(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def create_span_message():
model=NOT_USED,
provider=NOT_USED,
error_info=NOT_USED,
total_cost=NOT_USED,
)


Expand Down

0 comments on commit d390e93

Please sign in to comment.