Skip to content

Commit

Permalink
Track token count and add system prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
HerrIvan authored and rlouf committed Dec 6, 2023
1 parent 742ebb6 commit ebde682
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 18 deletions.
86 changes: 69 additions & 17 deletions outlines/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def __init__(
model_name: str,
api_key: Optional[str] = None,
max_retries: int = 6,
timeout: Optional[float] = None,
system_prompt: Optional[str] = None,
config: Optional[OpenAIConfig] = None,
):
"""Create an `OpenAI` instance.
Expand All @@ -93,6 +95,10 @@ def __init__(
`openai.api_key`.
max_retries
The maximum number of retries when calls to the API fail.
timeout
Duration after which the request times out.
system_prompt
The content of the system message that precedes the user's prompt.
config
An instance of `OpenAIConfig`. Can be useful to specify some
parameters that cannot be set by calling this class' methods.
Expand Down Expand Up @@ -120,7 +126,16 @@ def __init__(
else:
self.config = OpenAIConfig(model=model_name)

self.client = openai.AsyncOpenAI(api_key=api_key, max_retries=max_retries)
self.client = openai.AsyncOpenAI(
api_key=api_key, max_retries=max_retries, timeout=timeout
)
self.system_prompt = system_prompt

# We count the total number of prompt and generated tokens as returned
# by the OpenAI API, summed over all the requests performed with this
# model instance.
self.prompt_tokens = 0
self.completion_tokens = 0

def __call__(
self,
Expand Down Expand Up @@ -158,7 +173,13 @@ def __call__(
)
)
if "gpt-" in self.config.model:
return generate_chat(prompt, self.client, config)
response, usage = generate_chat(
prompt, self.system_prompt, self.client, config
)
self.prompt_tokens += usage["prompt_tokens"]
self.completion_tokens += usage["completion_tokens"]

return response

def generate_choice(
self, prompt: str, choices: List[str], max_tokens: Optional[int] = None
Expand Down Expand Up @@ -210,7 +231,13 @@ def generate_choice(
break

config = replace(config, logit_bias=mask, max_tokens=max_tokens_left)
response = generate_chat(prompt, self.client, config)

response, usage = generate_chat(
prompt, self.system_prompt, self.client, config
)
self.completion_tokens += usage["completion_tokens"]
self.prompt_tokens += usage["prompt_tokens"]

encoded_response = tokenizer.encode(response)

if encoded_response in encoded_choices_left:
Expand Down Expand Up @@ -255,22 +282,46 @@ def __repr__(self):


@cache(ignore="client")
@functools.partial(outlines.vectorize, signature="(),(),()->(s)")
@functools.partial(outlines.vectorize, signature="(),(),(),()->(s),()")
async def generate_chat(
prompt: str, client: "AsyncOpenAI", config: OpenAIConfig
) -> np.ndarray:
prompt: str,
system_prompt: Union[str, None],
client: "AsyncOpenAI",
config: OpenAIConfig,
) -> Tuple[np.ndarray, Dict]:
"""Call OpenAI's Chat Completion API.
Parameters
----------
prompt
The prompt we use to start the generation. Passed to the model
with the "user" role.
system_prompt
The system prompt, passed to the model with the "system" role
before the prompt.
client
The API client
config
An `OpenAIConfig` instance.
Returns
-------
A tuple that contains the model's response(s) and usage statistics.
"""
system_message = (
[{"role": "system", "content": system_prompt}] if system_prompt else []
)
user_message = [{"role": "user", "content": prompt}]

responses = await client.chat.completions.create(
messages=[{"role": "user", "content": prompt}], **asdict(config) # type: ignore
messages=system_message + user_message,
**asdict(config), # type: ignore
)

if config.n == 1:
results = np.array([responses.choices[0].message.content])
else:
results = np.array(
[responses.choices[i].message.content for i in range(config.n)]
)
results = np.array([responses.choices[i].message.content for i in range(config.n)])

return results
return results, responses.usage.model_dump()


openai = OpenAI
Expand All @@ -292,8 +343,8 @@ def find_response_choices_intersection(
choices.
Say the response is of the form `[1, 2, 3, 4, 5]` and we have the choices
`[[1, 2], [1, 2, 3], [6, 7, 8]` then the function will return `[1, 2]` as the
intersection, and `[1, 2, 3]` as the choice that is left.
`[[1, 2], [1, 2, 3], [6, 7, 8]` then the function will return `[1, 2, 3]` as the
intersection, and `[[]]` as the list of choices left.
Parameters
----------
Expand All @@ -305,7 +356,8 @@ def find_response_choices_intersection(
Returns
-------
A tuple that contains the longest intersection between the response and the
different choices, and the choices which start with this intersection.
different choices, and the choices which start with this intersection, with the
intersection removed.
"""
max_len_prefix = 0
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ module = [
"jinja2",
"joblib.*",
"jsonschema.*",
"openai",
"openai.*",
"nest_asyncio",
"numpy.*",
"perscache.*",
Expand Down

0 comments on commit ebde682

Please sign in to comment.