Skip to content

Commit

Permalink
Manage samples outside of sequence_generator
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Feb 6, 2024
1 parent 238c11a commit e00c53f
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 223 deletions.
97 changes: 48 additions & 49 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
import json as pyjson
import warnings
from typing import Iterator, List, Optional, Tuple, Union
from typing import Iterator, List, Optional, Union

import torch

from outlines.generate.generator import (
GenerationState,
init_generator_state,
sequence_generator,
token_generator,
)
from outlines.fsm.fsm import FSMState
from outlines.generate.generator import sequence_generator, token_generator


class SequenceGenerator:
Expand Down Expand Up @@ -51,9 +46,9 @@ def __init__(

def get_generated_token_ids(
self,
init_state: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
prompt_token_ids: torch.Tensor,
prompts: List[str],
last_state: GenerationState,
token_ids: torch.Tensor,
num_samples: int,
) -> List[torch.Tensor]:
"""Get the tokens generated so far.
Expand All @@ -64,8 +59,8 @@ def get_generated_token_ids(
The initial state of the generation.
prompts
The prompts passed to the generator.
last_state
The current state of the generation
token_ids
The generated token ids.
num_samples
The number of samples taken for each sequence
Expand All @@ -74,17 +69,12 @@ def get_generated_token_ids(
A tensor that contains the token ids that have been generated so far.
"""
prompt_token_ids = init_state[0]
prompt_lengths = [
len(prompt_token_ids[i])
for _ in range(num_samples)
for i in range(len(prompts))
]

# We flatten the obtained token_ids since the tokenizer's decoder
# only accepts tensor with two dimensions
token_ids = last_state.token_ids.reshape((-1, last_state.token_ids.shape[-1]))

token_ids = [
cur_token_ids[length:]
for cur_token_ids, length in zip(token_ids, prompt_lengths)
Expand Down Expand Up @@ -208,37 +198,43 @@ def __call__(

stop_sequences = stop_at or self.stop_sequences
max_tokens = max_tokens or self.max_tokens
<<<<<<< HEAD
=======
num_samples = self.num_particles
num_sequences = len(prompts)
>>>>>>> c8bf540 (Pass number of samples via sampler)
fsms = [self.fsm.copy() for _ in prompts]

if rng is None:
rng = torch.Generator(device=self.device)
rng.seed()

init_state = init_generator_state(
self.tokenizer, self.device, prompts, kv_cache
)
init_fsm_states = [self.fsm.first_state for _ in prompts]
prompt_token_ids, attention_masks = self.tokenizer.encode(prompts)
prompt_token_ids = prompt_token_ids.to(self.device)
attention_masks = attention_masks.to(self.device)

# To draw multiple samples we repeat the prompt as many times
# as there are samples. We copy the FSMs and initialize the
# FSM states.
num_samples = self.num_particles
batch_size = len(prompts)

prompt_token_ids = torch.repeat_interleave(prompt_token_ids, num_samples, dim=0)
attention_masks = torch.repeat_interleave(attention_masks, num_samples, dim=0)
fsm_states = [FSMState(0) for _ in range(batch_size * num_samples)]
fsms = [self.fsm.copy() for _ in range(batch_size * num_samples)]

states = sequence_generator(
self.generate_token,
fsms,
init_state,
init_fsm_states,
prompt_token_ids,
attention_masks,
fsm_states,
rng=rng,
num_samples=num_samples,
)

while True:
try:
last_state = next(states)
if max_tokens or stop_sequences:
token_ids = last_state.token_ids
generated_token_ids = self.get_generated_token_ids(
init_state, prompts, last_state, num_samples
prompt_token_ids, prompts, token_ids, num_samples
)
if max_tokens and len(generated_token_ids[0]) >= max_tokens:
break
Expand All @@ -249,8 +245,9 @@ def __call__(
except StopIteration:
break

token_ids = last_state.token_ids
generated_token_ids = self.get_generated_token_ids(
init_state, prompts, last_state, num_samples
prompt_token_ids, prompts, token_ids, num_samples
)

generated = self.tokenizer.decode(generated_token_ids)
Expand Down Expand Up @@ -320,37 +317,42 @@ def stream(

stop_sequences = stop_at or self.stop_sequences
max_tokens = max_tokens or self.max_tokens
<<<<<<< HEAD
=======
num_samples = self.num_particles
num_sequences = len(prompts)
>>>>>>> c8bf540 (Pass number of samples via sampler)
fsms = [self.fsm.copy() for _ in prompts]

if rng is None:
rng = torch.Generator(device=self.device)
rng.seed()

init_state = init_generator_state(
self.tokenizer, self.device, prompts, kv_cache
)
init_fsm_states = [self.fsm.first_state for _ in prompts]
prompt_token_ids, attention_masks = self.tokenizer.encode(prompts)
prompt_token_ids = prompt_token_ids.to(self.device)
attention_masks = attention_masks.to(self.device)

# To draw multiple samples we repeat the prompt as many times
# as there are samples. We copy the FSMs and initialize the
# FSM states.
num_samples = self.num_particles
batch_size = len(prompts)

prompt_token_ids = torch.repeat_interleave(prompt_token_ids, num_samples, dim=0)
attention_masks = torch.repeat_interleave(attention_masks, num_samples, dim=0)
fsm_states = [FSMState(0) for _ in range(batch_size * num_samples)]
fsms = [self.fsm.copy() for _ in range(batch_size * num_samples)]

states = sequence_generator(
self.generate_token,
fsms,
init_state,
init_fsm_states,
num_samples=num_samples,
prompt_token_ids,
attention_masks,
fsm_states,
rng=rng,
)

def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]:
previously_generated_sequences = [
"" for _ in range(num_sequences)
"" for _ in range(batch_size)
] * num_samples
num_generated = 0
is_stop_at_reached = [False for _ in range(num_sequences)] * num_samples
is_stop_at_reached = [False for _ in range(batch_size)] * num_samples
while True:
if (max_tokens and num_generated >= max_tokens) or all(
is_stop_at_reached
Expand All @@ -361,10 +363,7 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]:
num_generated += 1
except StopIteration:
return
generated_token_ids = sequence.token_ids[:, :, -num_generated:]
generated_token_ids = generated_token_ids.reshape(
-1, generated_token_ids.shape[-1]
)
generated_token_ids = sequence.token_ids[:, -num_generated:]
generated_sequences = self.tokenizer.decode(generated_token_ids)
next_tokens = [
token[len(sequence) :] if not stop else ""
Expand Down
55 changes: 6 additions & 49 deletions outlines/generate/generator.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import dataclasses
import math
from typing import TYPE_CHECKING, Callable, Iterator, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Callable, Iterator, List, Union

import torch

from outlines.fsm.fsm import FSMState

if TYPE_CHECKING:
from outlines.fsm.fsm import FSM
from outlines.models.tokenizer import Tokenizer
from outlines.samplers import Sampler


Expand All @@ -20,45 +19,12 @@ class GenerationState:
fsm_states: List[FSMState]


def init_generator_state(
tokenizer: "Tokenizer",
device: str,
prompt: Union[str, List[str]],
kv_cache: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Initialize the generation state.
This method is responsible for encoding the prompt, moving token ids
to the device and initializing the random number generator.
Parameters
----------
tokenizer:
The model's tokenizer.
device:
The name of the device on which to load the token ids and attention
masks.
prompt
The prompt on which the generation is conditioned.
Returns
-------
A `GenerationState` object.
"""
token_ids, attention_masks = tokenizer.encode(prompt)
token_ids = token_ids.to(device)
attention_masks = attention_masks.to(device)

return token_ids, attention_masks, kv_cache


def sequence_generator(
token_generator: Callable,
fsms: List["FSM"],
init_state: Tuple,
token_ids: torch.Tensor,
attention_masks: torch.Tensor,
fsm_states: List[FSMState],
num_samples: int = 1,
rng: torch.Generator = torch.Generator(),
) -> Iterator[GenerationState]:
"""Generates sequences of tokens.
Expand All @@ -81,16 +47,7 @@ def sequence_generator(
A new sequence.
"""
token_ids, attention_masks, kv_cache = init_state
batch_shape = token_ids.shape[:-1]

# To take several samples we duplicate `token_ids`, `attention_masks`
# and `fsm_states` as many times as the number of samples requested.
# The resulting tensors are of shape (num_samples * num_batches, num_tokens)
token_ids = torch.repeat_interleave(token_ids, num_samples, dim=0)
attention_masks = torch.repeat_interleave(attention_masks, num_samples, dim=0)
fsm_states = [state for state in fsm_states for _ in range(num_samples)]
fsms = [fsm.copy() for fsm in fsms for _ in range(num_samples)]
kv_cache = None

while True:
allowed_tokens = get_allowed_tokens(fsms, fsm_states)
Expand All @@ -110,15 +67,15 @@ def sequence_generator(

if is_finished:
yield GenerationState(
token_ids.reshape((num_samples,) + batch_shape + token_ids.shape[-1:]),
token_ids,
kv_cache,
logits,
fsm_states,
)
return

yield GenerationState(
token_ids.reshape((num_samples,) + batch_shape + token_ids.shape[-1:]),
token_ids,
kv_cache,
logits,
fsm_states,
Expand Down
Loading

0 comments on commit e00c53f

Please sign in to comment.