Skip to content

Commit

Permalink
Upgrade OpenAI Dependency (to fix "Quota Reached" error) & some lint …
Browse files Browse the repository at this point in the history
…fixes (#21)

* openai-version-bump

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* run fmt

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* s

* not going to bother

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update disc

* str

---------

Co-authored-by: Josh <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 23, 2024
1 parent c1e911d commit 8dd8d50
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 79 deletions.
62 changes: 11 additions & 51 deletions app/adapters/openai/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@
from typing import Required
from typing import TypedDict

import backoff
import openai.error
from openai.openai_object import OpenAIObject
import openai
from openai.types.chat import ChatCompletion

from app import settings


openai_client = openai.AsyncOpenAI(
api_key=settings.OPENAI_API_KEY,
)


class OpenAIModel(StrEnum):
Expand Down Expand Up @@ -55,56 +61,11 @@ class FunctionSchema(TypedDict):
parameters: Parameters


class GPTResponse(TypedDict):
choices: Sequence[Message]


MAX_BACKOFF_TIME = 16


def _is_non_retriable_error(error: Exception) -> bool:
"""\
Determine whether an error is non-retriable.
"""
if isinstance(error, openai.error.APIConnectionError):
return error.should_retry # TODO: confirm this
elif isinstance(
error,
(
openai.error.APIError, # TODO: confirm this
openai.error.TryAgain,
openai.error.Timeout,
openai.error.RateLimitError,
openai.error.ServiceUnavailableError,
),
):
return True
elif isinstance(
error,
(
openai.error.InvalidRequestError,
openai.error.AuthenticationError,
openai.error.PermissionError,
openai.error.InvalidAPIType,
openai.error.SignatureVerificationError,
),
):
return False
else:
raise NotImplementedError(f"Unknown error type: {error}")


@backoff.on_exception(
backoff.expo,
openai.error.OpenAIError,
max_time=MAX_BACKOFF_TIME,
giveup=_is_non_retriable_error,
)
async def send(
model: OpenAIModel,
messages: Sequence[Message],
functions: Sequence[FunctionSchema] | None = None,
) -> OpenAIObject:
) -> ChatCompletion:
"""\
Send a message to the OpenAI API, as a given model.
Expand All @@ -114,6 +75,5 @@ async def send(
if functions is not None:
kwargs["functions"] = functions

response = await openai.ChatCompletion.acreate(**kwargs)
assert isinstance(response, OpenAIObject)
response = await openai_client.chat.completions.create(**kwargs)
return response
59 changes: 34 additions & 25 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Bot(discord.Client):
async def start(self, *args, **kwargs) -> None:
state.read_database = database.Database(
database.dsn(
scheme="postgresql",
scheme=settings.READ_DB_SCHEME,
user=settings.READ_DB_USER,
password=settings.READ_DB_PASS,
host=settings.READ_DB_HOST,
Expand All @@ -52,7 +52,7 @@ async def start(self, *args, **kwargs) -> None:

state.write_database = database.Database(
database.dsn(
scheme="postgresql",
scheme=settings.WRITE_DB_SCHEME,
user=settings.WRITE_DB_USER,
password=settings.WRITE_DB_PASS,
host=settings.WRITE_DB_HOST,
Expand Down Expand Up @@ -117,11 +117,6 @@ async def on_ready():
# NOTE: we can't use this as a lifecycle hook because
# it may be called more than a single time.
# our lifecycle hook is in our Bot class definition

import openai

openai.api_key = settings.OPENAI_API_KEY

await command_tree.sync()


Expand Down Expand Up @@ -212,22 +207,32 @@ async def on_message(message: discord.Message):
)
if not gpt_response:
await message.channel.send(
f"Request failed after multiple retries.\n"
f"Please try again after some time.\n"
f"If this issue persists, please contact cmyui#0425 on discord."
"Request failed after multiple retries.\n"
"Please try again after some time.\n"
'If this issue persists, please contact "cmyui" on discord.'
)
return

gpt_choice = gpt_response["choices"][0]
gpt_message = gpt_choice["message"]
message_history.append(gpt_message)
gpt_choice = gpt_response.choices[0]
gpt_message = gpt_choice.message
assert gpt_message.content is not None

if gpt_choice["finish_reason"] == "stop":
gpt_response_content = gpt_message["content"]
message_history.append(
gpt.Message(
role=gpt_message.role,
content=gpt_message.content,
)
)

if gpt_choice.finish_reason == "stop":
gpt_response_content: str = gpt_message.content

elif gpt_choice["finish_reason"] == "function_call":
function_name = gpt_message["function_call"]["name"]
function_kwargs = json.loads(gpt_message["function_call"]["arguments"])
elif (
gpt_choice.finish_reason == "function_call"
and gpt_message.function_call is not None
):
function_name = gpt_message.function_call.name
function_kwargs = json.loads(gpt_message.function_call.arguments)

ai_function = openai_functions.ai_functions[function_name]
function_response = await ai_function["callback"](**function_kwargs)
Expand All @@ -243,13 +248,15 @@ async def on_message(message: discord.Message):
}
)
gpt_response = await gpt.send(tracked_thread["model"], message_history)
gpt_response_content = gpt_response["choices"][0]["message"]["content"]
assert gpt_response.choices[0].message.content is not None
gpt_response_content = gpt_response.choices[0].message.content

else:
raise NotImplementedError(
f"Unknown chatgpt finish reason: {gpt_choice['finish_reason']}"
f"Unknown chatgpt finish reason: {gpt_choice.finish_reason}"
)

assert gpt_response.usage is not None
input_tokens = gpt_response.usage.prompt_tokens
output_tokens = gpt_response.usage.completion_tokens

Expand Down Expand Up @@ -344,8 +351,8 @@ async def monthlycost(interaction: discord.Interaction):
response_cost = sum(per_user_cost.values())

message_chunks = [
f"**Monthly Cost Breakdown**",
f"**----------------------**",
"**Monthly Cost Breakdown**",
"**----------------------**",
"",
]
for user_id, cost in per_user_cost.items():
Expand Down Expand Up @@ -380,8 +387,8 @@ async def threadcost(interaction: discord.Interaction):
response_cost = sum(per_user_cost.values())

message_chunks = [
f"**Thread Cost Breakdown**",
f"**---------------------**",
"**Thread Cost Breakdown**",
"**---------------------**",
"",
]
for user_id, cost in per_user_cost.items():
Expand Down Expand Up @@ -472,7 +479,7 @@ async def context(
content="\n".join(
(
f"**Context length (messages length preserved) updated to {context_length}**",
f"NOTE: longer context costs linearly more tokens, so please take care.",
"NOTE: longer context costs linearly more tokens, so please take care.",
)
)
)
Expand Down Expand Up @@ -542,6 +549,8 @@ async def summarize(
gpt_response = await gpt.send(gpt.OpenAIModel.GPT_4_OMNI, messages)

gpt_response_content = gpt_response.choices[0].message.content
assert gpt_response_content is not None

# tokens_spent = gpt_response.usage.total_tokens

for chunk in split_message(gpt_response_content, 2000):
Expand Down
2 changes: 1 addition & 1 deletion app/openai_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_function_openai_schema(f: Callable[..., Awaitable[str]]) -> FunctionSche
or len(param_type.__metadata__) != 1
):
logging.warning(
f"Function decorated with @ai_function lacks parameter description annotation(s)",
"Function decorated with @ai_function lacks parameter description annotation(s)",
extra={
"param_name": param_name,
"param_type": param_type,
Expand Down
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
asyncpg
backoff
databases
discord.py
httpx
openai<1.0.0
openai
python-dotenv

0 comments on commit 8dd8d50

Please sign in to comment.