-
Notifications
You must be signed in to change notification settings - Fork 485
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
645 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# mlx-lm | ||
|
||
Outlines provides an integration with [mlx-lm](https://github.com/ml-explore/mlx-examples/tree/main/llms), allowing models to be run quickly on Apple Silicon via the [mlx](https://ml-explore.github.io/mlx/build/html/index.html) library. | ||
|
||
## Installation | ||
|
||
In addition to `outlines`, you must install `mlx-lm` and `mlx` libraries. You must use a device which [supports Metal](https://support.apple.com/en-us/102894). | ||
|
||
## Using `models.mlxlm` | ||
|
||
```python | ||
from outlines import models | ||
|
||
model = models.mlxlm("mlx-community/mlx-community/Meta-Llama-3-8B-Instruct-8bit") | ||
``` | ||
|
||
With the loaded model, you can generate text or perform structured generation, e.g. | ||
|
||
```python3 | ||
from outlines import models, generate | ||
|
||
model = models.mlxlm("mlx-community/Meta-Llama-3-8B-Instruct-8bit") | ||
|
||
phone_number_pattern = "\\+?[1-9][0-9]{7,14}" | ||
generator = generate.regex(model, phone_number_pattern) | ||
|
||
model_output = generator("What's Jennys Number?\n") | ||
print(model_output) | ||
# '8675309' | ||
``` | ||
|
||
For more examples, see the [cookbook](cookbook/index.md). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,240 @@ | ||
import dataclasses | ||
from typing import TYPE_CHECKING, Generator, Iterator, List, Optional, Tuple, Union | ||
|
||
from .transformers import TransformerTokenizer | ||
|
||
if TYPE_CHECKING: | ||
import mlx.core as mx | ||
import mlx.nn as nn | ||
from transformers import PreTrainedTokenizer | ||
|
||
from outlines.generate.api import GenerationParameters, SamplingParameters | ||
from outlines.processors import BaseLogitsProcessor | ||
|
||
|
||
class MLXLM: | ||
""" | ||
Represents an `mlx_lm` model | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model: "nn.Module", | ||
tokenizer: "PreTrainedTokenizer", | ||
): | ||
self.model = model | ||
self.mlx_tokenizer = tokenizer # returns mlx tensors, used for encode() | ||
self.tokenizer = TransformerTokenizer( | ||
tokenizer._tokenizer | ||
) # _tokenizer is HF Tokenizer | ||
|
||
def generate( | ||
self, | ||
prompts: Union[str, List[str]], | ||
generation_parameters: "GenerationParameters", | ||
logits_processor, | ||
sampling_parameters: "SamplingParameters", | ||
) -> str: | ||
streamer = self.stream( | ||
prompts, generation_parameters, logits_processor, sampling_parameters | ||
) | ||
return "".join(list(streamer)) | ||
|
||
def stream( | ||
self, | ||
prompts: Union[str, List[str]], | ||
generation_parameters: "GenerationParameters", | ||
logits_processor, | ||
sampling_parameters: "SamplingParameters", | ||
) -> Iterator[str]: | ||
"""Generate text using `mlx_lm`. | ||
Arguments | ||
--------- | ||
prompts | ||
A prompt or list of prompts. | ||
generation_parameters | ||
An instance of `GenerationParameters` that contains the prompt, | ||
the maximum number of tokens, stop sequences and seed. All the | ||
arguments to `SequenceGeneratorAdapter`'s `__cal__` method. | ||
logits_processor | ||
The logits processor to use when generating text. | ||
sampling_parameters | ||
An instance of `SamplingParameters`, a dataclass that contains | ||
the name of the sampler to use and related parameters as available | ||
in Outlines. | ||
Returns | ||
------- | ||
The generated text. | ||
""" | ||
import mlx.core as mx | ||
|
||
max_tokens, stop_at, seed = dataclasses.astuple(generation_parameters) | ||
sampler, num_samples, top_p, top_k, temperature = dataclasses.astuple( | ||
sampling_parameters | ||
) | ||
if max_tokens is None: | ||
max_tokens = int(1e9) | ||
|
||
if not isinstance(prompts, str): | ||
raise NotImplementedError( | ||
"The `mlx-lm` library does not support batch inference." | ||
) | ||
if sampler == "beam_search": | ||
raise NotImplementedError( | ||
"The `mlx-lm` library does not support Beam Search." | ||
) | ||
if num_samples != 1: | ||
raise NotImplementedError( | ||
"The `mlx-lm` library does not allow to take several samples." | ||
) | ||
if top_k is not None: | ||
raise NotImplementedError("The `mlx-lm` library does not support top_k.") | ||
if seed is not None: | ||
raise NotImplementedError("The `mlx-lm` library does not support seed.") | ||
if stop_at is not None: | ||
raise NotImplementedError("The `mlx-lm` library does not support stop_at.") | ||
|
||
generate_kwargs = { | ||
"temp": temperature, | ||
"top_p": top_p, | ||
"sampler": sampler, | ||
"logits_processor": logits_processor, | ||
} | ||
|
||
# Adapted from | ||
# https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L267 | ||
prompt_tokens = mx.array(self.mlx_tokenizer.encode(prompts)) | ||
|
||
for (token, prob), n in zip( | ||
self.generate_step(prompt_tokens, **generate_kwargs), | ||
range(max_tokens), | ||
): | ||
if token == self.tokenizer.eos_token_id: | ||
break | ||
yield self.tokenizer.decode([token])[0] | ||
|
||
def generate_step( | ||
self, | ||
prompt: "mx.array", | ||
temp: Optional[float], | ||
top_p: Optional[float], | ||
sampler: str, | ||
logits_processor: "BaseLogitsProcessor", | ||
) -> Generator[Tuple[int, float], None, None]: | ||
""" | ||
Adapted from | ||
https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L129 | ||
A generator producing token ids based on the given prompt from the model. | ||
Args: | ||
prompt (mx.array): The input prompt. | ||
temp (float): The temperature for sampling, if 0 the argmax is used. | ||
Default: ``0``. | ||
top_p (float, optional): Nulceus sampling, higher means model considers | ||
more less likely words. | ||
sampler (str): The sampler string defined by SequenceGeneratorAdapter | ||
logits_processor (BaseLogitsProcessor): Augment logits before sampling. | ||
""" | ||
import mlx.core as mx | ||
import mlx_lm | ||
|
||
temperature: float = temp or 1.0 | ||
|
||
def sample(logits: "mx.array") -> Tuple["mx.array", float]: | ||
softmax_logits = mx.softmax(logits) | ||
|
||
if temperature == 0.0 or sampler == "greedy": | ||
token = mx.argmax(logits, axis=-1) | ||
elif sampler == "multinomial": | ||
if top_p is not None and top_p > 0 and top_p < 1.0: | ||
token = mlx_lm.sample_utils.top_p_sampling( | ||
logits, top_p, temperature | ||
) | ||
else: | ||
token = mx.random.categorical(logits * (1 / temperature)) | ||
else: | ||
raise ValueError(f"Invalid mlx-lm sampler: `{sampler}`") | ||
|
||
prob = softmax_logits[0, token] | ||
return token, prob | ||
|
||
kv_heads = ( | ||
[self.model.n_kv_heads] * len(self.model.layers) | ||
if isinstance(self.model.n_kv_heads, int) | ||
else self.model.n_kv_heads | ||
) | ||
cache = [mlx_lm.models.base.KVCache(self.model.head_dim, n) for n in kv_heads] | ||
|
||
# kv cache contains processed input IDs, we pass the unprocessed inputs and cache to model() | ||
unprocessed_input_ids = prompt | ||
generated_ids: List[int] = [] | ||
|
||
while True: | ||
logits = self.model(unprocessed_input_ids[None], cache=cache) | ||
logits = logits[:, -1, :] | ||
|
||
if logits_processor is not None: | ||
# convert to logits_processor 1d expectation, apply, then convert back | ||
logits_1d = logits.reshape(-1) | ||
logits_1d = logits_processor(generated_ids, logits_1d) | ||
logits = logits_1d.reshape(1, -1) | ||
|
||
new_token_single, prob = sample(logits) | ||
new_token = new_token_single.item() | ||
yield new_token, prob | ||
|
||
generated_ids.append(new_token) | ||
unprocessed_input_ids = new_token_single | ||
|
||
|
||
def mlxlm( | ||
model_name: str, | ||
tokenizer_config: dict = {}, | ||
model_config: dict = {}, | ||
adapter_path: Optional[str] = None, | ||
lazy: bool = False, | ||
): | ||
"""Instantiate a model from the `mlx_lm` library and its tokenizer. | ||
Signature adapted from | ||
https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L422 | ||
Parameters | ||
---------- | ||
Args: | ||
path_or_hf_repo (Path): The path or the huggingface repository to load the model from. | ||
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer. | ||
Defaults to an empty dictionary. | ||
model_config(dict, optional): Configuration parameters specifically for the model. | ||
Defaults to an empty dictionary. | ||
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers | ||
to the model. Default: ``None``. | ||
lazy (bool): If False eval the model parameters to make sure they are | ||
loaded in memory before returning, otherwise they will be loaded | ||
when needed. Default: ``False`` | ||
Returns | ||
------- | ||
A `MLXLM` model instance. | ||
""" | ||
try: | ||
import mlx.core as mx | ||
import mlx_lm | ||
except ImportError: | ||
raise ImportError( | ||
"The `mlx_lm` library needs to be installed in order to use `mlx_lm` models." | ||
) | ||
if not mx.metal.is_available(): | ||
raise RuntimeError("You cannot use `mlx_lm` without Apple Silicon (Metal)") | ||
|
||
model, tokenizer = mlx_lm.load( | ||
model_name, | ||
tokenizer_config=tokenizer_config, | ||
model_config=model_config, | ||
adapter_path=adapter_path, | ||
lazy=lazy, | ||
) | ||
return MLXLM(model, tokenizer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .structured import ( | ||
BaseLogitsProcessor, | ||
CFGLogitsProcessor, | ||
FSMLogitsProcessor, | ||
JSONLogitsProcessor, | ||
RegexLogitsProcessor, | ||
) |
Oops, something went wrong.