Skip to content

Commit

Permalink
Add support for Mistral 7B (#313)
Browse files Browse the repository at this point in the history
* Add support for Mistral-7B.

* Fix prompt replacement.

* Update spacy_llm/models/hf/mistral.py
  • Loading branch information
rmitsch authored Oct 5, 2023
1 parent f2b4183 commit 8a2ea45
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 1 deletion.
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ langchain==0.0.302; python_version>="3.9"
openai>=0.27; python_version>="3.9"

# Necessary for running all local models on GPU.
transformers[sentencepiece]>=4.0.0,<4.30
transformers[sentencepiece]>=4.0.0
torch
einops>=0.4

Expand Down
2 changes: 2 additions & 0 deletions spacy_llm/models/hf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .dolly import dolly_hf
from .falcon import falcon_hf
from .llama2 import llama2_hf
from .mistral import mistral_hf
from .openllama import openllama_hf
from .stablelm import stablelm_hf

Expand All @@ -10,6 +11,7 @@
"dolly_hf",
"falcon_hf",
"llama2_hf",
"mistral_hf",
"openllama_hf",
"stablelm_hf",
]
99 changes: 99 additions & 0 deletions spacy_llm/models/hf/mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Any, Callable, Dict, Iterable, Optional, Tuple

from confection import SimpleFrozenDict

from ...compat import Literal, transformers
from ...registry.util import registry
from .base import HuggingFace


class Mistral(HuggingFace):
MODEL_NAMES = Literal["Mistral-7B-v0.1", "Mistral-7B-Instruct-v0.1"] # noqa: F722

def __init__(
self,
name: MODEL_NAMES,
config_init: Optional[Dict[str, Any]],
config_run: Optional[Dict[str, Any]],
):
self._tokenizer: Optional["transformers.AutoTokenizer"] = None
self._device: Optional[str] = None
self._is_instruct = "instruct" in name
super().__init__(name=name, config_init=config_init, config_run=config_run)

assert isinstance(self._tokenizer, transformers.PreTrainedTokenizerBase)

# Instantiate GenerationConfig object from config dict.
self._hf_config_run = transformers.GenerationConfig.from_pretrained(
self._name, **self._config_run
)
# To avoid deprecation warning regarding usage of `max_length`.
self._hf_config_run.max_new_tokens = self._hf_config_run.max_length

def init_model(self) -> Any:
self._tokenizer = transformers.AutoTokenizer.from_pretrained(self._name)
init_cfg = self._config_init
if "device" in init_cfg:
self._device = init_cfg.pop("device")

model = transformers.AutoModelForCausalLM.from_pretrained(
self._name, **init_cfg, resume_download=True
)
if self._device:
model.to(self._device)

return model

@property
def hf_account(self) -> str:
return "mistralai"

def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[override]
assert callable(self._tokenizer)
assert hasattr(self._model, "generate")
assert hasattr(self._tokenizer, "batch_decode")
prompts = list(prompts)

tokenized_input_ids = [
self._tokenizer(
prompt if not self._is_instruct else f"<s>[INST] {prompt} [/INST]",
return_tensors="pt",
).input_ids
for prompt in prompts
]
if self._device:
tokenized_input_ids = [tp.to(self._device) for tp in tokenized_input_ids]

return [
self._tokenizer.decode(
self._model.generate(
input_ids=tok_ii, generation_config=self._hf_config_run
)[:, tok_ii.shape[1] :][0],
skip_special_tokens=True,
)
for tok_ii in tokenized_input_ids
]

@staticmethod
def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]:
default_cfg_init, default_cfg_run = HuggingFace.compile_default_configs()
return (
default_cfg_init,
default_cfg_run,
)


@registry.llm_models("spacy.Mistral.v1")
def mistral_hf(
name: Mistral.MODEL_NAMES,
config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(),
config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(),
) -> Callable[[Iterable[str]], Iterable[str]]:
"""Generates Mistral instance that can execute a set of prompts and return the raw responses.
name (Literal): Name of the Falcon model. Has to be one of Falcon.get_model_names().
config_init (Optional[Dict[str, Any]]): HF config for initializing the model.
config_run (Optional[Dict[str, Any]]): HF config for running the model.
RETURNS (Callable[[Iterable[str]], Iterable[str]]): Falcon instance that can execute a set of prompts and return
the raw responses.
"""
return Mistral(name=name, config_init=config_init, config_run=config_run)
67 changes: 67 additions & 0 deletions spacy_llm/tests/models/test_mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import copy

import pytest
import spacy
from confection import Config # type: ignore[import]
from thinc.compat import has_torch_cuda_gpu

from ...compat import torch

_PIPE_CFG = {
"model": {
"@llm_models": "spacy.Mistral.v1",
"name": "Mistral-7B-v0.1",
},
"task": {"@llm_tasks": "spacy.NoOp.v1"},
}

_NLP_CONFIG = """
[nlp]
lang = "en"
pipeline = ["llm"]
batch_size = 128
[components]
[components.llm]
factory = "llm"
[components.llm.task]
@llm_tasks = "spacy.NoOp.v1"
[components.llm.model]
@llm_models = "spacy.Mistral.v1"
name = "Mistral-7B-v0.1"
"""


@pytest.mark.gpu
@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA")
def test_init():
"""Test initialization and simple run."""
nlp = spacy.blank("en")
cfg = copy.deepcopy(_PIPE_CFG)
nlp.add_pipe("llm", config=cfg)
nlp("This is a test.")
torch.cuda.empty_cache()


@pytest.mark.gpu
@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA")
def test_init_from_config():
orig_config = Config().from_str(_NLP_CONFIG)
nlp = spacy.util.load_model_from_config(orig_config, auto_fill=True)
assert nlp.pipe_names == ["llm"]
torch.cuda.empty_cache()


@pytest.mark.gpu
@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA")
def test_invalid_model():
orig_config = Config().from_str(_NLP_CONFIG)
config = copy.deepcopy(orig_config)
config["components"]["llm"]["model"]["name"] = "x"
with pytest.raises(ValueError, match="unexpected value; permitted"):
spacy.util.load_model_from_config(config, auto_fill=True)
torch.cuda.empty_cache()

0 comments on commit 8a2ea45

Please sign in to comment.