-___
+---
`xTuring` provides fast, efficient and simple fine-tuning of LLMs, such as LLaMA, GPT-J, Galactica, and more.
By providing an easy-to-use interface for fine-tuning LLMs to your own data and application, xTuring makes it
@@ -25,6 +25,7 @@ simple to build, customize and control LLMs. The entire process can be done insi
private cloud, ensuring data privacy and security.
With `xTuring` you can,
+
- Ingest data from different sources and preprocess them to a format LLMs can understand
- Scale from single to multiple GPUs for faster fine-tuning
- Leverage memory-efficient methods (i.e. INT4, LoRA fine-tuning) to reduce hardware costs by up to 90%
@@ -34,8 +35,11 @@ With `xTuring` you can,
## ๐ What's new?
+
We are excited to announce the latest enhancements to our `xTuring` library:
-1. __`LLaMA 2` integration__ - You can use and fine-tune the _`LLaMA 2`_ model in different configurations: _off-the-shelf_, _off-the-shelf with INT8 precision_, _LoRA fine-tuning_, _LoRA fine-tuning with INT8 precision_ and _LoRA fine-tuning with INT4 precision_ using the `GenericModel` wrapper and/or you can use the `Llama2` class from `xturing.models` to test and finetune the model.
+
+1. **`LLaMA 2` integration** - You can use and fine-tune the _`LLaMA 2`_ model in different configurations: _off-the-shelf_, _off-the-shelf with INT8 precision_, _LoRA fine-tuning_, _LoRA fine-tuning with INT8 precision_ and _LoRA fine-tuning with INT4 precision_ using the `GenericModel` wrapper and/or you can use the `Llama2` class from `xturing.models` to test and finetune the model.
+
```python
from xturing.models import Llama2
model = Llama2()
@@ -45,7 +49,9 @@ from xturing.models import BaseModel
model = BaseModel.create('llama2')
```
-2. __`Evaluation`__ - Now you can evaluate any `Causal Language Model` on any dataset. The metrics currently supported is [`perplexity`](https://towardsdatascience.com/perplexity-in-language-models-87a196019a94).
+
+2. **`Evaluation`** - Now you can evaluate any `Causal Language Model` on any dataset. The metrics currently supported is [`perplexity`](https://towardsdatascience.com/perplexity-in-language-models-87a196019a94).
+
```python
# Make the necessary imports
from xturing.datasets import InstructionDataset
@@ -64,7 +70,9 @@ result = model.evaluate(dataset)
print(f"Perplexity of the evalution: {result}")
```
-3. __`INT4` Precision__ - You can now use and fine-tune any LLM with `INT4 Precision` using `GenericLoraKbitModel`.
+
+3. **`INT4` Precision** - You can now use and fine-tune any LLM with `INT4 Precision` using `GenericLoraKbitModel`.
+
```python
# Make the necessary imports
from xturing.datasets import InstructionDataset
@@ -80,7 +88,7 @@ model = GenericLoraKbitModel('tiiuae/falcon-7b')
model.finetune(dataset)
```
-4. __CPU inference__ - The CPU, including laptop CPUs, is now fully equipped to handle LLM inference. We integrated [Intelยฎ Extension for Transformers](https://github.com/intel/intel-extension-for-transformers) to conserve memory by compressing the model with [weight-only quantization algorithms](https://github.com/intel/intel-extension-for-transformers/blob/main/docs/weightonlyquant.md) and accelerate the inference by leveraging its highly optimized kernel on Intel platforms.
+4. **CPU inference** - The CPU, including laptop CPUs, is now fully equipped to handle LLM inference. We integrated [Intelยฎ Extension for Transformers](https://github.com/intel/intel-extension-for-transformers) to conserve memory by compressing the model with [weight-only quantization algorithms](https://github.com/intel/intel-extension-for-transformers/blob/main/docs/weightonlyquant.md) and accelerate the inference by leveraging its highly optimized kernel on Intel platforms.
```python
# Make the necessary imports
@@ -95,7 +103,8 @@ output = model.generate(texts=["Why LLM models are becoming so important?"])
print(output)
```
-5. __Batch integration__ - By tweaking the 'batch_size' in the .generate() and .evaluate() functions, you can expedite results. Using a 'batch_size' greater than 1 typically enhances processing efficiency.
+5. **Batch integration** - By tweaking the 'batch_size' in the .generate() and .evaluate() functions, you can expedite results. Using a 'batch_size' greater than 1 typically enhances processing efficiency.
+
```python
# Make the necessary imports
from xturing.datasets import InstructionDataset
@@ -119,6 +128,7 @@ For an extended insight, consider examining the [GenericModel working example](e
## โ๏ธ Installation
+
```bash
pip install xturing
```
@@ -151,6 +161,7 @@ You can find the data folder [here](examples/models/llama/alpaca_data).
## CLI playground
+
```bash
@@ -159,6 +170,7 @@ $ xturing chat -m ""
```
## UI playground
+
```python
@@ -180,6 +192,7 @@ Playground().launch() ## launches localhost UI
## ๐ Tutorials
+
- [Preparing your dataset](examples/datasets/preparing_your_dataset.py)
- [Cerebras-GPT fine-tuning with LoRA and INT8](examples/models/cerebras/cerebras_lora_int8.ipynb) [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eKq3oF7dnK8KuIfsTE70Gvvniwr1O9D0?usp=sharing)
- [Cerebras-GPT fine-tuning with LoRA](examples/models/cerebras/cerebras_lora.ipynb) [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1VjqQhstm5pT4EjPjx4Je7b3W2X1V3vDo?usp=sharing)
@@ -209,17 +222,18 @@ Fine-tuning parameters:
}
```
-| LLaMA-7B | DeepSpeed + CPU Offloading | LoRA + DeepSpeed | LoRA + DeepSpeed + CPU Offloading |
-| :---------: | :----: | :----: | :----: |
-| GPU | 33.5 GB | 23.7 GB | 21.9 GB |
-| CPU | 190 GB | 10.2 GB | 14.9 GB |
-| Time/epoch | 21 hours | 20 mins | 20 mins |
+| LLaMA-7B | DeepSpeed + CPU Offloading | LoRA + DeepSpeed | LoRA + DeepSpeed + CPU Offloading |
+| :--------: | :------------------------: | :--------------: | :-------------------------------: |
+| GPU | 33.5 GB | 23.7 GB | 21.9 GB |
+| CPU | 190 GB | 10.2 GB | 14.9 GB |
+| Time/epoch | 21 hours | 20 mins | 20 mins |
Contribute to this by submitting your performance results on other GPUs by creating an issue with your hardware specifications, memory consumption and time per epoch.
## ๐ Fine-tuned model checkpoints
+
We have already fine-tuned some models that you can use as your base or start playing with.
Here is how you would load them:
@@ -228,44 +242,49 @@ from xturing.models import BaseModel
model = BaseModel.load("x/distilgpt2_lora_finetuned_alpaca")
```
-| model | dataset | Path |
-|---------------------|--------|---------------|
-| DistilGPT-2 LoRA | alpaca | `x/distilgpt2_lora_finetuned_alpaca` |
-| LLaMA LoRA | alpaca | `x/llama_lora_finetuned_alpaca` |
+| model | dataset | Path |
+| ---------------- | ------- | ------------------------------------ |
+| DistilGPT-2 LoRA | alpaca | `x/distilgpt2_lora_finetuned_alpaca` |
+| LLaMA LoRA | alpaca | `x/llama_lora_finetuned_alpaca` |
## Supported Models
+
Below is a list of all the supported models via `BaseModel` class of `xTuring` and their corresponding keys to load them.
-| Model | Key |
-| -- | -- |
-|Bloom | bloom|
-|Cerebras | cerebras|
-|DistilGPT-2 | distilgpt2|
-|Falcon-7B | falcon|
-|Galactica | galactica|
-|GPT-J | gptj|
-|GPT-2 | gpt2|
-|LlaMA | llama|
-|LlaMA2 | llama2|
-|OPT-1.3B | opt|
+| Model | Key |
+| ----------- | ---------- |
+| Bloom | bloom |
+| Cerebras | cerebras |
+| DistilGPT-2 | distilgpt2 |
+| Falcon-7B | falcon |
+| Galactica | galactica |
+| GPT-J | gptj |
+| GPT-2 | gpt2 |
+| LlaMA | llama |
+| LlaMA2 | llama2 |
+| OPT-1.3B | opt |
+| Mistral-7B | mistral |
The above mentioned are the base variants of the LLMs. Below are the templates to get their `LoRA`, `INT8`, `INT8 + LoRA` and `INT4 + LoRA` versions.
-| Version | Template |
-| -- | -- |
-| LoRA| _lora|
-| INT8| _int8|
-| INT8 + LoRA| _lora_int8|
+| Version | Template |
+| ----------- | ---------------------- |
+| LoRA | \_lora |
+| INT8 | \_int8 |
+| INT8 + LoRA | \_lora_int8 |
+
+\*\* In order to load any model's **`INT4+LoRA`** version, you will need to make use of `GenericLoraKbitModel` class from `xturing.models`. Below is how to use it:
-** In order to load any model's __`INT4+LoRA`__ version, you will need to make use of `GenericLoraKbitModel` class from `xturing.models`. Below is how to use it:
```python
model = GenericLoraKbitModel('')
```
+
The `model_path` can be replaced with you local directory or any HuggingFace library model like `facebook/opt-1.3b`.
## ๐ Roadmap
+
- [x] Support for `LLaMA`, `GPT-J`, `GPT-2`, `OPT`, `Cerebras-GPT`, `Galactica` and `Bloom` models
- [x] Dataset generation using self-instruction
- [x] Low-precision LoRA fine-tuning and unsupervised fine-tuning
@@ -284,6 +303,7 @@ The `model_path` can be replaced with you local directory or any HuggingFace lib
## ๐ค Help and Support
+
If you have any questions, you can create an issue on this repository.
You can also join our [Discord server](https://discord.gg/TgHXuSJEk6) and start a discussion in the `#xturing` channel.
@@ -291,9 +311,11 @@ You can also join our [Discord server](https://discord.gg/TgHXuSJEk6) and start
## ๐ License
+
This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details.
## ๐ Contributing
+
As an open source project in a rapidly evolving field, we welcome contributions of all kinds, including new features and better documentation. Please read our [contributing guide](CONTRIBUTING.md) to learn how you can get involved.
diff --git a/docs/docs/overview/quickstart/test.jsx b/docs/docs/overview/quickstart/test.jsx
index c738db8..7464e8f 100644
--- a/docs/docs/overview/quickstart/test.jsx
+++ b/docs/docs/overview/quickstart/test.jsx
@@ -1,57 +1,56 @@
-import React, { useEffect, useState } from 'react'
-import clsx from 'clsx'
-import MDXContent from '@theme/MDXContent'
-import CodeBlock from '@theme/CodeBlock'
+import React, { useEffect, useState } from "react";
+import clsx from "clsx";
+import MDXContent from "@theme/MDXContent";
+import CodeBlock from "@theme/CodeBlock";
const trainingTechniques = {
- base: 'Base',
- lora: 'LoRA',
- lora_int8: 'LoRA INT8',
- int8: 'INT8',
-}
+ base: "Base",
+ lora: "LoRA",
+ lora_int8: "LoRA INT8",
+ int8: "INT8",
+};
const modelList = {
- bloom: 'BLOOM',
- cerebras: 'Cerebras',
- distilgpt2: 'DistilGPT-2',
- galactica: 'Galactica',
- gptj: 'GPT-J',
- gpt2: 'GPT-2',
- llama: 'LLaMA',
- llama2: 'LLaMA 2',
- opt: 'OPT',
-}
+ bloom: "BLOOM",
+ cerebras: "Cerebras",
+ distilgpt2: "DistilGPT-2",
+ galactica: "Galactica",
+ gptj: "GPT-J",
+ gpt2: "GPT-2",
+ llama: "LLaMA",
+ llama2: "LLaMA 2",
+ opt: "OPT",
+ mistral: "Mistral",
+};
-export default function Test(
- {instruction}
-) {
+export default function Test({ instruction }) {
// const [code, setCode] = useState('llama');
const [code, setCode] = useState({
- model: '',
- technique: 'base',
- })
+ model: "",
+ technique: "base",
+ });
- let finalKey = ''
- if (code.technique === 'base') {
- finalKey = `${code.model}`
+ let finalKey = "";
+ if (code.technique === "base") {
+ finalKey = `${code.model}`;
} else {
- finalKey = `${code.model}_${code.technique}`
+ finalKey = `${code.model}_${code.technique}`;
}
-
+
useEffect(() => {
setCode({
- model: 'llama',
- technique: 'base'
+ model: "llama",
+ technique: "base",
});
}, []);
return (
-
-
+
+
-
+
- )
-}
\ No newline at end of file
+ );
+}
diff --git a/docs/docs/overview/supported_models.md b/docs/docs/overview/supported_models.md
index 02932dd..17de8bb 100644
--- a/docs/docs/overview/supported_models.md
+++ b/docs/docs/overview/supported_models.md
@@ -5,32 +5,39 @@ description: Models Supported by xTuring
---
+
## Base versions
-| Model | Model Key | LoRA | INT8 | LoRA + INT8 | LoRA + INT4 |
-| ------ | --- | :---: | :---: | :---: | :---: |
-| BLOOM 1.1B| bloom | โ | โ | โ | โ |
-| Cerebras 1.3B| cerebras | โ | โ | โ | โ |
-| DistilGPT-2 | distilgpt2 | โ | โ | โ | โ |
-| Falcon 7B | falcon | โ | โ | โ | โ |
-| Galactica 6.7B| galactica | โ | โ | โ | โ |
-| GPT-J 6B | gptj | โ | โ | โ | โ |
-| GPT-2 | gpt2 | โ | โ | โ | โ |
-| LLaMA 7B | llama | โ | โ | โ | โ |
-| LLaMA2 | llama2 | โ | โ | โ | โ |
-| OPT 1.3B | opt | โ | โ | โ | โ |
+
+| Model | Model Key | LoRA | INT8 | LoRA + INT8 | LoRA + INT4 |
+| -------------- | ---------- | :--: | :--: | :---------: | :---------: |
+| BLOOM 1.1B | bloom | โ | โ | โ | โ |
+| Cerebras 1.3B | cerebras | โ | โ | โ | โ |
+| DistilGPT-2 | distilgpt2 | โ | โ | โ | โ |
+| Falcon 7B | falcon | โ | โ | โ | โ |
+| Galactica 6.7B | galactica | โ | โ | โ | โ |
+| GPT-J 6B | gptj | โ | โ | โ | โ |
+| GPT-2 | gpt2 | โ | โ | โ | โ |
+| LLaMA 7B | llama | โ | โ | โ | โ |
+| LLaMA2 | llama2 | โ | โ | โ | โ |
+| OPT 1.3B | opt | โ | โ | โ | โ |
+| Mistral 7B | mistral | โ | โ | โ | โ |
### Memory-efficient versions
+
> The above mentioned are the base variants of the LLMs. Below are the templates to get their `LoRA`, `INT8`, `INT8 + LoRA` and `INT4 + LoRA` versions.
-| Version | Template |
-| -- | -- |
-| LoRA | _lora|
-| INT8 | _int8|
-| INT8 + LoRA | _lora_int8|
+| Version | Template |
+| ----------- | ---------------------- |
+| LoRA | \_lora |
+| INT8 | \_int8 |
+| INT8 + LoRA | \_lora_int8 |
### INT4 Precision model versions
-> In order to load any model's __`INT4+LoRA`__ version, you will need to make use of `GenericLoraKbitModel` class from `xturing.models`. Below is how to use it:
+
+> In order to load any model's **`INT4+LoRA`** version, you will need to make use of `GenericLoraKbitModel` class from `xturing.models`. Below is how to use it:
+
```python
model = GenericLoraKbitModel('/path/to/model')
```
+
The `/path/to/model` can be replaced with you local directory or any HuggingFace library model like `facebook/opt-1.3b`.
diff --git a/examples/README.md b/examples/README.md
index 3b9026b..11abef9 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -1,7 +1,9 @@
# Navigating through examples
-Here, is a brief about how to navigate through examples quick and efficiently, and get your hands dirty with `xTuring`.
+
+Here, is a brief about how to navigate through examples quick and efficiently, and get your hands dirty with `xTuring`.
## Directory structure
+
```
examples/
| datasets
@@ -15,44 +17,52 @@ examples/
```
### datsets/
-This directory consists of multiple ways to generate your custom dataset from a given set of examples.
+
+This directory consists of multiple ways to generate your custom dataset from a given set of examples.
### features/
-This directory consists of files with exapmles highlighting speific major features of the library, which can be replicated to any LLM you want.
-For example, in `dataset_generation/`, you will find an example on how to generate your custom dataset from a .jsonl file. In `evaluation/`, you will find a specific exapmle on how to evaluate your finetuned model, which can then be extended to any LLM and any dataset.
+
+This directory consists of files with exapmles highlighting speific major features of the library, which can be replicated to any LLM you want.
+For example, in `dataset_generation/`, you will find an example on how to generate your custom dataset from a .jsonl file. In `evaluation/`, you will find a specific exapmle on how to evaluate your finetuned model, which can then be extended to any LLM and any dataset.
### models/
-This directory consists of examples specific to each model mentioned.
+
+This directory consists of examples specific to each model mentioned.
### playground_ui/
+
This directory consists of an example which demonstrates how you can play around with your LLM through a web interface.
## Models
+
Below is a list of all the supported models via `BaseModel` class of `xTuring` and their corresponding keys to load them.
-| Model | Key |
-| -- | -- |
-|Bloom | bloom|
-|Cerebras | cerebras|
-|DistilGPT-2 | distilgpt2|
-|Falcon-7B | falcon|
-|Galactica | galactica|
-|GPT-J | gptj|
-|GPT-2 | gpt2|
-|LlaMA | llama|
-|LlaMA2 | llama2|
-|OPT-1.3B | opt|
+| Model | Key |
+| ----------- | ---------- |
+| Bloom | bloom |
+| Cerebras | cerebras |
+| DistilGPT-2 | distilgpt2 |
+| Falcon-7B | falcon |
+| Galactica | galactica |
+| GPT-J | gptj |
+| GPT-2 | gpt2 |
+| LlaMA | llama |
+| LlaMA2 | llama2 |
+| OPT-1.3B | opt |
+| Mistral-7B | mistral |
The above mentioned are the base variants of the LLMs. Below are the templates to get their `LoRA`, `INT8`, `INT8 + LoRA` and `INT4 + LoRA` versions.
-| Version | Template |
-| -- | -- |
-| LoRA| _lora|
-| INT8| _int8|
-| INT8 + LoRA| _lora_int8|
+| Version | Template |
+| ----------- | ---------------------- |
+| LoRA | \_lora |
+| INT8 | \_int8 |
+| INT8 + LoRA | \_lora_int8 |
+
+\*\* In order to load any model's **`INT4+LoRA`** version, you will need to make use of `GenericLoraKbitModel` class from `xturing.models`. Below is how to use it:
-** In order to load any model's __`INT4+LoRA`__ version, you will need to make use of `GenericLoraKbitModel` class from `xturing.models`. Below is how to use it:
```python
model = GenericLoraKbitModel('')
```
+
The `model_path` can be replaced with you local directory or any HuggingFace library model like `facebook/opt-1.3b`.
diff --git a/examples/models/mistral/mistral.py b/examples/models/mistral/mistral.py
new file mode 100644
index 0000000..8fecc13
--- /dev/null
+++ b/examples/models/mistral/mistral.py
@@ -0,0 +1,21 @@
+# Make the necessary imports
+from xturing.models import Mistral
+
+# Load the model
+model = Mistral()
+# Generate ouputs from the model
+outputs = model.generate(texts=["How are you?"])
+# Print the generated outputs
+print(outputs)
+
+## or
+
+# Make the necessary imports
+from xturing.models import BaseModel
+
+# Load the model
+model = BaseModel.create("mistral")
+# Generate ouputs from the model
+outputs = model.generate(texts=["How are you?"])
+# Print the generated outputs
+print(outputs)
diff --git a/examples/models/mistral/mistral_woq.py b/examples/models/mistral/mistral_woq.py
new file mode 100644
index 0000000..76d4f51
--- /dev/null
+++ b/examples/models/mistral/mistral_woq.py
@@ -0,0 +1,10 @@
+# from xturing.datasets.instruction_dataset import InstructionDataset
+from xturing.models import BaseModel
+
+# Initializes the model: Quantize model with weight only algorithms and
+# replace the linear with itrex's qbits_linear kernel
+model = BaseModel.create("mistral_int8")
+
+# Once the model has been quantized, you can do inferences directly
+output = model.generate(texts=["Why LLM models are becoming so important?"])
+print(output)
diff --git a/pyproject.toml b/pyproject.toml
index 4a6267e..234db9b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -43,7 +43,7 @@ keywords = [
dependencies = [
"torch >= 1.9.0",
"pytorch-lightning",
- "transformers==4.31.0",
+ "transformers==4.40.0",
"datasets==2.14.5",
"evaluate==0.4.0",
"bitsandbytes==0.41.1",
@@ -60,6 +60,7 @@ dependencies = [
"rouge-score >= 0.1.2",
"accelerate==0.22.0",
"wandb",
+ "scipy",
]
[project.scripts]
diff --git a/src/xturing/config/finetuning_config.yaml b/src/xturing/config/finetuning_config.yaml
index 37b82ed..e3a102e 100644
--- a/src/xturing/config/finetuning_config.yaml
+++ b/src/xturing/config/finetuning_config.yaml
@@ -322,3 +322,33 @@ opt_int8:
num_train_epochs: 3
batch_size: 8
max_length: 256
+
+mistral:
+ learning_rate: 5e-5
+ weight_decay: 0.01
+ num_train_epochs: 3
+ optimizer_name: cpu_adam
+
+mistral_lora:
+ learning_rate: 5e-5
+ weight_decay: 0.01
+ num_train_epochs: 3
+ optimizer_name: cpu_adam
+
+mistral_lora_int8:
+ learning_rate: 5e-5
+ weight_decay: 0.01
+ num_train_epochs: 3
+ optimizer_name: cpu_adam
+
+mistral_int8:
+ learning_rate: 5e-5
+ weight_decay: 0.01
+ num_train_epochs: 3
+ optimizer_name: cpu_adam
+
+mistral_lora_kbit:
+ learning_rate: 5e-5
+ weight_decay: 0.01
+ num_train_epochs: 3
+ optimizer_name: cpu_adam
diff --git a/src/xturing/config/generation_config.yaml b/src/xturing/config/generation_config.yaml
index 2eba241..6f5a0a6 100644
--- a/src/xturing/config/generation_config.yaml
+++ b/src/xturing/config/generation_config.yaml
@@ -1,7 +1,6 @@
# For large models contrastive search works very well.
# For smaller models top-p sampling is better. Contrastive search repeats the tokens in small models.
-
# Contrastive search
defaults:
max_new_tokens: 256
@@ -275,3 +274,32 @@ opt_lora_int8:
opt_int8:
max_new_tokens: 256
do_sample: false
+
+ # Contrastive search
+mistral:
+ penalty_alpha: 0.6
+ top_k: 4
+ max_new_tokens: 256
+ do_sample: false
+
+# Contrastive search
+mistral_lora:
+ penalty_alpha: 0.6
+ top_k: 4
+ max_new_tokens: 256
+ do_sample: false
+
+# Greedy search
+mistral_int8:
+ max_new_tokens: 256
+ do_sample: false
+
+# Greedy search
+mistral_lora_int8:
+ max_new_tokens: 256
+ do_sample: false
+
+# Greedy search
+mistral_lora_kbit:
+ max_new_tokens: 256
+ do_sample: false
diff --git a/src/xturing/engines/__init__.py b/src/xturing/engines/__init__.py
index 7422985..49cde7f 100644
--- a/src/xturing/engines/__init__.py
+++ b/src/xturing/engines/__init__.py
@@ -58,6 +58,13 @@
LlamaLoraInt8Engine,
LlamaLoraKbitEngine,
)
+from xturing.engines.mistral_engine import (
+ MistralEngine,
+ MistralInt8Engine,
+ MistralLoraEngine,
+ MistralLoraInt8Engine,
+ MistralLoraKbitEngine,
+)
from xturing.engines.opt_engine import (
OPTEngine,
OPTInt8Engine,
@@ -111,3 +118,8 @@
BaseEngine.add_to_registry(OPTInt8Engine.config_name, OPTInt8Engine)
BaseEngine.add_to_registry(OPTLoraEngine.config_name, OPTLoraEngine)
BaseEngine.add_to_registry(OPTLoraInt8Engine.config_name, OPTLoraInt8Engine)
+BaseEngine.add_to_registry(MistralEngine.config_name, MistralEngine)
+BaseEngine.add_to_registry(MistralInt8Engine.config_name, MistralInt8Engine)
+BaseEngine.add_to_registry(MistralLoraEngine.config_name, MistralLoraEngine)
+BaseEngine.add_to_registry(MistralLoraInt8Engine.config_name, MistralLoraInt8Engine)
+BaseEngine.add_to_registry(MistralLoraKbitEngine.config_name, MistralLoraKbitEngine)
diff --git a/src/xturing/engines/mistral_engine.py b/src/xturing/engines/mistral_engine.py
new file mode 100644
index 0000000..8776d1d
--- /dev/null
+++ b/src/xturing/engines/mistral_engine.py
@@ -0,0 +1,86 @@
+from pathlib import Path
+from typing import Optional, Union
+
+from xturing.engines.causal import CausalEngine, CausalLoraEngine, CausalLoraKbitEngine
+
+
+class MistralEngine(CausalEngine):
+ config_name: str = "mistral_engine"
+
+ def __init__(self, weights_path: Optional[Union[str, Path]] = None):
+ super().__init__(
+ model_name="MaziyarPanahi/Mistral-7B-Instruct-v0.3",
+ weights_path=weights_path,
+ trust_remote_code=True,
+ )
+
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
+
+
+class MistralLoraEngine(CausalLoraEngine):
+ config_name: str = "mistral_lora_engine"
+
+ def __init__(self, weights_path: Optional[Union[str, Path]] = None):
+ super().__init__(
+ model_name="MaziyarPanahi/Mistral-7B-Instruct-v0.3",
+ weights_path=weights_path,
+ target_modules=[
+ "q_proj",
+ "v_proj",
+ ],
+ trust_remote_code=True,
+ )
+
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
+
+
+class MistralInt8Engine(CausalEngine):
+ config_name: str = "mistral_int8_engine"
+
+ def __init__(self, weights_path: Optional[Union[str, Path]] = None):
+ super().__init__(
+ model_name="MaziyarPanahi/Mistral-7B-Instruct-v0.3",
+ weights_path=weights_path,
+ load_8bit=True,
+ trust_remote_code=True,
+ )
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
+
+
+class MistralLoraInt8Engine(CausalLoraEngine):
+ config_name: str = "mistral_lora_int8_engine"
+
+ def __init__(self, weights_path: Optional[Union[str, Path]] = None):
+ super().__init__(
+ model_name="MaziyarPanahi/Mistral-7B-Instruct-v0.3",
+ weights_path=weights_path,
+ load_8bit=True,
+ target_modules=[
+ "q_proj",
+ "v_proj",
+ ],
+ trust_remote_code=True,
+ )
+
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
+
+
+class MistralLoraKbitEngine(CausalLoraKbitEngine):
+ config_name: str = "mistral_lora_kbit_engine"
+
+ def __init__(self, weights_path: Optional[Union[str, Path]] = None):
+ model_name = "MaziyarPanahi/Mistral-7B-Instruct-v0.3"
+ super().__init__(
+ model_name=model_name,
+ weights_path=None,
+ target_modules=["q_proj", "v_proj"],
+ trust_remote_code=True,
+ load_4bit=True,
+ )
+
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
diff --git a/src/xturing/models/__init__.py b/src/xturing/models/__init__.py
index 95be19c..7ef0204 100644
--- a/src/xturing/models/__init__.py
+++ b/src/xturing/models/__init__.py
@@ -43,6 +43,13 @@
Llama2LoraInt8,
Llama2LoraKbit,
)
+from xturing.models.mistral import (
+ Mistral,
+ MistralInt8,
+ MistralLora,
+ MistralLoraInt8,
+ MistralLoraKbit,
+)
from xturing.models.opt import OPT, OPTInt8, OPTLora, OPTLoraInt8
from xturing.models.stable_diffusion import StableDiffusion
@@ -93,3 +100,8 @@
BaseModel.add_to_registry(OPTLora.config_name, OPTLora)
BaseModel.add_to_registry(OPTLoraInt8.config_name, OPTLoraInt8)
BaseModel.add_to_registry(StableDiffusion.config_name, StableDiffusion)
+BaseModel.add_to_registry(Mistral.config_name, Mistral)
+BaseModel.add_to_registry(MistralInt8.config_name, MistralInt8)
+BaseModel.add_to_registry(MistralLora.config_name, MistralLora)
+BaseModel.add_to_registry(MistralLoraInt8.config_name, MistralLoraInt8)
+BaseModel.add_to_registry(MistralLoraKbit.config_name, MistralLoraKbit)
diff --git a/src/xturing/models/mistral.py b/src/xturing/models/mistral.py
new file mode 100644
index 0000000..7712245
--- /dev/null
+++ b/src/xturing/models/mistral.py
@@ -0,0 +1,51 @@
+from typing import Optional
+
+from xturing.engines.mistral_engine import (
+ MistralEngine,
+ MistralInt8Engine,
+ MistralLoraEngine,
+ MistralLoraInt8Engine,
+ MistralLoraKbitEngine,
+)
+from xturing.models.causal import (
+ CausalInt8Model,
+ CausalLoraInt8Model,
+ CausalLoraKbitModel,
+ CausalLoraModel,
+ CausalModel,
+)
+
+
+class Mistral(CausalModel):
+ config_name: str = "mistral"
+
+ def __init__(self, weights_path: Optional[str] = None):
+ super().__init__(MistralEngine.config_name, weights_path)
+
+
+class MistralLora(CausalLoraModel):
+ config_name: str = "mistral_lora"
+
+ def __init__(self, weights_path: Optional[str] = None):
+ super().__init__(MistralLoraEngine.config_name, weights_path)
+
+
+class MistralInt8(CausalInt8Model):
+ config_name: str = "mistral_int8"
+
+ def __init__(self, weights_path: Optional[str] = None):
+ super().__init__(MistralInt8Engine.config_name, weights_path)
+
+
+class MistralLoraInt8(CausalLoraInt8Model):
+ config_name: str = "mistral_lora_int8"
+
+ def __init__(self, weights_path: Optional[str] = None):
+ super().__init__(MistralLoraInt8Engine.config_name, weights_path)
+
+
+class MistralLoraKbit(CausalLoraKbitModel):
+ config_name: str = "mistral_lora_kbit"
+
+ def __init__(self, weights_path: Optional[str] = None):
+ super().__init__(MistralLoraKbitEngine.config_name, weights_path)