diff --git a/pyproject.toml b/pyproject.toml index 4a6267e..2f36a2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ keywords = [ dependencies = [ "torch >= 1.9.0", "pytorch-lightning", - "transformers==4.31.0", + "transformers==4.39.3", "datasets==2.14.5", "evaluate==0.4.0", "bitsandbytes==0.41.1", diff --git a/src/xturing/config/finetuning_config.yaml b/src/xturing/config/finetuning_config.yaml index 37b82ed..3ab1ef9 100644 --- a/src/xturing/config/finetuning_config.yaml +++ b/src/xturing/config/finetuning_config.yaml @@ -298,6 +298,10 @@ llama2_lora_kbit: num_train_epochs: 3 optimizer_name: cpu_adam +mamba: + learning_rate: 5e-5 + weight_decay: 0.01 + opt: learning_rate: 5e-5 weight_decay: 0.01 diff --git a/src/xturing/config/generation_config.yaml b/src/xturing/config/generation_config.yaml index 2eba241..3e472cf 100644 --- a/src/xturing/config/generation_config.yaml +++ b/src/xturing/config/generation_config.yaml @@ -252,6 +252,10 @@ llama2_lora_kbit: max_new_tokens: 256 do_sample: false +# Greedy search +mamba: + do_sample: false + # Contrastive search opt: penalty_alpha: 0.6 diff --git a/src/xturing/engines/__init__.py b/src/xturing/engines/__init__.py index 7422985..a97842b 100644 --- a/src/xturing/engines/__init__.py +++ b/src/xturing/engines/__init__.py @@ -58,6 +58,7 @@ LlamaLoraInt8Engine, LlamaLoraKbitEngine, ) +from xturing.engines.mamba_engine import MambaEngine from xturing.engines.opt_engine import ( OPTEngine, OPTInt8Engine, @@ -107,6 +108,7 @@ BaseEngine.add_to_registry(LLama2LoraEngine.config_name, LLama2LoraEngine) BaseEngine.add_to_registry(LLama2LoraInt8Engine.config_name, LLama2LoraInt8Engine) BaseEngine.add_to_registry(LLama2LoraKbitEngine.config_name, LLama2LoraKbitEngine) +BaseEngine.add_to_registry(MambaEngine.config_name, MambaEngine) BaseEngine.add_to_registry(OPTEngine.config_name, OPTEngine) BaseEngine.add_to_registry(OPTInt8Engine.config_name, OPTInt8Engine) BaseEngine.add_to_registry(OPTLoraEngine.config_name, OPTLoraEngine) diff --git a/src/xturing/engines/mamba_engine.py b/src/xturing/engines/mamba_engine.py new file mode 100644 index 0000000..651379c --- /dev/null +++ b/src/xturing/engines/mamba_engine.py @@ -0,0 +1,22 @@ +import os +from pathlib import Path +from typing import Optional, Union + +from transformers import AutoTokenizer, MambaForCausalLM + +from xturing.engines.causal import CausalEngine + +class MambaEngine(CausalEngine): + config_name: str = "mamba_engine" + + def __init__(self, weights_path: Optional[Union[str, Path]] = None): + model_name = "state-spaces/mamba-2.8b-hf" + model = MambaForCausalLM.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name) + + super().__init__(weights_path=weights_path, model=model, tokenizer=tokenizer) + + + def save(self, saving_path: Union[str, Path]): + self.model.save_pretrained(saving_path) + self.tokenizer.save_pretrained(saving_path) diff --git a/src/xturing/models/__init__.py b/src/xturing/models/__init__.py index 95be19c..611ef2c 100644 --- a/src/xturing/models/__init__.py +++ b/src/xturing/models/__init__.py @@ -43,6 +43,7 @@ Llama2LoraInt8, Llama2LoraKbit, ) +from xturing.models.mamba import Mamba from xturing.models.opt import OPT, OPTInt8, OPTLora, OPTLoraInt8 from xturing.models.stable_diffusion import StableDiffusion @@ -88,6 +89,7 @@ BaseModel.add_to_registry(Llama2Lora.config_name, Llama2Lora) BaseModel.add_to_registry(Llama2LoraInt8.config_name, Llama2LoraInt8) BaseModel.add_to_registry(Llama2LoraKbit.config_name, Llama2LoraKbit) +BaseModel.add_to_registry(Mamba.config_name, Mamba) BaseModel.add_to_registry(OPT.config_name, OPT) BaseModel.add_to_registry(OPTInt8.config_name, OPTInt8) BaseModel.add_to_registry(OPTLora.config_name, OPTLora) diff --git a/src/xturing/models/mamba.py b/src/xturing/models/mamba.py new file mode 100644 index 0000000..9e27662 --- /dev/null +++ b/src/xturing/models/mamba.py @@ -0,0 +1,11 @@ +from typing import Optional + +from xturing.engines.mamba_engine import MambaEngine +from xturing.models.causal import CausalModel + + +class Mamba(CausalModel): + config_name: str = "mamba" + + def __init__(self, weights_path: Optional[str] = None): + super().__init__(MambaEngine.config_name, weights_path)