Skip to content

Commit

Permalink
fix models.mlxlm whitespace prefix handling
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed Jun 23, 2024
1 parent f20c774 commit 566fe3c
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 4 deletions.
9 changes: 8 additions & 1 deletion outlines/models/mlxlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,20 @@ def stream(
# https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L267
prompt_tokens = mx.array(self.mlx_tokenizer.encode(prompts))

detokenizer = self.mlx_tokenizer.detokenizer
detokenizer.reset()

for (token, prob), n in zip(
self.generate_step(prompt_tokens, **generate_kwargs),
range(max_tokens),
):
if token == self.tokenizer.eos_token_id:
break
yield self.tokenizer.decode([token])[0]
detokenizer.add_token(token)
yield detokenizer.last_segment

detokenizer.finalize()
yield detokenizer.last_segment

def generate_step(
self,
Expand Down
2 changes: 1 addition & 1 deletion tests/generate/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def pytest_collection_modifyitems(config, items):
for item in items:
if "model_fixture" in item.fixturenames:
model_param = item.callspec.params.get("model_fixture", None)
if model_param == "model_mlxlm":
if model_param.startswith("model_mlxlm"):
item.add_marker(skip_marker)


Expand Down
10 changes: 8 additions & 2 deletions tests/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,19 @@ def model_mlxlm(tmp_path_factory):
return models.mlxlm("mlx-community/TinyLlama-1.1B-Chat-v1.0-4bit")


@pytest.fixture(scope="session")
def model_mlxlm_phi3(tmp_path_factory):
return models.mlxlm("mlx-community/Phi-3-mini-4k-instruct-4bit")


@pytest.fixture(scope="session")
def model_transformers(tmp_path_factory):
return models.transformers("Locutusque/TinyMistral-248M-v2-Instruct", device="cpu")


@pytest.mark.parametrize(
"model_fixture",
("model_llamacpp", "model_mlxlm", "model_transformers"),
("model_llamacpp", "model_mlxlm", "model_transformers", "model_mlxlm_phi3"),
)
def test_generate_text(request, model_fixture):
model = request.getfixturevalue(model_fixture)
Expand All @@ -37,11 +42,12 @@ def test_generate_text(request, model_fixture):

@pytest.mark.parametrize(
"model_fixture",
("model_llamacpp", "model_mlxlm", "model_transformers"),
("model_llamacpp", "model_mlxlm", "model_transformers", "model_mlxlm_phi3"),
)
@pytest.mark.parametrize(
"pattern",
(
"a b c d e", # test model tokenizer whitespace prefix handling
"[0-9]",
"abc*",
"\\+?[1-9][0-9]{7,14}",
Expand Down

0 comments on commit 566fe3c

Please sign in to comment.