diff --git a/README.md b/README.md index a1b857f..bdfa04a 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@


-___ +--- `xTuring` provides fast, efficient and simple fine-tuning of LLMs, such as LLaMA, GPT-J, Galactica, and more. By providing an easy-to-use interface for fine-tuning LLMs to your own data and application, xTuring makes it @@ -25,6 +25,7 @@ simple to build, customize and control LLMs. The entire process can be done insi private cloud, ensuring data privacy and security. With `xTuring` you can, + - Ingest data from different sources and preprocess them to a format LLMs can understand - Scale from single to multiple GPUs for faster fine-tuning - Leverage memory-efficient methods (i.e. INT4, LoRA fine-tuning) to reduce hardware costs by up to 90% @@ -34,8 +35,11 @@ With `xTuring` you can,
## ๐ŸŒŸ What's new? + We are excited to announce the latest enhancements to our `xTuring` library: -1. __`LLaMA 2` integration__ - You can use and fine-tune the _`LLaMA 2`_ model in different configurations: _off-the-shelf_, _off-the-shelf with INT8 precision_, _LoRA fine-tuning_, _LoRA fine-tuning with INT8 precision_ and _LoRA fine-tuning with INT4 precision_ using the `GenericModel` wrapper and/or you can use the `Llama2` class from `xturing.models` to test and finetune the model. + +1. **`LLaMA 2` integration** - You can use and fine-tune the _`LLaMA 2`_ model in different configurations: _off-the-shelf_, _off-the-shelf with INT8 precision_, _LoRA fine-tuning_, _LoRA fine-tuning with INT8 precision_ and _LoRA fine-tuning with INT4 precision_ using the `GenericModel` wrapper and/or you can use the `Llama2` class from `xturing.models` to test and finetune the model. + ```python from xturing.models import Llama2 model = Llama2() @@ -45,7 +49,9 @@ from xturing.models import BaseModel model = BaseModel.create('llama2') ``` -2. __`Evaluation`__ - Now you can evaluate any `Causal Language Model` on any dataset. The metrics currently supported is [`perplexity`](https://towardsdatascience.com/perplexity-in-language-models-87a196019a94). + +2. **`Evaluation`** - Now you can evaluate any `Causal Language Model` on any dataset. The metrics currently supported is [`perplexity`](https://towardsdatascience.com/perplexity-in-language-models-87a196019a94). + ```python # Make the necessary imports from xturing.datasets import InstructionDataset @@ -64,7 +70,9 @@ result = model.evaluate(dataset) print(f"Perplexity of the evalution: {result}") ``` -3. __`INT4` Precision__ - You can now use and fine-tune any LLM with `INT4 Precision` using `GenericLoraKbitModel`. + +3. **`INT4` Precision** - You can now use and fine-tune any LLM with `INT4 Precision` using `GenericLoraKbitModel`. + ```python # Make the necessary imports from xturing.datasets import InstructionDataset @@ -80,7 +88,7 @@ model = GenericLoraKbitModel('tiiuae/falcon-7b') model.finetune(dataset) ``` -4. __CPU inference__ - The CPU, including laptop CPUs, is now fully equipped to handle LLM inference. We integrated [Intelยฎ Extension for Transformers](https://github.com/intel/intel-extension-for-transformers) to conserve memory by compressing the model with [weight-only quantization algorithms](https://github.com/intel/intel-extension-for-transformers/blob/main/docs/weightonlyquant.md) and accelerate the inference by leveraging its highly optimized kernel on Intel platforms. +4. **CPU inference** - The CPU, including laptop CPUs, is now fully equipped to handle LLM inference. We integrated [Intelยฎ Extension for Transformers](https://github.com/intel/intel-extension-for-transformers) to conserve memory by compressing the model with [weight-only quantization algorithms](https://github.com/intel/intel-extension-for-transformers/blob/main/docs/weightonlyquant.md) and accelerate the inference by leveraging its highly optimized kernel on Intel platforms. ```python # Make the necessary imports @@ -95,7 +103,8 @@ output = model.generate(texts=["Why LLM models are becoming so important?"]) print(output) ``` -5. __Batch integration__ - By tweaking the 'batch_size' in the .generate() and .evaluate() functions, you can expedite results. Using a 'batch_size' greater than 1 typically enhances processing efficiency. +5. **Batch integration** - By tweaking the 'batch_size' in the .generate() and .evaluate() functions, you can expedite results. Using a 'batch_size' greater than 1 typically enhances processing efficiency. + ```python # Make the necessary imports from xturing.datasets import InstructionDataset @@ -119,6 +128,7 @@ For an extended insight, consider examining the [GenericModel working example](e
## โš™๏ธ Installation + ```bash pip install xturing ``` @@ -151,6 +161,7 @@ You can find the data folder [here](examples/models/llama/alpaca_data).
## CLI playground + ```bash @@ -159,6 +170,7 @@ $ xturing chat -m "" ``` ## UI playground + ```python @@ -180,6 +192,7 @@ Playground().launch() ## launches localhost UI
## ๐Ÿ“š Tutorials + - [Preparing your dataset](examples/datasets/preparing_your_dataset.py) - [Cerebras-GPT fine-tuning with LoRA and INT8](examples/models/cerebras/cerebras_lora_int8.ipynb)   [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eKq3oF7dnK8KuIfsTE70Gvvniwr1O9D0?usp=sharing) - [Cerebras-GPT fine-tuning with LoRA](examples/models/cerebras/cerebras_lora.ipynb)   [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1VjqQhstm5pT4EjPjx4Je7b3W2X1V3vDo?usp=sharing) @@ -209,17 +222,18 @@ Fine-tuning parameters: } ``` -| LLaMA-7B | DeepSpeed + CPU Offloading | LoRA + DeepSpeed | LoRA + DeepSpeed + CPU Offloading | -| :---------: | :----: | :----: | :----: | -| GPU | 33.5 GB | 23.7 GB | 21.9 GB | -| CPU | 190 GB | 10.2 GB | 14.9 GB | -| Time/epoch | 21 hours | 20 mins | 20 mins | +| LLaMA-7B | DeepSpeed + CPU Offloading | LoRA + DeepSpeed | LoRA + DeepSpeed + CPU Offloading | +| :--------: | :------------------------: | :--------------: | :-------------------------------: | +| GPU | 33.5 GB | 23.7 GB | 21.9 GB | +| CPU | 190 GB | 10.2 GB | 14.9 GB | +| Time/epoch | 21 hours | 20 mins | 20 mins | Contribute to this by submitting your performance results on other GPUs by creating an issue with your hardware specifications, memory consumption and time per epoch.
## ๐Ÿ“Ž Fine-tuned model checkpoints + We have already fine-tuned some models that you can use as your base or start playing with. Here is how you would load them: @@ -228,44 +242,49 @@ from xturing.models import BaseModel model = BaseModel.load("x/distilgpt2_lora_finetuned_alpaca") ``` -| model | dataset | Path | -|---------------------|--------|---------------| -| DistilGPT-2 LoRA | alpaca | `x/distilgpt2_lora_finetuned_alpaca` | -| LLaMA LoRA | alpaca | `x/llama_lora_finetuned_alpaca` | +| model | dataset | Path | +| ---------------- | ------- | ------------------------------------ | +| DistilGPT-2 LoRA | alpaca | `x/distilgpt2_lora_finetuned_alpaca` | +| LLaMA LoRA | alpaca | `x/llama_lora_finetuned_alpaca` |
## Supported Models + Below is a list of all the supported models via `BaseModel` class of `xTuring` and their corresponding keys to load them. -| Model | Key | -| -- | -- | -|Bloom | bloom| -|Cerebras | cerebras| -|DistilGPT-2 | distilgpt2| -|Falcon-7B | falcon| -|Galactica | galactica| -|GPT-J | gptj| -|GPT-2 | gpt2| -|LlaMA | llama| -|LlaMA2 | llama2| -|OPT-1.3B | opt| +| Model | Key | +| ----------- | ---------- | +| Bloom | bloom | +| Cerebras | cerebras | +| DistilGPT-2 | distilgpt2 | +| Falcon-7B | falcon | +| Galactica | galactica | +| GPT-J | gptj | +| GPT-2 | gpt2 | +| LlaMA | llama | +| LlaMA2 | llama2 | +| OPT-1.3B | opt | +| Mistral-7B | mistral | The above mentioned are the base variants of the LLMs. Below are the templates to get their `LoRA`, `INT8`, `INT8 + LoRA` and `INT4 + LoRA` versions. -| Version | Template | -| -- | -- | -| LoRA| _lora| -| INT8| _int8| -| INT8 + LoRA| _lora_int8| +| Version | Template | +| ----------- | ---------------------- | +| LoRA | \_lora | +| INT8 | \_int8 | +| INT8 + LoRA | \_lora_int8 | + +\*\* In order to load any model's **`INT4+LoRA`** version, you will need to make use of `GenericLoraKbitModel` class from `xturing.models`. Below is how to use it: -** In order to load any model's __`INT4+LoRA`__ version, you will need to make use of `GenericLoraKbitModel` class from `xturing.models`. Below is how to use it: ```python model = GenericLoraKbitModel('') ``` + The `model_path` can be replaced with you local directory or any HuggingFace library model like `facebook/opt-1.3b`. ## ๐Ÿ“ˆ Roadmap + - [x] Support for `LLaMA`, `GPT-J`, `GPT-2`, `OPT`, `Cerebras-GPT`, `Galactica` and `Bloom` models - [x] Dataset generation using self-instruction - [x] Low-precision LoRA fine-tuning and unsupervised fine-tuning @@ -284,6 +303,7 @@ The `model_path` can be replaced with you local directory or any HuggingFace lib
## ๐Ÿค Help and Support + If you have any questions, you can create an issue on this repository. You can also join our [Discord server](https://discord.gg/TgHXuSJEk6) and start a discussion in the `#xturing` channel. @@ -291,9 +311,11 @@ You can also join our [Discord server](https://discord.gg/TgHXuSJEk6) and start
## ๐Ÿ“ License + This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details.
## ๐ŸŒŽ Contributing + As an open source project in a rapidly evolving field, we welcome contributions of all kinds, including new features and better documentation. Please read our [contributing guide](CONTRIBUTING.md) to learn how you can get involved. diff --git a/docs/docs/overview/quickstart/test.jsx b/docs/docs/overview/quickstart/test.jsx index c738db8..7464e8f 100644 --- a/docs/docs/overview/quickstart/test.jsx +++ b/docs/docs/overview/quickstart/test.jsx @@ -1,57 +1,56 @@ -import React, { useEffect, useState } from 'react' -import clsx from 'clsx' -import MDXContent from '@theme/MDXContent' -import CodeBlock from '@theme/CodeBlock' +import React, { useEffect, useState } from "react"; +import clsx from "clsx"; +import MDXContent from "@theme/MDXContent"; +import CodeBlock from "@theme/CodeBlock"; const trainingTechniques = { - base: 'Base', - lora: 'LoRA', - lora_int8: 'LoRA INT8', - int8: 'INT8', -} + base: "Base", + lora: "LoRA", + lora_int8: "LoRA INT8", + int8: "INT8", +}; const modelList = { - bloom: 'BLOOM', - cerebras: 'Cerebras', - distilgpt2: 'DistilGPT-2', - galactica: 'Galactica', - gptj: 'GPT-J', - gpt2: 'GPT-2', - llama: 'LLaMA', - llama2: 'LLaMA 2', - opt: 'OPT', -} + bloom: "BLOOM", + cerebras: "Cerebras", + distilgpt2: "DistilGPT-2", + galactica: "Galactica", + gptj: "GPT-J", + gpt2: "GPT-2", + llama: "LLaMA", + llama2: "LLaMA 2", + opt: "OPT", + mistral: "Mistral", +}; -export default function Test( - {instruction} -) { +export default function Test({ instruction }) { // const [code, setCode] = useState('llama'); const [code, setCode] = useState({ - model: '', - technique: 'base', - }) + model: "", + technique: "base", + }); - let finalKey = '' - if (code.technique === 'base') { - finalKey = `${code.model}` + let finalKey = ""; + if (code.technique === "base") { + finalKey = `${code.model}`; } else { - finalKey = `${code.model}_${code.technique}` + finalKey = `${code.model}_${code.technique}`; } - + useEffect(() => { setCode({ - model: 'llama', - technique: 'base' + model: "llama", + technique: "base", }); }, []); return ( -
- +
+ - +
- ) -} \ No newline at end of file + ); +} diff --git a/docs/docs/overview/supported_models.md b/docs/docs/overview/supported_models.md index 02932dd..17de8bb 100644 --- a/docs/docs/overview/supported_models.md +++ b/docs/docs/overview/supported_models.md @@ -5,32 +5,39 @@ description: Models Supported by xTuring --- + ## Base versions -| Model | Model Key | LoRA | INT8 | LoRA + INT8 | LoRA + INT4 | -| ------ | --- | :---: | :---: | :---: | :---: | -| BLOOM 1.1B| bloom | โœ… | โœ… | โœ… | โœ… | -| Cerebras 1.3B| cerebras | โœ… | โœ… | โœ… | โœ… | -| DistilGPT-2 | distilgpt2 | โœ… | โœ… | โœ… | โœ… | -| Falcon 7B | falcon | โœ… | โœ… | โœ… | โœ… | -| Galactica 6.7B| galactica | โœ… | โœ… | โœ… | โœ… | -| GPT-J 6B | gptj | โœ… | โœ… | โœ… | โœ… | -| GPT-2 | gpt2 | โœ… | โœ… | โœ… | โœ… | -| LLaMA 7B | llama | โœ… | โœ… | โœ… | โœ… | -| LLaMA2 | llama2 | โœ… | โœ… | โœ… | โœ… | -| OPT 1.3B | opt | โœ… | โœ… | โœ… | โœ… | + +| Model | Model Key | LoRA | INT8 | LoRA + INT8 | LoRA + INT4 | +| -------------- | ---------- | :--: | :--: | :---------: | :---------: | +| BLOOM 1.1B | bloom | โœ… | โœ… | โœ… | โœ… | +| Cerebras 1.3B | cerebras | โœ… | โœ… | โœ… | โœ… | +| DistilGPT-2 | distilgpt2 | โœ… | โœ… | โœ… | โœ… | +| Falcon 7B | falcon | โœ… | โœ… | โœ… | โœ… | +| Galactica 6.7B | galactica | โœ… | โœ… | โœ… | โœ… | +| GPT-J 6B | gptj | โœ… | โœ… | โœ… | โœ… | +| GPT-2 | gpt2 | โœ… | โœ… | โœ… | โœ… | +| LLaMA 7B | llama | โœ… | โœ… | โœ… | โœ… | +| LLaMA2 | llama2 | โœ… | โœ… | โœ… | โœ… | +| OPT 1.3B | opt | โœ… | โœ… | โœ… | โœ… | +| Mistral 7B | mistral | โœ… | โœ… | โœ… | โœ… | ### Memory-efficient versions + > The above mentioned are the base variants of the LLMs. Below are the templates to get their `LoRA`, `INT8`, `INT8 + LoRA` and `INT4 + LoRA` versions. -| Version | Template | -| -- | -- | -| LoRA | _lora| -| INT8 | _int8| -| INT8 + LoRA | _lora_int8| +| Version | Template | +| ----------- | ---------------------- | +| LoRA | \_lora | +| INT8 | \_int8 | +| INT8 + LoRA | \_lora_int8 | ### INT4 Precision model versions -> In order to load any model's __`INT4+LoRA`__ version, you will need to make use of `GenericLoraKbitModel` class from `xturing.models`. Below is how to use it: + +> In order to load any model's **`INT4+LoRA`** version, you will need to make use of `GenericLoraKbitModel` class from `xturing.models`. Below is how to use it: + ```python model = GenericLoraKbitModel('/path/to/model') ``` + The `/path/to/model` can be replaced with you local directory or any HuggingFace library model like `facebook/opt-1.3b`. diff --git a/examples/README.md b/examples/README.md index 3b9026b..11abef9 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,7 +1,9 @@ # Navigating through examples -Here, is a brief about how to navigate through examples quick and efficiently, and get your hands dirty with `xTuring`. + +Here, is a brief about how to navigate through examples quick and efficiently, and get your hands dirty with `xTuring`. ## Directory structure + ``` examples/ | datasets @@ -15,44 +17,52 @@ examples/ ``` ### datsets/ -This directory consists of multiple ways to generate your custom dataset from a given set of examples. + +This directory consists of multiple ways to generate your custom dataset from a given set of examples. ### features/ -This directory consists of files with exapmles highlighting speific major features of the library, which can be replicated to any LLM you want. -For example, in `dataset_generation/`, you will find an example on how to generate your custom dataset from a .jsonl file. In `evaluation/`, you will find a specific exapmle on how to evaluate your finetuned model, which can then be extended to any LLM and any dataset. + +This directory consists of files with exapmles highlighting speific major features of the library, which can be replicated to any LLM you want. +For example, in `dataset_generation/`, you will find an example on how to generate your custom dataset from a .jsonl file. In `evaluation/`, you will find a specific exapmle on how to evaluate your finetuned model, which can then be extended to any LLM and any dataset. ### models/ -This directory consists of examples specific to each model mentioned. + +This directory consists of examples specific to each model mentioned. ### playground_ui/ + This directory consists of an example which demonstrates how you can play around with your LLM through a web interface. ## Models + Below is a list of all the supported models via `BaseModel` class of `xTuring` and their corresponding keys to load them. -| Model | Key | -| -- | -- | -|Bloom | bloom| -|Cerebras | cerebras| -|DistilGPT-2 | distilgpt2| -|Falcon-7B | falcon| -|Galactica | galactica| -|GPT-J | gptj| -|GPT-2 | gpt2| -|LlaMA | llama| -|LlaMA2 | llama2| -|OPT-1.3B | opt| +| Model | Key | +| ----------- | ---------- | +| Bloom | bloom | +| Cerebras | cerebras | +| DistilGPT-2 | distilgpt2 | +| Falcon-7B | falcon | +| Galactica | galactica | +| GPT-J | gptj | +| GPT-2 | gpt2 | +| LlaMA | llama | +| LlaMA2 | llama2 | +| OPT-1.3B | opt | +| Mistral-7B | mistral | The above mentioned are the base variants of the LLMs. Below are the templates to get their `LoRA`, `INT8`, `INT8 + LoRA` and `INT4 + LoRA` versions. -| Version | Template | -| -- | -- | -| LoRA| _lora| -| INT8| _int8| -| INT8 + LoRA| _lora_int8| +| Version | Template | +| ----------- | ---------------------- | +| LoRA | \_lora | +| INT8 | \_int8 | +| INT8 + LoRA | \_lora_int8 | + +\*\* In order to load any model's **`INT4+LoRA`** version, you will need to make use of `GenericLoraKbitModel` class from `xturing.models`. Below is how to use it: -** In order to load any model's __`INT4+LoRA`__ version, you will need to make use of `GenericLoraKbitModel` class from `xturing.models`. Below is how to use it: ```python model = GenericLoraKbitModel('') ``` + The `model_path` can be replaced with you local directory or any HuggingFace library model like `facebook/opt-1.3b`. diff --git a/examples/models/mistral/mistral.py b/examples/models/mistral/mistral.py new file mode 100644 index 0000000..8fecc13 --- /dev/null +++ b/examples/models/mistral/mistral.py @@ -0,0 +1,21 @@ +# Make the necessary imports +from xturing.models import Mistral + +# Load the model +model = Mistral() +# Generate ouputs from the model +outputs = model.generate(texts=["How are you?"]) +# Print the generated outputs +print(outputs) + +## or + +# Make the necessary imports +from xturing.models import BaseModel + +# Load the model +model = BaseModel.create("mistral") +# Generate ouputs from the model +outputs = model.generate(texts=["How are you?"]) +# Print the generated outputs +print(outputs) diff --git a/examples/models/mistral/mistral_woq.py b/examples/models/mistral/mistral_woq.py new file mode 100644 index 0000000..76d4f51 --- /dev/null +++ b/examples/models/mistral/mistral_woq.py @@ -0,0 +1,10 @@ +# from xturing.datasets.instruction_dataset import InstructionDataset +from xturing.models import BaseModel + +# Initializes the model: Quantize model with weight only algorithms and +# replace the linear with itrex's qbits_linear kernel +model = BaseModel.create("mistral_int8") + +# Once the model has been quantized, you can do inferences directly +output = model.generate(texts=["Why LLM models are becoming so important?"]) +print(output) diff --git a/pyproject.toml b/pyproject.toml index 4a6267e..234db9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ keywords = [ dependencies = [ "torch >= 1.9.0", "pytorch-lightning", - "transformers==4.31.0", + "transformers==4.40.0", "datasets==2.14.5", "evaluate==0.4.0", "bitsandbytes==0.41.1", @@ -60,6 +60,7 @@ dependencies = [ "rouge-score >= 0.1.2", "accelerate==0.22.0", "wandb", + "scipy", ] [project.scripts] diff --git a/src/xturing/config/finetuning_config.yaml b/src/xturing/config/finetuning_config.yaml index 37b82ed..e3a102e 100644 --- a/src/xturing/config/finetuning_config.yaml +++ b/src/xturing/config/finetuning_config.yaml @@ -322,3 +322,33 @@ opt_int8: num_train_epochs: 3 batch_size: 8 max_length: 256 + +mistral: + learning_rate: 5e-5 + weight_decay: 0.01 + num_train_epochs: 3 + optimizer_name: cpu_adam + +mistral_lora: + learning_rate: 5e-5 + weight_decay: 0.01 + num_train_epochs: 3 + optimizer_name: cpu_adam + +mistral_lora_int8: + learning_rate: 5e-5 + weight_decay: 0.01 + num_train_epochs: 3 + optimizer_name: cpu_adam + +mistral_int8: + learning_rate: 5e-5 + weight_decay: 0.01 + num_train_epochs: 3 + optimizer_name: cpu_adam + +mistral_lora_kbit: + learning_rate: 5e-5 + weight_decay: 0.01 + num_train_epochs: 3 + optimizer_name: cpu_adam diff --git a/src/xturing/config/generation_config.yaml b/src/xturing/config/generation_config.yaml index 2eba241..6f5a0a6 100644 --- a/src/xturing/config/generation_config.yaml +++ b/src/xturing/config/generation_config.yaml @@ -1,7 +1,6 @@ # For large models contrastive search works very well. # For smaller models top-p sampling is better. Contrastive search repeats the tokens in small models. - # Contrastive search defaults: max_new_tokens: 256 @@ -275,3 +274,32 @@ opt_lora_int8: opt_int8: max_new_tokens: 256 do_sample: false + + # Contrastive search +mistral: + penalty_alpha: 0.6 + top_k: 4 + max_new_tokens: 256 + do_sample: false + +# Contrastive search +mistral_lora: + penalty_alpha: 0.6 + top_k: 4 + max_new_tokens: 256 + do_sample: false + +# Greedy search +mistral_int8: + max_new_tokens: 256 + do_sample: false + +# Greedy search +mistral_lora_int8: + max_new_tokens: 256 + do_sample: false + +# Greedy search +mistral_lora_kbit: + max_new_tokens: 256 + do_sample: false diff --git a/src/xturing/engines/__init__.py b/src/xturing/engines/__init__.py index 7422985..49cde7f 100644 --- a/src/xturing/engines/__init__.py +++ b/src/xturing/engines/__init__.py @@ -58,6 +58,13 @@ LlamaLoraInt8Engine, LlamaLoraKbitEngine, ) +from xturing.engines.mistral_engine import ( + MistralEngine, + MistralInt8Engine, + MistralLoraEngine, + MistralLoraInt8Engine, + MistralLoraKbitEngine, +) from xturing.engines.opt_engine import ( OPTEngine, OPTInt8Engine, @@ -111,3 +118,8 @@ BaseEngine.add_to_registry(OPTInt8Engine.config_name, OPTInt8Engine) BaseEngine.add_to_registry(OPTLoraEngine.config_name, OPTLoraEngine) BaseEngine.add_to_registry(OPTLoraInt8Engine.config_name, OPTLoraInt8Engine) +BaseEngine.add_to_registry(MistralEngine.config_name, MistralEngine) +BaseEngine.add_to_registry(MistralInt8Engine.config_name, MistralInt8Engine) +BaseEngine.add_to_registry(MistralLoraEngine.config_name, MistralLoraEngine) +BaseEngine.add_to_registry(MistralLoraInt8Engine.config_name, MistralLoraInt8Engine) +BaseEngine.add_to_registry(MistralLoraKbitEngine.config_name, MistralLoraKbitEngine) diff --git a/src/xturing/engines/mistral_engine.py b/src/xturing/engines/mistral_engine.py new file mode 100644 index 0000000..8776d1d --- /dev/null +++ b/src/xturing/engines/mistral_engine.py @@ -0,0 +1,86 @@ +from pathlib import Path +from typing import Optional, Union + +from xturing.engines.causal import CausalEngine, CausalLoraEngine, CausalLoraKbitEngine + + +class MistralEngine(CausalEngine): + config_name: str = "mistral_engine" + + def __init__(self, weights_path: Optional[Union[str, Path]] = None): + super().__init__( + model_name="MaziyarPanahi/Mistral-7B-Instruct-v0.3", + weights_path=weights_path, + trust_remote_code=True, + ) + + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + + +class MistralLoraEngine(CausalLoraEngine): + config_name: str = "mistral_lora_engine" + + def __init__(self, weights_path: Optional[Union[str, Path]] = None): + super().__init__( + model_name="MaziyarPanahi/Mistral-7B-Instruct-v0.3", + weights_path=weights_path, + target_modules=[ + "q_proj", + "v_proj", + ], + trust_remote_code=True, + ) + + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + + +class MistralInt8Engine(CausalEngine): + config_name: str = "mistral_int8_engine" + + def __init__(self, weights_path: Optional[Union[str, Path]] = None): + super().__init__( + model_name="MaziyarPanahi/Mistral-7B-Instruct-v0.3", + weights_path=weights_path, + load_8bit=True, + trust_remote_code=True, + ) + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + + +class MistralLoraInt8Engine(CausalLoraEngine): + config_name: str = "mistral_lora_int8_engine" + + def __init__(self, weights_path: Optional[Union[str, Path]] = None): + super().__init__( + model_name="MaziyarPanahi/Mistral-7B-Instruct-v0.3", + weights_path=weights_path, + load_8bit=True, + target_modules=[ + "q_proj", + "v_proj", + ], + trust_remote_code=True, + ) + + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + + +class MistralLoraKbitEngine(CausalLoraKbitEngine): + config_name: str = "mistral_lora_kbit_engine" + + def __init__(self, weights_path: Optional[Union[str, Path]] = None): + model_name = "MaziyarPanahi/Mistral-7B-Instruct-v0.3" + super().__init__( + model_name=model_name, + weights_path=None, + target_modules=["q_proj", "v_proj"], + trust_remote_code=True, + load_4bit=True, + ) + + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id diff --git a/src/xturing/models/__init__.py b/src/xturing/models/__init__.py index 95be19c..7ef0204 100644 --- a/src/xturing/models/__init__.py +++ b/src/xturing/models/__init__.py @@ -43,6 +43,13 @@ Llama2LoraInt8, Llama2LoraKbit, ) +from xturing.models.mistral import ( + Mistral, + MistralInt8, + MistralLora, + MistralLoraInt8, + MistralLoraKbit, +) from xturing.models.opt import OPT, OPTInt8, OPTLora, OPTLoraInt8 from xturing.models.stable_diffusion import StableDiffusion @@ -93,3 +100,8 @@ BaseModel.add_to_registry(OPTLora.config_name, OPTLora) BaseModel.add_to_registry(OPTLoraInt8.config_name, OPTLoraInt8) BaseModel.add_to_registry(StableDiffusion.config_name, StableDiffusion) +BaseModel.add_to_registry(Mistral.config_name, Mistral) +BaseModel.add_to_registry(MistralInt8.config_name, MistralInt8) +BaseModel.add_to_registry(MistralLora.config_name, MistralLora) +BaseModel.add_to_registry(MistralLoraInt8.config_name, MistralLoraInt8) +BaseModel.add_to_registry(MistralLoraKbit.config_name, MistralLoraKbit) diff --git a/src/xturing/models/mistral.py b/src/xturing/models/mistral.py new file mode 100644 index 0000000..7712245 --- /dev/null +++ b/src/xturing/models/mistral.py @@ -0,0 +1,51 @@ +from typing import Optional + +from xturing.engines.mistral_engine import ( + MistralEngine, + MistralInt8Engine, + MistralLoraEngine, + MistralLoraInt8Engine, + MistralLoraKbitEngine, +) +from xturing.models.causal import ( + CausalInt8Model, + CausalLoraInt8Model, + CausalLoraKbitModel, + CausalLoraModel, + CausalModel, +) + + +class Mistral(CausalModel): + config_name: str = "mistral" + + def __init__(self, weights_path: Optional[str] = None): + super().__init__(MistralEngine.config_name, weights_path) + + +class MistralLora(CausalLoraModel): + config_name: str = "mistral_lora" + + def __init__(self, weights_path: Optional[str] = None): + super().__init__(MistralLoraEngine.config_name, weights_path) + + +class MistralInt8(CausalInt8Model): + config_name: str = "mistral_int8" + + def __init__(self, weights_path: Optional[str] = None): + super().__init__(MistralInt8Engine.config_name, weights_path) + + +class MistralLoraInt8(CausalLoraInt8Model): + config_name: str = "mistral_lora_int8" + + def __init__(self, weights_path: Optional[str] = None): + super().__init__(MistralLoraInt8Engine.config_name, weights_path) + + +class MistralLoraKbit(CausalLoraKbitModel): + config_name: str = "mistral_lora_kbit" + + def __init__(self, weights_path: Optional[str] = None): + super().__init__(MistralLoraKbitEngine.config_name, weights_path)