Skip to content

Commit

Permalink
Use PyTorch instead of NumPy
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Jul 6, 2023
1 parent c855860 commit 293826c
Show file tree
Hide file tree
Showing 11 changed files with 302 additions and 282 deletions.
4 changes: 2 additions & 2 deletions outlines/models/hf_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def create_int_constraint(
import torch

num_prompt_tokens = prompt_tokens.shape[-1]
mask = torch.from_numpy(create_int_mask(tokenizer.get_vocab()))
mask = create_int_mask(tokenizer.get_vocab())

def logit_processor(input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
"""Pre-process the model's output logits before generating the next token.
Expand Down Expand Up @@ -373,7 +373,7 @@ def create_float_constraint(
import torch

num_prompt_tokens = prompt_tokens.shape[-1]
mask = torch.from_numpy(create_float_mask(tokenizer.get_vocab()))
mask = create_float_mask(tokenizer.get_vocab())

def logit_processor(input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
"""Pre-process the model's output logits before generating the next token.
Expand Down
38 changes: 18 additions & 20 deletions outlines/models/transformers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import math
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import numpy as np
from numpy.typing import NDArray
import torch

from outlines.models.tokenizer import Tokenizer

Expand All @@ -27,29 +26,28 @@ def __init__(
self.tokenizer = tokenizer

def __call__(
self, input_ids: NDArray[np.int64], attention_mask: NDArray[np.int64]
) -> NDArray[np.float64]:
import torch

self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor
) -> torch.FloatTensor:
# `transformers` model accept `input_ids` of size at most equal to 2. We
# thus reshape the input array, call the model and reshape the output
# logits.
batch_shape = input_ids.shape[:-1]
num_tokens = input_ids.shape[-1]
input_ids = input_ids.reshape(math.prod(batch_shape), num_tokens)
output = self.model(
input_ids,
attention_mask=attention_mask,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
next_token_logits = output.logits[:, -1, :]
probs = torch.nn.functional.softmax(next_token_logits, dim=-1).squeeze()

with torch.no_grad():
input_ids = torch.from_numpy(input_ids).to(self.device)
attention_mask = torch.from_numpy(attention_mask).to(self.device)

output = self.model(input_ids, attention_mask=attention_mask)

next_token_logits = output.logits[:, -1, :]
probs = torch.nn.functional.softmax(next_token_logits, dim=-1).squeeze()
probs = torch.atleast_2d(probs)
numpy_probs = probs.cpu().detach().numpy()
probs = torch.atleast_2d(probs)
probs = probs.reshape(batch_shape + (-1,))

return numpy_probs.reshape(batch_shape + (-1,))
return probs


class TransformersTokenizer(Tokenizer):
Expand All @@ -72,13 +70,13 @@ def __init__(self, model_name: str, **kwargs):

def encode(
self, prompt: Union[str, List[str]], **kwargs
) -> Tuple[NDArray[np.int64], NDArray[np.int64]]:
) -> Tuple[torch.LongTensor, torch.LongTensor]:
kwargs["padding"] = True
kwargs["return_tensors"] = "np"
kwargs["return_tensors"] = "pt"
output = self.tokenizer(prompt, **kwargs)
return output["input_ids"], output["attention_mask"]

def decode(self, token_ids: NDArray[np.int64]) -> List[str]:
def decode(self, token_ids: torch.LongTensor) -> List[str]:
text = self.tokenizer.batch_decode(token_ids)
return text

Expand Down
13 changes: 6 additions & 7 deletions outlines/text/generate/continuation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import List, Optional

import numpy as np
from numpy.typing import NDArray
import torch

from outlines.text.generate.sequence import Sequence

Expand All @@ -19,8 +18,11 @@ class Continuation(Sequence):

def __init__(self, model, max_tokens: Optional[int]):
super().__init__(model, max_tokens)
self.eos_token_id = torch.tensor(
[self.model.tokenizer.eos_token_id], device=self.device
)

def is_finished(self, token_ids: NDArray[np.int64]) -> NDArray[np.bool_]:
def is_finished(self, token_ids: torch.LongTensor) -> torch.BoolTensor:
"""Determine whether the sequences reached maximum length of end with
and EOS token.
Expand All @@ -35,10 +37,7 @@ def is_finished(self, token_ids: NDArray[np.int64]) -> NDArray[np.bool_]:
The input sequences.
"""
is_finished = np.zeros((token_ids.shape[0],), dtype=np.bool_)
is_finished[token_ids[:, -1] == self.model.tokenizer.eos_token_id] = True

return is_finished
return token_ids[:, -1] == self.model.tokenizer.eos_token_id

def postprocess_completions(self, completions: List[str]) -> List[str]:
"""Remove the EOS token from the completion."""
Expand Down
120 changes: 65 additions & 55 deletions outlines/text/generate/sequence.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from typing import List, Optional, Tuple, Union

import numpy as np
from numpy.random import Generator
from numpy.typing import NDArray
import torch


class Sequence:
Expand All @@ -21,9 +19,13 @@ def __init__(self, model, max_tokens: Optional[int] = None):
"""
self.model = model
self.device = model.device
self.max_tokens = max_tokens
self.pad_token_id = torch.tensor(
model.tokenizer.pad_token_id, device=model.device
)

def is_finished(self, token_ids: NDArray[np.int64]) -> NDArray[np.bool_]:
def is_finished(self, token_ids: torch.LongTensor) -> torch.BoolTensor:
"""Determine whether we should stop the generation."""
raise NotImplementedError(
"`Sequence.is_finished` must be implemented by subclasses."
Expand All @@ -34,11 +36,11 @@ def postprocess_completions(self, completions: List[str]) -> List[str]:

def step(
self,
rng: Generator,
token_ids: NDArray[np.int64],
attention_mask: NDArray[np.int64],
rng: torch.Generator,
token_ids: torch.LongTensor,
attention_mask: torch.LongTensor,
samples: int = 1,
) -> Tuple[NDArray[np.int64], NDArray[float]]:
) -> Tuple[torch.LongTensor, torch.FloatTensor]:
"""Generate one or several tokens that complete the input sequence.
The sampling step consists in using a model to generate next-token
Expand Down Expand Up @@ -73,42 +75,48 @@ def step(
next_token_ids = vectorized_random_choice(rng, probs, samples)

# Add the missing `num_tokens` and `num_sample` dimensions
next_token_ids = np.expand_dims(next_token_ids, -1)
token_ids = np.expand_dims(token_ids, 0)
next_token_ids = torch.unsqueeze(next_token_ids, -1)
token_ids = torch.unsqueeze(token_ids, 0)

# Expand the input `token_ids` array to be able to concatenate several
# samples.
if samples > 1:
repetitions = (samples,) + (1,) * num_input_dims
token_ids = np.tile(token_ids, repetitions)
probs = np.tile(probs, repetitions)
token_ids = torch.tile(token_ids, repetitions)
probs = torch.tile(probs, repetitions)

token_ids = np.concatenate([token_ids, next_token_ids], axis=-1)
token_ids = torch.concatenate([token_ids, next_token_ids], axis=-1)

# Merge sample and batch dimensions by removing dimensions of length
# 1. The shape of the resulting arrays is `new_batch_shape + (num_tokens,)`
# and `new_batch_shape + (vocab_size,)` respectively.
token_ids = np.atleast_2d(token_ids.squeeze())
probs = np.atleast_2d(probs.squeeze())
token_ids = torch.atleast_2d(token_ids.squeeze())
probs = torch.atleast_2d(probs.squeeze())

return token_ids, probs

def expand_attention_mask(
self, attention_mask: NDArray[np.int64]
) -> NDArray[np.int64]:
self, attention_mask: torch.LongTensor
) -> torch.LongTensor:
"""Expand the attention mask after the last completion."""
batch_shape = attention_mask.shape[:-1]
attention_mask = np.concatenate(
[attention_mask, np.broadcast_to([1], batch_shape + (1,))], axis=-1
attention_mask = torch.concatenate(
[
attention_mask,
torch.broadcast_to(
torch.tensor([1], device=self.device), batch_shape + (1,)
),
],
axis=-1,
)
return attention_mask

def update_token_ids(
self,
is_finished: NDArray[np.bool_],
token_ids: NDArray[np.int64],
token_ids_unfinished: NDArray[np.int64],
) -> NDArray[np.int64]:
is_finished: torch.BoolTensor,
token_ids: torch.LongTensor,
token_ids_unfinished: torch.LongTensor,
) -> torch.LongTensor:
"""Update the array of token ids after the last completion.
We only generate new tokens for the sequences that are not finished. We thus
Expand All @@ -133,15 +141,15 @@ def update_token_ids(
"""
batch_shape = token_ids.shape[:-1]
num_tokens = token_ids.shape[-1]
new_token_ids = np.empty(batch_shape + (num_tokens + 1,), dtype=np.int64)

token_ids_finished = token_ids[is_finished]
batch_shape_finished = token_ids_finished.shape[:-1]
token_ids_finished = np.concatenate(
new_token_ids = torch.empty(
batch_shape + (num_tokens + 1,), dtype=torch.int64, device=self.device
)
token_ids_finished = torch.concatenate(
[
token_ids_finished,
np.broadcast_to(
[self.model.tokenizer.pad_token_id], batch_shape_finished + (1,)
token_ids[is_finished],
torch.broadcast_to(
self.pad_token_id,
token_ids[is_finished].shape[:-1] + (1,),
),
],
axis=-1,
Expand All @@ -152,11 +160,12 @@ def update_token_ids(

return new_token_ids

@torch.inference_mode()
def __call__(
self,
prompt: Union[str, List[str]],
samples: int = 1,
rng: Generator = np.random.default_rng(),
rng: Optional[torch.Generator] = None,
) -> Union[str, List[str]]:
"""Generate a new sequence given a prompt.
Expand All @@ -173,6 +182,13 @@ def __call__(
"""
token_ids, attention_mask = self.model.tokenizer.encode(prompt)

token_ids = token_ids.to(self.device)
attention_mask = attention_mask.to(self.device)

if rng is None:
rng = torch.Generator(device=self.device)

num_prompt_tokens = token_ids.shape[-1]

if samples > 1:
Expand All @@ -181,28 +197,23 @@ def __call__(

num_batch_dims = token_ids.ndim - 1
repetitions = (samples,) + (1,) * num_batch_dims
attention_mask = np.tile(attention_mask, repetitions)
attention_mask = torch.tile(attention_mask, repetitions)
attention_mask = self.expand_attention_mask(attention_mask)
else:
batch_shape = token_ids.shape[:-1]
is_finished = np.zeros(batch_shape, dtype=np.bool_)
is_finished = torch.zeros(batch_shape, dtype=torch.bool, device=self.device)

while True:
num_generated_tokens = token_ids.shape[-1] - num_prompt_tokens
if np.all(is_finished) or num_generated_tokens == self.max_tokens:
if torch.all(is_finished) or num_generated_tokens == self.max_tokens:
break

token_ids_unfinished = token_ids[~is_finished]
attention_mask_unfinished = attention_mask[~is_finished]
token_ids_unfinished, _ = self.step(
rng, token_ids_unfinished, attention_mask_unfinished
)

token_ids = self.update_token_ids(
is_finished, token_ids, token_ids_unfinished
updated_token_ids, _ = self.step(
rng, token_ids[~is_finished], attention_mask[~is_finished]
)
token_ids = self.update_token_ids(is_finished, token_ids, updated_token_ids)
attention_mask = self.expand_attention_mask(attention_mask)
is_finished[~is_finished] = self.is_finished(token_ids_unfinished).flatten()
is_finished[~is_finished] = self.is_finished(updated_token_ids).flatten()

result = self.model.tokenizer.decode(token_ids)
result = self.postprocess_completions(result)
Expand All @@ -213,12 +224,9 @@ def __call__(
return result


vsearchsorted = np.vectorize(np.searchsorted, otypes=[int], signature="(n),()->()")


def vectorized_random_choice(
rng: Generator,
p: NDArray[np.float64],
rng: torch.Generator,
p: torch.FloatTensor,
samples: int = 1,
):
"""Vectorized implementation of `np.random.choice`.
Expand All @@ -228,13 +236,13 @@ def vectorized_random_choice(
Note
----
`searchsorted` might be more efficient here since the number of elements
can be quite large.
`torch.searchsorted` may be more efficient, but it is not implemented for
every backend, for instance MPS.
Parameters
----------
rng
NumPy random number Generator instance
Torch random number Generator instance
p
An array of probability of shape `(num_probability_vectors, num_items)`
that must sum to 1.
Expand All @@ -247,8 +255,10 @@ def vectorized_random_choice(
"""

cumsum = np.expand_dims(p.cumsum(axis=-1), 0)
rand = rng.random((samples,) + p.shape[:-1])
idx = vsearchsorted(cumsum, rand)
cumsum = torch.unsqueeze(p.cumsum(axis=-1), 0)
rand = torch.rand(
(samples,) + p.shape[:-1] + (1,), generator=rng, device=rng.device
)
idx = (cumsum < rand).sum(axis=-1)

return idx
Loading

0 comments on commit 293826c

Please sign in to comment.