Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
patricebechard authored Jul 12, 2024
2 parents fc906f7 + b54a964 commit df53cc4
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 88 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ First time here? Go to our [setup guide](https://outlines-dev.github.io/outlines
- [x] 💾 Caching of generations
- [x] 🗂️ Batch inference
- [x] 🎲 Sample with the greedy, multinomial and beam search algorithms (and more to come!)
- [x] 🚀 [Serve with vLLM](https://outlines-dev.github.io/outlines/reference/vllm), with official Docker image, [`outlinesdev/outlines`](https://hub.docker.com/r/outlinesdev/outlines)!
- [x] 🚀 [Serve with vLLM](https://outlines-dev.github.io/outlines/reference/serve/vllm), with official Docker image, [`outlinesdev/outlines`](https://hub.docker.com/r/outlinesdev/outlines)!


Outlines 〰 has new releases and features coming every week. Make sure to ⭐ star and 👀 watch this repository, follow [@dottxtai][dottxt-twitter] to stay up to date!
Expand Down
106 changes: 81 additions & 25 deletions benchmarks/bench_processors.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import mlx.core as mx
import numpy as np
import torch

from outlines.processors import OutlinesLogitsProcessor
import outlines.models as models
from outlines.processors import OutlinesLogitsProcessor, RegexLogitsProcessor

try:
import mlx.core as mx
except ImportError:
pass


def is_mlx_lm_allowed():
Expand All @@ -13,40 +18,91 @@ def is_mlx_lm_allowed():
return mx.metal.is_available()


def get_mock_processor_inputs(array_library, num_tokens=30000):
"""
logits: (4, 30,000 ) dtype=float
input_ids shape: (4, 2048) dtype=int
"""
if array_library == "torch":
logits = torch.rand((4, num_tokens), dtype=torch.float)
input_ids = torch.randint(
low=0, high=num_tokens, size=(4, 2048), dtype=torch.int
)
elif array_library == "torch_cuda":
logits = torch.rand((4, num_tokens), dtype=torch.float, device="cuda")
input_ids = torch.randint(
low=0, high=num_tokens, size=(4, 2048), dtype=torch.int, device="cuda"
)
elif array_library == "numpy":
logits = np.random.rand(4, num_tokens).astype(np.float32)
input_ids = np.random.randint(low=0, high=num_tokens, size=(4, 2048))
elif array_library == "mlx":
logits = mx.random.uniform(
low=-1e9, high=1e9, shape=(4, num_tokens), dtype=mx.float32
)
input_ids = mx.random.randint(
low=0, high=num_tokens, shape=(4, 2048), dtype=mx.int32
)
else:
raise ValueError

return logits, input_ids


class HalvingLogitsProcessor(OutlinesLogitsProcessor):
"""Simply halve the passed logits"""

def process_logits(self, input_ids, logits):
return logits / 2


class LogitsProcessorBenchmark:
class LogitsProcessorPassthroughBenchmark:
"""
Benchmark the time it takes to convert between array frameworks
This should be on the order of microseconds
"""

params = ["torch", "numpy"]
if mx.metal.is_available():
if is_mlx_lm_allowed():
params += ["mlx"]
if torch.cuda.is_available():
params += ["torch_cuda"]

def setup(self, array_library):
self.logits_processor = HalvingLogitsProcessor()

# logits: (4, 30,000 ) dtype=float
# input_ids shape: (4, 2048) dtype=int
if array_library == "torch":
self.logits = torch.rand((4, 30000), dtype=torch.float)
self.input_ids = torch.randint(
low=0, high=30000, size=(4, 2048), dtype=torch.int
)
elif array_library == "numpy":
self.logits = np.random.rand(4, 30000).astype(np.float32)
self.input_ids = np.random.randint(low=0, high=30000, size=(4, 2048))
elif array_library == "mlx":
self.logits = mx.random.uniform(
low=-1e9, high=1e9, shape=(4, 30000), dtype=mx.float32
)
self.input_ids = mx.random.randint(
low=0, high=30000, shape=(4, 2048), dtype=mx.int32
)
else:
raise ValueError

def time_logits_processor(self, array_library):
self.logits, self.input_ids = get_mock_processor_inputs(array_library)

def time_passthrough(self, *params):
self.logits_processor(self.input_ids, self.logits)


class LogitsProcessorStructuredBenchmark:
"""
Benchmark structured generation mask application for single decoder pass
"""

array_libraries = ["torch", "numpy"]
if is_mlx_lm_allowed():
array_libraries += ["mlx"]
# PR TODO
if torch.cuda.is_available():
array_libraries += ["torch_cuda"]

# accept very many or very few tokens, respectively
patterns = [r"[^Z]*", "Z*"]

params = [array_libraries, patterns]
param_names = ["array_library, pattern"]

def setup(self, array_library, pattern):
tokenizer = models.transformers("facebook/opt-125m", device="cpu").tokenizer

self.logits_processor = RegexLogitsProcessor(pattern, tokenizer)

self.logits, self.input_ids = get_mock_processor_inputs(
array_library, len(tokenizer.vocabulary)
)

def time_structured_generation(self, array_library, pattern):
self.logits_processor(self.input_ids, self.logits)
10 changes: 7 additions & 3 deletions docs/community/contribute.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,21 @@ You can run the benchmark test suite locally with the following command:
asv run --config benchmarks/asv.conf.json
```

Run a specific test:
Caveats:
- If you're on a device with CUDA, you must add the argument `--launch-method spawn`
- Uncommitted code will not be benchmarked, you must first commit your changes.

#### Run a specific test:
```
asv run --config benchmarks/asv.conf.json -b bench_json_schema.JsonSchemaBenchmark.time_json_schema_to_fsm
```

Profile a specific test:
#### Profile a specific test:
```
asv run --config benchmarks/asv.conf.json --profile -b bench_json_schema.JsonSchemaBenchmark.time_json_schema_to_fsm
```

Compare to `origin/main`
#### Compare to `origin/main`
```
get fetch origin
asv continuous origin/main HEAD --config benchmarks/asv.conf.json
Expand Down
3 changes: 2 additions & 1 deletion docs/reference/prompting.md
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,8 @@ pretty print a dictionary from within an Outlines prompt function
def my_prompt(response_model):
"""{{ response_model | schema }}"""

my_prompt(MyResponse)
prompt = my_prompt(MyResponse)
print(prompt)
# {
# "field1": "an int",
# "field2": "<field2>"
Expand Down
16 changes: 12 additions & 4 deletions outlines/fsm/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import TYPE_CHECKING, List, Optional, Protocol, Tuple, Union

import interegular
import torch
from lark import Lark

from outlines import grammars
Expand Down Expand Up @@ -146,6 +147,13 @@ def __init__(self, regex_string: str, tokenizer):
self.eos_token_id = tokenizer.eos_token_id
self.final_states = fsm_finals | {-1}

# cache returned masks token masks
# this increases performance of the mask substantially
self.states_to_token_mask = {
state: torch.tensor(list(next_tokens_to_end_states.keys()))
for state, next_tokens_to_end_states in self.states_to_token_maps.items()
}

def get_next_instruction(self, state: int) -> Instruction:
"""Return the next instruction for guided generation.
Expand All @@ -169,11 +177,11 @@ def get_next_instruction(self, state: int) -> Instruction:
A `Generate` instance that contains the model and the allowed token ids.
"""
next_tokens_to_end_states = self.states_to_token_maps.get(state)
if next_tokens_to_end_states is None:
return Write([self.eos_token_id])
next_tokens_mask = self.states_to_token_mask.get(state)
if next_tokens_mask is None:
return Write(torch.tensor([self.eos_token_id]))

return Generate(list(next_tokens_to_end_states.keys()))
return Generate(next_tokens_mask)

def get_next_state(self, state: int, token_id: int) -> int:
"""Update the state of the guide.
Expand Down
4 changes: 2 additions & 2 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def __call__(

# We reshape the output to (batch_size, sample_size)
output: List[List[FormattedOutput]] = list()
for i in range(batch_size):
for i in range(0, batch_size * num_samples, num_samples):
output.append(formatted[i : i + num_samples])

# We remove leading dimensions for the output
Expand Down Expand Up @@ -372,7 +372,7 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]:
previously_generated_sequences = generated_sequences
# We reshape the output to (batch_size, sample_size)
output: List[List[str]] = list()
for i in range(batch_size):
for i in range(0, batch_size * num_samples, num_samples):
output.append(next_tokens[i : i + num_samples])

# We remove leading dimensions for the output
Expand Down
Loading

0 comments on commit df53cc4

Please sign in to comment.