Skip to content

Commit

Permalink
Add support for OpenAI functions
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Nov 16, 2023
1 parent 9441958 commit ca03ff2
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion outlines/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union

import numpy as np
from pydantic import BaseModel

import outlines
from outlines.caching import cache
Expand All @@ -14,6 +15,18 @@
from openai import AsyncOpenAI


JSON_ERROR_MSG = """
_______________________________
< Damn OpenAI, missed it again! >
-------------------------------
\\ ^__^
\\ (oo)\\_______
(__)\\ )\\/\
||----w |
|| ||
"""


class OpenAIAPI:
def __init__(
self,
Expand Down Expand Up @@ -156,10 +169,19 @@ def __call__(
prompt: str,
max_tokens: int = 500,
*,
response_schema=None,
samples=1,
stop_at: Union[List[Optional[str]], str] = [],
is_in: Optional[List[str]] = None,
):
if response_schema is not None:
if isinstance(response_schema, type(BaseModel)):
_ = response_schema.model_json_schema()
else:
raise TypeError(
"The `schema_object` passed to the JSON generating function must either be a string that represents a valid JSON Schema, a Pydantic model or a function with a type-annotated signature."
)

if is_in is not None and stop_at:
raise TypeError("You cannot set `is_in` and `stop_at` at the same time.")
elif is_in is not None:
Expand All @@ -186,7 +208,9 @@ def call(*args, **kwargs):
openai.InternalServerError,
openai.RateLimitError,
) as e:
raise OSError(f"Could not connect to the OpenAI API: {e}")
raise OSError(
f"Could not connect to the OpenAI API: {e}. Responses from OpenAI are cached, and your generation will resume from where it started at the next call."
)
except (
openai.AuthenticationError,
openai.BadRequestError,
Expand Down

0 comments on commit ca03ff2

Please sign in to comment.