Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update of MLX-LM generate_step to support repetition_penalty #1134

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 74 additions & 7 deletions outlines/models/mlxlm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
import dataclasses
from typing import TYPE_CHECKING, Generator, Iterator, List, Optional, Tuple, Union
from typing import (
TYPE_CHECKING,
Iterator,
List,
Optional,
Tuple,
TypedDict,
Union,
Generator,
)

from typing_extensions import Unpack

from .transformers import TransformerTokenizer

Expand All @@ -12,6 +23,14 @@
from outlines.processors import OutlinesLogitsProcessor



class MLXLMParams(TypedDict, total=False):
top_p: float # so top_p can be passed as a parameter to generate() without defining a sampler
repetition_penalty: float
repetition_context_size: int



class MLXLM:
"""
Represents an `mlx_lm` model
Expand All @@ -28,24 +47,28 @@ def __init__(
tokenizer._tokenizer
) # _tokenizer is HF Tokenizer


def generate(
self,
prompts: Union[str, List[str]],
generation_parameters: "GenerationParameters",
logits_processor,
sampling_parameters: "SamplingParameters",
**mlx_lm_params: Unpack[MLXLMParams],
) -> str:
streamer = self.stream(
prompts, generation_parameters, logits_processor, sampling_parameters
prompts, generation_parameters, logits_processor, sampling_parameters, **mlx_lm_params
)
return "".join(list(streamer))


def stream(
self,
prompts: Union[str, List[str]],
generation_parameters: "GenerationParameters",
logits_processor,
sampling_parameters: "SamplingParameters",
**mlx_lm_params: Unpack[MLXLMParams],
) -> Iterator[str]:
"""Generate text using `mlx_lm`.

Expand All @@ -63,6 +86,9 @@ def stream(
An instance of `SamplingParameters`, a dataclass that contains
the name of the sampler to use and related parameters as available
in Outlines.
mlx_lm_params
Of type `MLXLMParams`.

Returns
-------
The generated text.
Expand Down Expand Up @@ -100,6 +126,7 @@ def stream(
"top_p": top_p,
"sampler": sampler,
"logits_processor": logits_processor,
**mlx_lm_params
}

# Adapted from
Expand All @@ -121,40 +148,61 @@ def stream(
detokenizer.finalize()
yield detokenizer.last_segment


def generate_step(
self,
prompt: "mx.array",
temp: Optional[float],
top_p: Optional[float],
sampler: str,
logits_processor: "OutlinesLogitsProcessor",
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = 20,
) -> Generator[Tuple[int, float], None, None]:
"""
Adapted from
https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L129
https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L129
and updated (on Sept 2024 to add repetition_* args) from
https://github.com/ml-explore/mlx-examples/blob/bd29aec299c8fa59c161a9c1207bfc59db31d845/llms/mlx_lm/utils.py#L149

A generator producing token ids based on the given prompt from the model.

Args:
prompt (mx.array): The input prompt.
temp (float): The temperature for sampling, if 0 the argmax is used.
Default: ``0``.
Default: ``0``.
top_p (float, optional): Nulceus sampling, higher means model considers
more less likely words.
more less likely words.
sampler (str): The sampler string defined by SequenceGeneratorAdapter
logits_processor (OutlinesLogitsProcessor): Augment logits before sampling.
repetition_penalty (float, optional): The penalty factor for repeating tokens.
1.0 for no penalty. >1.0 for penalty. Default: ``None``.
repetition_context_size (int, optional): The number of tokens to
consider for repetition penalty. Default: ``20``.

Yields:
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
one token and a vector of log probabilities.
"""
import mlx.core as mx
import mlx_lm

temperature: float = temp or 1.0
if repetition_penalty:
if not isinstance(repetition_penalty, float) or repetition_penalty <= 0:
raise ValueError(
f"repetition_penalty must be a non-negative float, got {repetition_penalty}" )
if not isinstance(repetition_context_size, int) or repetition_context_size <= 2:
raise ValueError(
f"repetition_context_size must be a positive integer > 2, got {repetition_context_size}" )


def sample(logits: "mx.array") -> Tuple["mx.array", float]:
softmax_logits = mx.softmax(logits)

if temperature == 0.0 or sampler == "greedy":
if temp == 0.0 or sampler == "greedy": # temp == 0, not temperature, which can never be 0
token = mx.argmax(logits, axis=-1)
elif sampler == "multinomial":
temperature: float = temp or 1.0
if top_p is not None and top_p > 0 and top_p < 1.0:
token = mlx_lm.sample_utils.top_p_sampling(
logits, top_p, temperature
Expand All @@ -167,13 +215,22 @@ def sample(logits: "mx.array") -> Tuple["mx.array", float]:
prob = softmax_logits[0, token]
return token, prob


# Create the KV cache for generation
kv_heads = (
[self.model.n_kv_heads] * len(self.model.layers)
if isinstance(self.model.n_kv_heads, int)
else self.model.n_kv_heads
)
cache = [mlx_lm.models.base.KVCache(self.model.head_dim, n) for n in kv_heads]


# Init the repetition context
repetition_context = prompt.tolist()
if repetition_context_size:
repetition_context = repetition_context[-repetition_context_size:]


# kv cache contains processed input IDs, we pass the unprocessed inputs and cache to model()
unprocessed_input_ids = prompt
generated_ids: List[int] = []
Expand All @@ -182,6 +239,10 @@ def sample(logits: "mx.array") -> Tuple["mx.array", float]:
logits = self.model(unprocessed_input_ids[None], cache=cache)
logits = logits[:, -1, :]

if repetition_penalty:
logits = mlx_lm.utils.apply_repetition_penalty(
logits, repetition_context, repetition_penalty )

if logits_processor is not None:
# convert to logits_processor 1d expectation, apply, then convert back
logits_1d = logits.reshape(-1)
Expand All @@ -191,11 +252,17 @@ def sample(logits: "mx.array") -> Tuple["mx.array", float]:
new_token_single, prob = sample(logits)
new_token = new_token_single.item()
yield new_token, prob

if repetition_penalty:
repetition_context.append(new_token)
if repetition_context_size and len(repetition_context) > repetition_context_size:
repetition_context = repetition_context[-repetition_context_size:]

generated_ids.append(new_token)
unprocessed_input_ids = new_token_single



def mlxlm(
model_name: str,
tokenizer_config: dict = {},
Expand Down
Loading