Skip to content

Commit

Permalink
Remove un-used header extraction (#838)
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter authored Jun 6, 2024
1 parent e167885 commit 8d6de30
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 123 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed
- `Workflow.insert_task()` no longer inserts duplicate tasks when given multiple parent tasks.
- Performance issue in `OpenAiChatPromptDriver` when extracting unused rate-limiting headers.

## [0.26.0] - 2024-06-04

Expand Down
54 changes: 3 additions & 51 deletions griptape/drivers/prompt/openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from griptape.utils import PromptStack
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import OpenAiTokenizer, BaseTokenizer
import dateparser
from datetime import datetime, timedelta


@define
Expand All @@ -25,12 +23,6 @@ class OpenAiChatPromptDriver(BasePromptDriver):
response_format: An optional OpenAi Chat Completion response format. Currently only supports `json_object` which will enable OpenAi's JSON mode.
seed: An optional OpenAi Chat Completion seed.
ignored_exception_types: An optional tuple of exception types to ignore. Defaults to OpenAI's known exception types.
_ratelimit_request_limit: The maximum number of requests allowed in the current rate limit window.
_ratelimit_requests_remaining: The number of requests remaining in the current rate limit window.
_ratelimit_requests_reset_at: The time at which the current rate limit window resets.
_ratelimit_token_limit: The maximum number of tokens allowed in the current rate limit window.
_ratelimit_tokens_remaining: The number of tokens remaining in the current rate limit window.
_ratelimit_tokens_reset_at: The time at which the current rate limit window resets.
"""

base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
Expand Down Expand Up @@ -64,22 +56,12 @@ class OpenAiChatPromptDriver(BasePromptDriver):
),
kw_only=True,
)
_ratelimit_request_limit: Optional[int] = field(init=False, default=None)
_ratelimit_requests_remaining: Optional[int] = field(init=False, default=None)
_ratelimit_requests_reset_at: Optional[datetime] = field(init=False, default=None)
_ratelimit_token_limit: Optional[int] = field(init=False, default=None)
_ratelimit_tokens_remaining: Optional[int] = field(init=False, default=None)
_ratelimit_tokens_reset_at: Optional[datetime] = field(init=False, default=None)

def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
result = self.client.chat.completions.with_raw_response.create(**self._base_params(prompt_stack))
result = self.client.chat.completions.create(**self._base_params(prompt_stack))

self._extract_ratelimit_metadata(result)

parsed_result = result.parse()

if len(parsed_result.choices) == 1:
return TextArtifact(value=parsed_result.choices[0].message.content.strip())
if len(result.choices) == 1:
return TextArtifact(value=result.choices[0].message.content.strip())
else:
raise Exception("Completion with more than one choice is not supported yet.")

Expand Down Expand Up @@ -136,33 +118,3 @@ def __to_openai_role(self, prompt_input: PromptStack.Input) -> str:
return "assistant"
else:
return "user"

def _extract_ratelimit_metadata(self, response):
# The OpenAI SDK's requestssession variable is global, so this hook will fire for all API requests.
# The following headers are not reliably returned in every API call, so we check for the presence of the
# headers before reading and parsing their values to prevent other SDK users from encountering KeyErrors.
reset_requests_at = response.headers.get("x-ratelimit-reset-requests")
if reset_requests_at is not None:
self._ratelimit_requests_reset_at = dateparser.parse(
reset_requests_at, settings={"PREFER_DATES_FROM": "future"}
)

# The dateparser utility doesn't handle sub-second durations as are sometimes returned by OpenAI's API.
# If the API returns, for example, "13ms", dateparser.parse() returns None. In this case, we will set
# the time value to the current time plus a one second buffer.
if self._ratelimit_requests_reset_at is None:
self._ratelimit_requests_reset_at = datetime.now() + timedelta(seconds=1)

reset_tokens_at = response.headers.get("x-ratelimit-reset-tokens")
if reset_tokens_at is not None:
self._ratelimit_tokens_reset_at = dateparser.parse(
reset_tokens_at, settings={"PREFER_DATES_FROM": "future"}
)

if self._ratelimit_tokens_reset_at is None:
self._ratelimit_tokens_reset_at = datetime.now() + timedelta(seconds=1)

self._ratelimit_request_limit = response.headers.get("x-ratelimit-limit-requests")
self._ratelimit_requests_remaining = response.headers.get("x-ratelimit-remaining-requests")
self._ratelimit_token_limit = response.headers.get("x-ratelimit-limit-tokens")
self._ratelimit_tokens_remaining = response.headers.get("x-ratelimit-remaining-tokens")
10 changes: 5 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ numpy = ">=1"
stringcase = "^1.2.0"
docker = "^7.1.0"
sqlalchemy = "~=1.0"
dateparser = "^1.1.8"
requests = "^2"

# drivers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
class TestAzureOpenAiChatPromptDriver(TestOpenAiChatPromptDriverFixtureMixin):
@pytest.fixture
def mock_chat_completion_create(self, mocker):
mock_chat_create = mocker.patch("openai.AzureOpenAI").return_value.chat.completions.with_raw_response.create
mock_chat_create = mocker.patch("openai.AzureOpenAI").return_value.chat.completions.create
mock_choice = Mock()
mock_choice.message.content = "model-output"
mock_chat_create.return_value.headers = {}
mock_chat_create.return_value.parse.return_value.choices = [mock_choice]
mock_chat_create.return_value.choices = [mock_choice]
return mock_chat_create

@pytest.fixture
Expand Down
67 changes: 3 additions & 64 deletions tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import datetime

from transformers import AutoTokenizer

from griptape.drivers import OpenAiChatPromptDriver
Expand All @@ -13,11 +11,11 @@
class TestOpenAiChatPromptDriverFixtureMixin:
@pytest.fixture
def mock_chat_completion_create(self, mocker):
mock_chat_create = mocker.patch("openai.OpenAI").return_value.chat.completions.with_raw_response.create
mock_chat_create = mocker.patch("openai.OpenAI").return_value.chat.completions.create
mock_choice = Mock()
mock_choice.message.content = "model-output"
mock_chat_create.return_value.headers = {}
mock_chat_create.return_value.parse.return_value.choices = [mock_choice]
mock_chat_create.return_value.choices = [mock_choice]
return mock_chat_create

@pytest.fixture
Expand Down Expand Up @@ -202,7 +200,7 @@ def test_try_run_throws_when_prompt_stack_is_string(self):
def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_chat_completion_create, prompt_stack):
# Given
driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, api_key="api-key")
mock_chat_completion_create.return_value.parse.return_value.choices = [choices]
mock_chat_completion_create.return_value.choices = [choices]

# When
with pytest.raises(Exception) as e:
Expand Down Expand Up @@ -258,65 +256,6 @@ def test_max_output_tokens_with_max_tokens(self, messages):

assert max_tokens == 42

def test_extract_ratelimit_metadata(self):
response_with_headers = OpenAiApiResponseWithHeaders()
driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL)
driver._extract_ratelimit_metadata(response_with_headers)

assert driver._ratelimit_requests_remaining == response_with_headers.remaining_requests
assert driver._ratelimit_tokens_remaining == response_with_headers.remaining_tokens
assert driver._ratelimit_request_limit == response_with_headers.limit_requests
assert driver._ratelimit_token_limit == response_with_headers.limit_tokens

# Assert that the reset times are within one second of the expected value.
expected_request_reset_time = datetime.datetime.now() + datetime.timedelta(
seconds=response_with_headers.reset_requests_in
)
expected_token_reset_time = datetime.datetime.now() + datetime.timedelta(
seconds=response_with_headers.reset_tokens_in
)

assert driver._ratelimit_requests_reset_at is not None
assert abs(driver._ratelimit_requests_reset_at - expected_request_reset_time) < datetime.timedelta(seconds=1)
assert driver._ratelimit_tokens_reset_at is not None
assert abs(driver._ratelimit_tokens_reset_at - expected_token_reset_time) < datetime.timedelta(seconds=1)

def test_extract_ratelimit_metadata_with_subsecond_reset_times(self):
response_with_headers = OpenAiApiResponseWithHeaders(
reset_requests_in=1, reset_requests_in_unit="ms", reset_tokens_in=10, reset_tokens_in_unit="ms"
)
driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, api_key="api-key")
driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL)
driver._extract_ratelimit_metadata(response_with_headers)

# Assert that the reset times are within one second of the expected value. With a sub-second reset time,
# this is rounded up to one second in the future.
expected_request_reset_time = datetime.datetime.now() + datetime.timedelta(seconds=1)
expected_token_reset_time = datetime.datetime.now() + datetime.timedelta(seconds=1)

assert driver._ratelimit_requests_reset_at is not None
assert abs(driver._ratelimit_requests_reset_at - expected_request_reset_time) < datetime.timedelta(seconds=1)
assert driver._ratelimit_tokens_reset_at is not None
assert abs(driver._ratelimit_tokens_reset_at - expected_token_reset_time) < datetime.timedelta(seconds=1)

def test_extract_ratelimit_metadata_missing_headers(self):
class OpenAiApiResponseNoHeaders:
@property
def headers(self):
return {}

response_without_headers = OpenAiApiResponseNoHeaders()

driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL)
driver._extract_ratelimit_metadata(response_without_headers)

assert driver._ratelimit_request_limit is None
assert driver._ratelimit_requests_remaining is None
assert driver._ratelimit_requests_reset_at is None
assert driver._ratelimit_token_limit is None
assert driver._ratelimit_tokens_remaining is None
assert driver._ratelimit_tokens_reset_at is None

def test_custom_tokenizer(self, mock_chat_completion_create, prompt_stack, messages):
driver = OpenAiChatPromptDriver(
model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL,
Expand Down

0 comments on commit 8d6de30

Please sign in to comment.