Skip to content

Commit

Permalink
remove unnecessary tokenizer requirement in models.openai, remove dea…
Browse files Browse the repository at this point in the history
…d code
  • Loading branch information
lapp0 committed Sep 15, 2024
1 parent 91bddb1 commit 3893746
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 241 deletions.
198 changes: 4 additions & 194 deletions outlines/models/openai.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
"""Integration with OpenAI's API."""
import copy
import functools
import warnings
from dataclasses import asdict, dataclass, field, replace
from itertools import zip_longest
from typing import Callable, Dict, List, Optional, Set, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -75,7 +73,6 @@ def __init__(
self,
client,
config,
tokenizer=None,
system_prompt: Optional[str] = None,
):
"""Create an `OpenAI` instance.
Expand All @@ -90,12 +87,9 @@ def __init__(
config
An instance of `OpenAIConfig`. Can be useful to specify some
parameters that cannot be set by calling this class' methods.
tokenizer
The tokenizer associated with the model the client connects to.
"""

self.client = client
self.tokenizer = tokenizer
self.config = config

# We count the total number of prompt and generated tokens as returned
Expand Down Expand Up @@ -161,97 +155,6 @@ def stream(self, *args, **kwargs):
"Streaming is currently not supported for the OpenAI API"
)

def generate_choice(
self,
prompt: str,
choices: List[str],
max_tokens: Optional[int] = None,
system_prompt: Optional[str] = None,
) -> str:
"""Call the OpenAI API to generate one of several choices.
Parameters
----------
prompt
A string or list of strings that will be used to prompt the model
choices
The list of strings between which we ask the model to choose
max_tokens
The maximum number of tokens to generate
system_prompt
The content of the system message that precedes the user's prompt.
"""
if self.tokenizer is None:
raise ValueError(
"You must initialize the `OpenAI` class with a tokenizer to use `outlines.generate.choice`"
)

config = replace(self.config, max_tokens=max_tokens)

greedy = False
decoded: List[str] = []
encoded_choices_left: List[List[int]] = [
self.tokenizer.encode(word) for word in choices
]

while len(encoded_choices_left) > 0:
max_tokens_left = max([len(tokens) for tokens in encoded_choices_left])
transposed_choices_left: List[Set] = [
{item for item in subset if item is not None}
for subset in zip_longest(*encoded_choices_left)
]

if not greedy:
mask = build_optimistic_mask(transposed_choices_left)
else:
mask = {}
for token in transposed_choices_left[0]: # build greedy mask
mask[token] = 100

if len(mask) == 0:
break

config = replace(config, logit_bias=mask, max_tokens=max_tokens_left)

response, prompt_tokens, completion_tokens = generate_chat(
prompt, system_prompt, self.client, config
)
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens

encoded_response = self.tokenizer.encode(response)

if encoded_response in encoded_choices_left:
decoded.append(response)
break
else:
(
encoded_response,
encoded_choices_left,
) = find_response_choices_intersection(
encoded_response, encoded_choices_left
)

if len(encoded_response) == 0:
greedy = True # next iteration will be "greedy"
continue
else:
decoded.append("".join(self.tokenizer.decode(encoded_response)))

if len(encoded_choices_left) == 1: # only one choice left
choice_left = self.tokenizer.decode(encoded_choices_left[0])
decoded.append(choice_left)
break

greedy = False # after each success, stay with (or switch to) "optimistic" approach

prompt = prompt + "".join(decoded)

choice = "".join(decoded)

return choice

def new_with_replacements(self, **kwargs):
new_instance = copy.copy(self)
new_instance.config = replace(new_instance.config, **kwargs)
Expand Down Expand Up @@ -316,81 +219,6 @@ async def call_api(prompt, system_prompt, config):
return results, usage["prompt_tokens"], usage["completion_tokens"]


def find_longest_intersection(response: List[int], choice: List[int]) -> List[int]:
"""Find the longest intersection between the response and the choice."""
for i, (token_r, token_c) in enumerate(zip_longest(response, choice)):
if token_r != token_c:
return response[:i]

return response


def find_response_choices_intersection(
response: List[int], choices: List[List[int]]
) -> Tuple[List[int], List[List[int]]]:
"""Find the longest intersection between the response and the different
choices.
Say the response is of the form `[1, 2, 3, 4, 5]` and we have the choices
`[[1, 2], [1, 2, 3], [6, 7, 8]` then the function will return `[1, 2, 3]` as the
intersection, and `[[]]` as the list of choices left.
Parameters
----------
response
The model's response
choices
The remaining possible choices
Returns
-------
A tuple that contains the longest intersection between the response and the
different choices, and the choices which start with this intersection, with the
intersection removed.
"""
max_len_prefix = 0
choices_left = []
longest_prefix = []
for i, choice in enumerate(choices):
# Find the longest intersection between the response and the choice.
prefix = find_longest_intersection(response, choice)

if len(prefix) > max_len_prefix:
max_len_prefix = len(prefix)
choices_left = [choice[len(prefix) :]]
longest_prefix = prefix

elif len(prefix) == max_len_prefix:
choices_left.append(choice[len(prefix) :])

return longest_prefix, choices_left


def build_optimistic_mask(
transposed: List[Set[int]], max_mask_size: int = 300
) -> Dict[int, int]:
"""We build the largest mask possible.
Tokens are added from left to right, so if the encoded choices are e.g.
`[[1,2], [3,4]]`, `1` and `3` will be added before `2` and `4`.
Parameters
----------
transposed
A list of lists that contain the nth token of each choice.
"""
mask: Dict[int, int] = {}
for tokens in transposed:
for token in tokens:
if len(mask) == max_mask_size:
return mask
mask[token] = 100

return mask


def error_handler(api_call_fn: Callable) -> Callable:
"""Handle OpenAI API errors and missing API key."""

Expand Down Expand Up @@ -430,11 +258,10 @@ def openai_model(
**openai_client_params,
):
try:
import tiktoken
from openai import AsyncOpenAI
except ImportError:
raise ImportError(
"The `openai` and `tiktoken` libraries needs to be installed in order to use Outlines' OpenAI integration."
"The `openai` library needs to be installed in order to use Outlines' OpenAI integration."
)

if config is not None:
Expand All @@ -444,15 +271,7 @@ def openai_model(

client = AsyncOpenAI(**openai_client_params)

try:
tokenizer = tiktoken.encoding_for_model(model_name)
except KeyError:
warnings.warn(
f"Could not find a tokenizer for model {model_name}. Using default cl100k_base."
)
tokenizer = tiktoken.get_encoding("cl100k_base")

return OpenAI(client, config, tokenizer)
return OpenAI(client, config)


def azure_openai(
Expand All @@ -462,11 +281,10 @@ def azure_openai(
**azure_openai_client_params,
):
try:
import tiktoken
from openai import AsyncAzureOpenAI
except ImportError:
raise ImportError(
"The `openai` and `tiktoken` libraries needs to be installed in order to use Outlines' Azure OpenAI integration."
"The `openai` library needs to be installed in order to use Outlines' Azure OpenAI integration."
)

if config is not None:
Expand All @@ -476,12 +294,4 @@ def azure_openai(

client = AsyncAzureOpenAI(**azure_openai_client_params)

try:
tokenizer = tiktoken.encoding_for_model(model_name or deployment_name)
except KeyError:
warnings.warn(
f"Could not find a tokenizer for model {model_name or deployment_name}. Using default cl100k_base."
)
tokenizer = tiktoken.get_encoding("cl100k_base")

return OpenAI(client, config, tokenizer)
return OpenAI(client, config)
47 changes: 0 additions & 47 deletions tests/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,12 @@
from unittest import mock
from unittest.mock import MagicMock

import pytest
from openai import AsyncOpenAI

from outlines import generate
from outlines.models.openai import (
OpenAI,
OpenAIConfig,
build_optimistic_mask,
find_longest_intersection,
find_response_choices_intersection,
)


Expand Down Expand Up @@ -77,46 +73,3 @@ def test_openai_choice_call():
# just integration between generate.choice and models.openai
generator = generate.choice(model, ["foo", "bar"])
assert generator("hi") == "foo"


@pytest.mark.parametrize(
"response,choice,expected_intersection,expected_choices_left",
(
([1, 2, 3, 4], [[5, 6]], [], [[5, 6]]),
([1, 2, 3, 4], [[5, 6], [7, 8]], [], [[5, 6], [7, 8]]),
([1, 2, 3, 4], [[1, 2], [7, 8]], [1, 2], [[]]),
([1, 2], [[1, 2, 3, 4], [1, 2]], [1, 2], [[3, 4], []]),
([1, 2, 3], [[1, 2, 3, 4], [1, 2]], [1, 2, 3], [[4]]),
),
)
def test_find_response_choices_intersection(
response, choice, expected_intersection, expected_choices_left
):
intersection, choices_left = find_response_choices_intersection(response, choice)
assert intersection == expected_intersection
assert choices_left == expected_choices_left


@pytest.mark.parametrize(
"response,choice,expected_prefix",
(
([1, 2, 3], [1, 2, 3, 4], [1, 2, 3]),
([1, 2, 3], [1, 2, 3], [1, 2, 3]),
([4, 5], [1, 2, 3], []),
),
)
def test_find_longest_common_prefix(response, choice, expected_prefix):
prefix = find_longest_intersection(response, choice)
assert prefix == expected_prefix


@pytest.mark.parametrize(
"transposed,mask_size,expected_mask",
(
([{1, 2}, {3, 4}], 3, {1: 100, 2: 100, 3: 100}),
([{1, 2}, {3, 4}], 4, {1: 100, 2: 100, 3: 100, 4: 100}),
),
)
def test_build_optimistic_mask(transposed, mask_size, expected_mask):
mask = build_optimistic_mask(transposed, mask_size)
assert mask == expected_mask

0 comments on commit 3893746

Please sign in to comment.