-
Notifications
You must be signed in to change notification settings - Fork 426
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce outlines.models.mlxlm
#956
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# mlx-lm | ||
|
||
Outlines provides an integration with [mlx-lm](https://github.com/ml-explore/mlx-examples/tree/main/llms), allowing models to be run quickly on Apple Silicon via the [mlx](https://ml-explore.github.io/mlx/build/html/index.html) library. | ||
|
||
## Installation | ||
|
||
In addition to `outlines`, you must install `mlx-lm` and `mlx` libraries. You must use a device which [supports Metal](https://support.apple.com/en-us/102894). | ||
|
||
## Using `models.mlxlm` | ||
|
||
```python | ||
from outlines import models | ||
|
||
model = models.mlxlm("mlx-community/mlx-community/Meta-Llama-3-8B-Instruct-8bit") | ||
``` | ||
|
||
With the loaded model, you can generate text or perform structured generation, e.g. | ||
|
||
```python3 | ||
from outlines import models, generate | ||
|
||
model = models.mlxlm("mlx-community/Meta-Llama-3-8B-Instruct-8bit") | ||
|
||
phone_number_pattern = "\\+?[1-9][0-9]{7,14}" | ||
generator = generate.regex(model, phone_number_pattern) | ||
|
||
model_output = generator("What's Jennys Number?\n") | ||
print(model_output) | ||
# '8675309' | ||
``` | ||
|
||
For more examples, see the [cookbook](cookbook/index.md). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,240 @@ | ||
import dataclasses | ||
from typing import TYPE_CHECKING, Generator, Iterator, List, Optional, Tuple, Union | ||
|
||
from .transformers import TransformerTokenizer | ||
|
||
if TYPE_CHECKING: | ||
import mlx.core as mx | ||
import mlx.nn as nn | ||
from transformers import PreTrainedTokenizer | ||
|
||
from outlines.generate.api import GenerationParameters, SamplingParameters | ||
from outlines.processors import BaseLogitsProcessor | ||
|
||
|
||
class MLXLM: | ||
""" | ||
Represents an `mlx_lm` model | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model: "nn.Module", | ||
tokenizer: "PreTrainedTokenizer", | ||
): | ||
self.model = model | ||
self.mlx_tokenizer = tokenizer # returns mlx tensors, used for encode() | ||
self.tokenizer = TransformerTokenizer( | ||
tokenizer._tokenizer | ||
) # _tokenizer is HF Tokenizer | ||
|
||
def generate( | ||
self, | ||
prompts: Union[str, List[str]], | ||
generation_parameters: "GenerationParameters", | ||
logits_processor, | ||
sampling_parameters: "SamplingParameters", | ||
) -> str: | ||
streamer = self.stream( | ||
prompts, generation_parameters, logits_processor, sampling_parameters | ||
) | ||
return "".join(list(streamer)) | ||
|
||
def stream( | ||
self, | ||
prompts: Union[str, List[str]], | ||
generation_parameters: "GenerationParameters", | ||
logits_processor, | ||
sampling_parameters: "SamplingParameters", | ||
) -> Iterator[str]: | ||
"""Generate text using `mlx_lm`. | ||
|
||
Arguments | ||
--------- | ||
prompts | ||
A prompt or list of prompts. | ||
generation_parameters | ||
An instance of `GenerationParameters` that contains the prompt, | ||
the maximum number of tokens, stop sequences and seed. All the | ||
arguments to `SequenceGeneratorAdapter`'s `__cal__` method. | ||
logits_processor | ||
The logits processor to use when generating text. | ||
sampling_parameters | ||
An instance of `SamplingParameters`, a dataclass that contains | ||
the name of the sampler to use and related parameters as available | ||
in Outlines. | ||
Returns | ||
------- | ||
The generated text. | ||
""" | ||
import mlx.core as mx | ||
|
||
max_tokens, stop_at, seed = dataclasses.astuple(generation_parameters) | ||
sampler, num_samples, top_p, top_k, temperature = dataclasses.astuple( | ||
sampling_parameters | ||
) | ||
if max_tokens is None: | ||
max_tokens = int(1e9) | ||
|
||
if not isinstance(prompts, str): | ||
raise NotImplementedError( | ||
"The `mlx-lm` library does not support batch inference." | ||
) | ||
if sampler == "beam_search": | ||
raise NotImplementedError( | ||
"The `mlx-lm` library does not support Beam Search." | ||
) | ||
if num_samples != 1: | ||
raise NotImplementedError( | ||
"The `mlx-lm` library does not allow to take several samples." | ||
) | ||
if top_k is not None: | ||
raise NotImplementedError("The `mlx-lm` library does not support top_k.") | ||
if seed is not None: | ||
raise NotImplementedError("The `mlx-lm` library does not support seed.") | ||
if stop_at is not None: | ||
raise NotImplementedError("The `mlx-lm` library does not support stop_at.") | ||
|
||
generate_kwargs = { | ||
"temp": temperature, | ||
"top_p": top_p, | ||
"sampler": sampler, | ||
"logits_processor": logits_processor, | ||
} | ||
|
||
# Adapted from | ||
# https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L267 | ||
prompt_tokens = mx.array(self.mlx_tokenizer.encode(prompts)) | ||
|
||
for (token, prob), n in zip( | ||
self.generate_step(prompt_tokens, **generate_kwargs), | ||
range(max_tokens), | ||
): | ||
if token == self.tokenizer.eos_token_id: | ||
break | ||
yield self.tokenizer.decode([token])[0] | ||
|
||
def generate_step( | ||
self, | ||
prompt: "mx.array", | ||
temp: Optional[float], | ||
top_p: Optional[float], | ||
sampler: str, | ||
logits_processor: "BaseLogitsProcessor", | ||
) -> Generator[Tuple[int, float], None, None]: | ||
""" | ||
Adapted from | ||
https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L129 | ||
|
||
A generator producing token ids based on the given prompt from the model. | ||
|
||
Args: | ||
prompt (mx.array): The input prompt. | ||
temp (float): The temperature for sampling, if 0 the argmax is used. | ||
Default: ``0``. | ||
top_p (float, optional): Nulceus sampling, higher means model considers | ||
more less likely words. | ||
sampler (str): The sampler string defined by SequenceGeneratorAdapter | ||
logits_processor (BaseLogitsProcessor): Augment logits before sampling. | ||
""" | ||
import mlx.core as mx | ||
import mlx_lm | ||
|
||
temperature: float = temp or 1.0 | ||
|
||
def sample(logits: "mx.array") -> Tuple["mx.array", float]: | ||
softmax_logits = mx.softmax(logits) | ||
|
||
if temperature == 0.0 or sampler == "greedy": | ||
token = mx.argmax(logits, axis=-1) | ||
elif sampler == "multinomial": | ||
if top_p is not None and top_p > 0 and top_p < 1.0: | ||
token = mlx_lm.sample_utils.top_p_sampling( | ||
logits, top_p, temperature | ||
) | ||
else: | ||
token = mx.random.categorical(logits * (1 / temperature)) | ||
else: | ||
raise ValueError(f"Invalid mlx-lm sampler: `{sampler}`") | ||
|
||
prob = softmax_logits[0, token] | ||
return token, prob | ||
|
||
kv_heads = ( | ||
[self.model.n_kv_heads] * len(self.model.layers) | ||
if isinstance(self.model.n_kv_heads, int) | ||
else self.model.n_kv_heads | ||
) | ||
cache = [mlx_lm.models.base.KVCache(self.model.head_dim, n) for n in kv_heads] | ||
|
||
# kv cache contains processed input IDs, we pass the unprocessed inputs and cache to model() | ||
unprocessed_input_ids = prompt | ||
generated_ids: List[int] = [] | ||
|
||
while True: | ||
logits = self.model(unprocessed_input_ids[None], cache=cache) | ||
logits = logits[:, -1, :] | ||
|
||
if logits_processor is not None: | ||
# convert to logits_processor 1d expectation, apply, then convert back | ||
logits_1d = logits.reshape(-1) | ||
logits_1d = logits_processor(generated_ids, logits_1d) | ||
logits = logits_1d.reshape(1, -1) | ||
|
||
new_token_single, prob = sample(logits) | ||
new_token = new_token_single.item() | ||
yield new_token, prob | ||
|
||
generated_ids.append(new_token) | ||
unprocessed_input_ids = new_token_single | ||
|
||
|
||
def mlxlm( | ||
model_name: str, | ||
tokenizer_config: dict = {}, | ||
model_config: dict = {}, | ||
adapter_path: Optional[str] = None, | ||
lazy: bool = False, | ||
): | ||
"""Instantiate a model from the `mlx_lm` library and its tokenizer. | ||
|
||
Signature adapted from | ||
https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L422 | ||
|
||
Parameters | ||
---------- | ||
Args: | ||
path_or_hf_repo (Path): The path or the huggingface repository to load the model from. | ||
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer. | ||
Defaults to an empty dictionary. | ||
model_config(dict, optional): Configuration parameters specifically for the model. | ||
Defaults to an empty dictionary. | ||
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers | ||
to the model. Default: ``None``. | ||
lazy (bool): If False eval the model parameters to make sure they are | ||
loaded in memory before returning, otherwise they will be loaded | ||
when needed. Default: ``False`` | ||
|
||
Returns | ||
------- | ||
A `MLXLM` model instance. | ||
|
||
""" | ||
try: | ||
import mlx.core as mx | ||
import mlx_lm | ||
except ImportError: | ||
raise ImportError( | ||
"The `mlx_lm` library needs to be installed in order to use `mlx_lm` models." | ||
) | ||
if not mx.metal.is_available(): | ||
raise RuntimeError("You cannot use `mlx_lm` without Apple Silicon (Metal)") | ||
|
||
model, tokenizer = mlx_lm.load( | ||
model_name, | ||
tokenizer_config=tokenizer_config, | ||
model_config=model_config, | ||
adapter_path=adapter_path, | ||
lazy=lazy, | ||
) | ||
return MLXLM(model, tokenizer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .structured import ( | ||
BaseLogitsProcessor, | ||
CFGLogitsProcessor, | ||
FSMLogitsProcessor, | ||
JSONLogitsProcessor, | ||
RegexLogitsProcessor, | ||
) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mlx-community is repeated twice