From 8d6de30c70998b27e298d77c344600c58e1c15c3 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 6 Jun 2024 10:27:37 -0700 Subject: [PATCH] Remove un-used header extraction (#838) --- CHANGELOG.md | 1 + .../prompt/openai_chat_prompt_driver.py | 54 +-------------- poetry.lock | 10 +-- pyproject.toml | 1 - .../test_azure_openai_chat_prompt_driver.py | 4 +- .../prompt/test_openai_chat_prompt_driver.py | 67 +------------------ 6 files changed, 14 insertions(+), 123 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ad589408c..0756ef35d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - `Workflow.insert_task()` no longer inserts duplicate tasks when given multiple parent tasks. +- Performance issue in `OpenAiChatPromptDriver` when extracting unused rate-limiting headers. ## [0.26.0] - 2024-06-04 diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index 3d19063d3..345545b6f 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -7,8 +7,6 @@ from griptape.utils import PromptStack from griptape.drivers import BasePromptDriver from griptape.tokenizers import OpenAiTokenizer, BaseTokenizer -import dateparser -from datetime import datetime, timedelta @define @@ -25,12 +23,6 @@ class OpenAiChatPromptDriver(BasePromptDriver): response_format: An optional OpenAi Chat Completion response format. Currently only supports `json_object` which will enable OpenAi's JSON mode. seed: An optional OpenAi Chat Completion seed. ignored_exception_types: An optional tuple of exception types to ignore. Defaults to OpenAI's known exception types. - _ratelimit_request_limit: The maximum number of requests allowed in the current rate limit window. - _ratelimit_requests_remaining: The number of requests remaining in the current rate limit window. - _ratelimit_requests_reset_at: The time at which the current rate limit window resets. - _ratelimit_token_limit: The maximum number of tokens allowed in the current rate limit window. - _ratelimit_tokens_remaining: The number of tokens remaining in the current rate limit window. - _ratelimit_tokens_reset_at: The time at which the current rate limit window resets. """ base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) @@ -64,22 +56,12 @@ class OpenAiChatPromptDriver(BasePromptDriver): ), kw_only=True, ) - _ratelimit_request_limit: Optional[int] = field(init=False, default=None) - _ratelimit_requests_remaining: Optional[int] = field(init=False, default=None) - _ratelimit_requests_reset_at: Optional[datetime] = field(init=False, default=None) - _ratelimit_token_limit: Optional[int] = field(init=False, default=None) - _ratelimit_tokens_remaining: Optional[int] = field(init=False, default=None) - _ratelimit_tokens_reset_at: Optional[datetime] = field(init=False, default=None) def try_run(self, prompt_stack: PromptStack) -> TextArtifact: - result = self.client.chat.completions.with_raw_response.create(**self._base_params(prompt_stack)) + result = self.client.chat.completions.create(**self._base_params(prompt_stack)) - self._extract_ratelimit_metadata(result) - - parsed_result = result.parse() - - if len(parsed_result.choices) == 1: - return TextArtifact(value=parsed_result.choices[0].message.content.strip()) + if len(result.choices) == 1: + return TextArtifact(value=result.choices[0].message.content.strip()) else: raise Exception("Completion with more than one choice is not supported yet.") @@ -136,33 +118,3 @@ def __to_openai_role(self, prompt_input: PromptStack.Input) -> str: return "assistant" else: return "user" - - def _extract_ratelimit_metadata(self, response): - # The OpenAI SDK's requestssession variable is global, so this hook will fire for all API requests. - # The following headers are not reliably returned in every API call, so we check for the presence of the - # headers before reading and parsing their values to prevent other SDK users from encountering KeyErrors. - reset_requests_at = response.headers.get("x-ratelimit-reset-requests") - if reset_requests_at is not None: - self._ratelimit_requests_reset_at = dateparser.parse( - reset_requests_at, settings={"PREFER_DATES_FROM": "future"} - ) - - # The dateparser utility doesn't handle sub-second durations as are sometimes returned by OpenAI's API. - # If the API returns, for example, "13ms", dateparser.parse() returns None. In this case, we will set - # the time value to the current time plus a one second buffer. - if self._ratelimit_requests_reset_at is None: - self._ratelimit_requests_reset_at = datetime.now() + timedelta(seconds=1) - - reset_tokens_at = response.headers.get("x-ratelimit-reset-tokens") - if reset_tokens_at is not None: - self._ratelimit_tokens_reset_at = dateparser.parse( - reset_tokens_at, settings={"PREFER_DATES_FROM": "future"} - ) - - if self._ratelimit_tokens_reset_at is None: - self._ratelimit_tokens_reset_at = datetime.now() + timedelta(seconds=1) - - self._ratelimit_request_limit = response.headers.get("x-ratelimit-limit-requests") - self._ratelimit_requests_remaining = response.headers.get("x-ratelimit-remaining-requests") - self._ratelimit_token_limit = response.headers.get("x-ratelimit-limit-tokens") - self._ratelimit_tokens_remaining = response.headers.get("x-ratelimit-remaining-tokens") diff --git a/poetry.lock b/poetry.lock index 87b88792a..ed65bc07f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1127,7 +1127,7 @@ test-randomorder = ["pytest-randomly"] name = "dateparser" version = "1.2.0" description = "Date parsing library designed to parse dates from HTML pages" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "dateparser-1.2.0-py2.py3-none-any.whl", hash = "sha256:0b21ad96534e562920a0083e97fd45fa959882d4162acc358705144520a35830"}, @@ -4248,7 +4248,7 @@ six = ">=1.5" name = "pytz" version = "2024.1" description = "World timezone definitions, modern and historical" -optional = false +optional = true python-versions = "*" files = [ {file = "pytz-2024.1-py2.py3-none-any.whl", hash = "sha256:328171f4e3623139da4983451950b28e95ac706e13f3f2630a879749e7a8b319"}, @@ -5698,7 +5698,7 @@ files = [ name = "tzdata" version = "2024.1" description = "Provider of IANA time zone data" -optional = false +optional = true python-versions = ">=2" files = [ {file = "tzdata-2024.1-py2.py3-none-any.whl", hash = "sha256:9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252"}, @@ -5709,7 +5709,7 @@ files = [ name = "tzlocal" version = "5.2" description = "tzinfo object for the local timezone" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "tzlocal-5.2-py3-none-any.whl", hash = "sha256:49816ef2fe65ea8ac19d19aa7a1ae0551c834303d5014c6d5a62e4cbda8047b8"}, @@ -6096,4 +6096,4 @@ loaders-pdf = ["pypdf"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "d0913fb4119f352710d722c55029eaa74e948d57e3fa5ffb345ca4e690f20ec2" +content-hash = "74de0d1e5ee382332635cea14dc0d39e288ab65adfec91569da0dbb06fd316e2" diff --git a/pyproject.toml b/pyproject.toml index 53c07257e..d58432b8f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,6 @@ numpy = ">=1" stringcase = "^1.2.0" docker = "^7.1.0" sqlalchemy = "~=1.0" -dateparser = "^1.1.8" requests = "^2" # drivers diff --git a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py index 9446d0520..f6bd12d80 100644 --- a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py @@ -7,11 +7,11 @@ class TestAzureOpenAiChatPromptDriver(TestOpenAiChatPromptDriverFixtureMixin): @pytest.fixture def mock_chat_completion_create(self, mocker): - mock_chat_create = mocker.patch("openai.AzureOpenAI").return_value.chat.completions.with_raw_response.create + mock_chat_create = mocker.patch("openai.AzureOpenAI").return_value.chat.completions.create mock_choice = Mock() mock_choice.message.content = "model-output" mock_chat_create.return_value.headers = {} - mock_chat_create.return_value.parse.return_value.choices = [mock_choice] + mock_chat_create.return_value.choices = [mock_choice] return mock_chat_create @pytest.fixture diff --git a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py index fbc939005..1f8f07cf5 100644 --- a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py @@ -1,5 +1,3 @@ -import datetime - from transformers import AutoTokenizer from griptape.drivers import OpenAiChatPromptDriver @@ -13,11 +11,11 @@ class TestOpenAiChatPromptDriverFixtureMixin: @pytest.fixture def mock_chat_completion_create(self, mocker): - mock_chat_create = mocker.patch("openai.OpenAI").return_value.chat.completions.with_raw_response.create + mock_chat_create = mocker.patch("openai.OpenAI").return_value.chat.completions.create mock_choice = Mock() mock_choice.message.content = "model-output" mock_chat_create.return_value.headers = {} - mock_chat_create.return_value.parse.return_value.choices = [mock_choice] + mock_chat_create.return_value.choices = [mock_choice] return mock_chat_create @pytest.fixture @@ -202,7 +200,7 @@ def test_try_run_throws_when_prompt_stack_is_string(self): def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_chat_completion_create, prompt_stack): # Given driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, api_key="api-key") - mock_chat_completion_create.return_value.parse.return_value.choices = [choices] + mock_chat_completion_create.return_value.choices = [choices] # When with pytest.raises(Exception) as e: @@ -258,65 +256,6 @@ def test_max_output_tokens_with_max_tokens(self, messages): assert max_tokens == 42 - def test_extract_ratelimit_metadata(self): - response_with_headers = OpenAiApiResponseWithHeaders() - driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL) - driver._extract_ratelimit_metadata(response_with_headers) - - assert driver._ratelimit_requests_remaining == response_with_headers.remaining_requests - assert driver._ratelimit_tokens_remaining == response_with_headers.remaining_tokens - assert driver._ratelimit_request_limit == response_with_headers.limit_requests - assert driver._ratelimit_token_limit == response_with_headers.limit_tokens - - # Assert that the reset times are within one second of the expected value. - expected_request_reset_time = datetime.datetime.now() + datetime.timedelta( - seconds=response_with_headers.reset_requests_in - ) - expected_token_reset_time = datetime.datetime.now() + datetime.timedelta( - seconds=response_with_headers.reset_tokens_in - ) - - assert driver._ratelimit_requests_reset_at is not None - assert abs(driver._ratelimit_requests_reset_at - expected_request_reset_time) < datetime.timedelta(seconds=1) - assert driver._ratelimit_tokens_reset_at is not None - assert abs(driver._ratelimit_tokens_reset_at - expected_token_reset_time) < datetime.timedelta(seconds=1) - - def test_extract_ratelimit_metadata_with_subsecond_reset_times(self): - response_with_headers = OpenAiApiResponseWithHeaders( - reset_requests_in=1, reset_requests_in_unit="ms", reset_tokens_in=10, reset_tokens_in_unit="ms" - ) - driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, api_key="api-key") - driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL) - driver._extract_ratelimit_metadata(response_with_headers) - - # Assert that the reset times are within one second of the expected value. With a sub-second reset time, - # this is rounded up to one second in the future. - expected_request_reset_time = datetime.datetime.now() + datetime.timedelta(seconds=1) - expected_token_reset_time = datetime.datetime.now() + datetime.timedelta(seconds=1) - - assert driver._ratelimit_requests_reset_at is not None - assert abs(driver._ratelimit_requests_reset_at - expected_request_reset_time) < datetime.timedelta(seconds=1) - assert driver._ratelimit_tokens_reset_at is not None - assert abs(driver._ratelimit_tokens_reset_at - expected_token_reset_time) < datetime.timedelta(seconds=1) - - def test_extract_ratelimit_metadata_missing_headers(self): - class OpenAiApiResponseNoHeaders: - @property - def headers(self): - return {} - - response_without_headers = OpenAiApiResponseNoHeaders() - - driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL) - driver._extract_ratelimit_metadata(response_without_headers) - - assert driver._ratelimit_request_limit is None - assert driver._ratelimit_requests_remaining is None - assert driver._ratelimit_requests_reset_at is None - assert driver._ratelimit_token_limit is None - assert driver._ratelimit_tokens_remaining is None - assert driver._ratelimit_tokens_reset_at is None - def test_custom_tokenizer(self, mock_chat_completion_create, prompt_stack, messages): driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL,