diff --git a/outlines/models/openai.py b/outlines/models/openai.py index e845b940e..4d5e35535 100644 --- a/outlines/models/openai.py +++ b/outlines/models/openai.py @@ -1,7 +1,9 @@ """Integration with OpenAI's API.""" import functools import os -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union +from collections import deque +from itertools import zip_longest +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple, Union import numpy as np @@ -85,6 +87,39 @@ async def generate_base( return results + def longest_common_prefix(tokens1: List[int], tokens2: List[int]) -> List[int]: + i = 0 + while i < len(tokens1) and i < len(tokens2) and tokens1[i] == tokens2[i]: + i += 1 + return tokens1[:i] + + def get_choices_with_longest_common_prefix( + response: List[int], is_in: List[List[int]] + ) -> Tuple[List[int], List[List[int]]]: + max_len_prefix = 0 + is_in_left = [] + prefix = [] + for i in range(len(is_in)): + len_prefix = len(longest_common_prefix(response, is_in[i])) + + if len_prefix > max_len_prefix: + max_len_prefix = len_prefix + is_in_left = [is_in[i][len_prefix:]] + prefix = is_in[i][:len_prefix] + + elif len_prefix == max_len_prefix: + is_in_left.append(is_in[i][len_prefix:]) + + return prefix, is_in_left + + def build_optimistic_mask(transposed: deque[Set]) -> Dict: + # build the biggest mask possible, adding tokens left to right + to_mask: Set[int] = set() + while len(transposed) > 0 and len(to_mask | transposed[0]) <= 300: + to_mask = to_mask | transposed.popleft() + + return {token: 100 for token in to_mask} + @functools.partial(outlines.vectorize, signature="(),(m),()->(s)") async def generate_choice( prompt: str, @@ -95,12 +130,11 @@ async def generate_choice( .. warning:: - This function will call the API once for every token generated. + Worst case, this function may call the API as many times as tokens are in the response. - We tokenize every choice, iterate over the token lists, create a mask - with the current tokens and generate one token. We progressively - eliminate the choices that don't start with the currently decoded - sequence. + With the optimistic approach, we activate all tokens that could form all answers. If the solution returned + does not match any of the answers, we the call the API again only with the tokens that can be accepted as + next-token. In average, this approach returns a solution consuming less calls to the API. """ try: @@ -111,20 +145,33 @@ async def generate_choice( ) tokenizer = tiktoken.encoding_for_model(model_name) - encoded: List[List[int]] = [tokenizer.encode(word) for word in is_in] decoded_samples = [] for _ in range(samples): + is_in_left = is_in.copy() decoded: List[str] = [] - for i in range(max([len(word) for word in encoded])): - mask = {} - for word, tokenized_word in zip(is_in, encoded): - if not word.startswith("".join(decoded)): - continue - try: - mask[tokenized_word[i]] = 100 - except IndexError: - pass + + greedy = False # we try to generate the full response at each iteration + + while len(is_in_left) > 0: + encoded: List[List[int]] = [ + tokenizer.encode(word) for word in is_in_left + ] + + max_tokens_left = max([len(tokens) for tokens in encoded]) + transposed: deque[Set] = deque( + [ + {item for item in subset if item is not None} + for subset in zip_longest(*encoded) + ] + ) + + if not greedy: + mask = build_optimistic_mask(transposed) + else: + mask = {} + for token in transposed.popleft(): # build greedy mask + mask[token] = 100 if len(mask) == 0: break @@ -132,15 +179,46 @@ async def generate_choice( response = await call_api( model_name, format_prompt(prompt), - 1, + max_tokens_left if not greedy else 1, temperature, [], mask, 1, ) - decoded.append(extract_choice(response["choices"][0])) - prompt = prompt + "".join(decoded) + current_resp = extract_choice(response["choices"][0]) + + if current_resp in is_in_left: + decoded.append(current_resp) + break + else: + # map response to tokens + tokenized_resp = tokenizer.encode(current_resp) + ( + tokenized_resp, + encoded, + ) = get_choices_with_longest_common_prefix( + tokenized_resp, encoded + ) + + if len(tokenized_resp) == 0: + greedy = True # next iteration will be "greedy" + continue + else: + decoded.append("".join(tokenizer.decode(tokenized_resp))) + + # map back to words + is_in_left = [ + "".join(tokenizer.decode(tokens)) for tokens in encoded + ] + + if len(is_in_left) == 1: # only one choice left + decoded.append(is_in_left[0]) + break + + greedy = False # after each success, stay with (or switch to) "optimistic" approach + + prompt = prompt + "".join(decoded) decoded_samples.append("".join(decoded))