Skip to content

Commit

Permalink
feat: BREAKING - enable different billing schemas - pass Call to the …
Browse files Browse the repository at this point in the history
…results_usage_extractor

This makes old API descriptions incompatible
  • Loading branch information
tpietruszka committed Aug 17, 2023
1 parent 346bbf5 commit ee06f3f
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 11 deletions.
4 changes: 2 additions & 2 deletions rate_limited/apis/openai/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List

from rate_limited.apis.common import get_requests_per_minute
from rate_limited.calls import Call
from rate_limited.calls import Call, Result
from rate_limited.resources import Resource


Expand All @@ -28,7 +28,7 @@ def get_tokens_per_minute(quota: int, model_max_len: int) -> Resource:
)


def get_used_tokens(results: dict) -> int:
def get_used_tokens(call: Call, results: Result) -> int:
total_tokens = results.get("usage", {}).get("total_tokens", None)
if total_tokens is None:
raise ValueError("Could not find total_tokens in results")
Expand Down
6 changes: 3 additions & 3 deletions rate_limited/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from logging import getLogger
from typing import Collection, Optional

from rate_limited.calls import Call
from rate_limited.calls import Call, Result
from rate_limited.resources import Resource, Unit


Expand Down Expand Up @@ -51,10 +51,10 @@ def pre_allocate(self, call: Call):
if resource.max_results_usage_estimator:
resource.reserve_amount(resource.max_results_usage_estimator(call))

def register_result(self, result):
def register_result(self, call: Call, result: Result):
for resource in self.resources:
if resource.results_usage_extractor:
resource.add_usage(resource.results_usage_extractor(result))
resource.add_usage(resource.results_usage_extractor(call, result))

def remove_pre_allocation(self, call: Call):
# Right now assuming that pre-allocation is only based on the call, this could change
Expand Down
8 changes: 4 additions & 4 deletions rate_limited/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(
quota: Unit,
time_window_seconds: float,
arguments_usage_extractor: Optional[Callable[[Call], Unit]] = None,
results_usage_extractor: Optional[Callable[[Result], Unit]] = None,
results_usage_extractor: Optional[Callable[[Call, Result], Unit]] = None,
max_results_usage_estimator: Optional[Callable[[Call], Unit]] = None,
):
"""
Expand All @@ -32,9 +32,9 @@ def __init__(
quota: maximum amount of the resource that can be used in the time window
time_window_seconds: time window in seconds
arguments_usage_extractor: function that extracts the amount of resource used from
the arguments
the arguments, "billed" before the call is made
results_usage_extractor: function that extracts the amount of resource used from
the results
the results (and the arguments), "billed" after the call is made
max_results_usage_estimator: function that extracts an upper bound on the amount of
resource that might be used when results are returned, based on the arguments
(this is used to pre-allocate usage, pre-allocation is then replaced with the
Expand All @@ -50,7 +50,7 @@ def __init__(

self.arguments_usage_extractor = arguments_usage_extractor
self.results_usage_extractor = results_usage_extractor
self.max_results_usage_estimator = max_results_usage_estimator # TODO: consider renaming
self.max_results_usage_estimator = max_results_usage_estimator

if self.max_results_usage_estimator and not self.results_usage_extractor:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion rate_limited/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ async def worker(self):
self.requests_executor_pool, self.function, *call.args, **call.kwargs
)
# TODO: are there cases where we need to register result-based usage on error?
self.resource_manager.register_result(result)
self.resource_manager.register_result(call, result)
if self.validation_function is not None:
if not self.validation_function(result):
raise ValidationError(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def dummy_resources(
name="points",
quota=num_points,
time_window_seconds=time_window_seconds,
results_usage_extractor=lambda x: x["used_points"],
results_usage_extractor=lambda _, result: result["used_points"],
max_results_usage_estimator=estimator,
),
]
Expand Down

0 comments on commit ee06f3f

Please sign in to comment.