Skip to content

Commit

Permalink
Merge pull request #284 from mapmeld/mamba
Browse files Browse the repository at this point in the history
Add Mamba to available LLMs
  • Loading branch information
MarcosRiveraMartinez authored Sep 23, 2024
2 parents 6a0c18d + dabb5d1 commit 570a0d6
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ keywords = [
dependencies = [
"torch >= 1.9.0",
"pytorch-lightning",
"transformers==4.31.0",
"transformers==4.39.3",
"datasets==2.14.5",
"evaluate==0.4.0",
"bitsandbytes==0.41.1",
Expand Down
4 changes: 4 additions & 0 deletions src/xturing/config/finetuning_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,10 @@ llama2_lora_kbit:
num_train_epochs: 3
optimizer_name: cpu_adam

mamba:
learning_rate: 5e-5
weight_decay: 0.01

opt:
learning_rate: 5e-5
weight_decay: 0.01
Expand Down
4 changes: 4 additions & 0 deletions src/xturing/config/generation_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ llama2_lora_kbit:
max_new_tokens: 256
do_sample: false

# Greedy search
mamba:
do_sample: false

# Contrastive search
opt:
penalty_alpha: 0.6
Expand Down
2 changes: 2 additions & 0 deletions src/xturing/engines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
LlamaLoraInt8Engine,
LlamaLoraKbitEngine,
)
from xturing.engines.mamba_engine import MambaEngine
from xturing.engines.opt_engine import (
OPTEngine,
OPTInt8Engine,
Expand Down Expand Up @@ -107,6 +108,7 @@
BaseEngine.add_to_registry(LLama2LoraEngine.config_name, LLama2LoraEngine)
BaseEngine.add_to_registry(LLama2LoraInt8Engine.config_name, LLama2LoraInt8Engine)
BaseEngine.add_to_registry(LLama2LoraKbitEngine.config_name, LLama2LoraKbitEngine)
BaseEngine.add_to_registry(MambaEngine.config_name, MambaEngine)
BaseEngine.add_to_registry(OPTEngine.config_name, OPTEngine)
BaseEngine.add_to_registry(OPTInt8Engine.config_name, OPTInt8Engine)
BaseEngine.add_to_registry(OPTLoraEngine.config_name, OPTLoraEngine)
Expand Down
22 changes: 22 additions & 0 deletions src/xturing/engines/mamba_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os
from pathlib import Path
from typing import Optional, Union

from transformers import AutoTokenizer, MambaForCausalLM

from xturing.engines.causal import CausalEngine

class MambaEngine(CausalEngine):
config_name: str = "mamba_engine"

def __init__(self, weights_path: Optional[Union[str, Path]] = None):
model_name = "state-spaces/mamba-2.8b-hf"
model = MambaForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

super().__init__(weights_path=weights_path, model=model, tokenizer=tokenizer)


def save(self, saving_path: Union[str, Path]):
self.model.save_pretrained(saving_path)
self.tokenizer.save_pretrained(saving_path)
2 changes: 2 additions & 0 deletions src/xturing/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
Llama2LoraInt8,
Llama2LoraKbit,
)
from xturing.models.mamba import Mamba
from xturing.models.opt import OPT, OPTInt8, OPTLora, OPTLoraInt8
from xturing.models.stable_diffusion import StableDiffusion

Expand Down Expand Up @@ -88,6 +89,7 @@
BaseModel.add_to_registry(Llama2Lora.config_name, Llama2Lora)
BaseModel.add_to_registry(Llama2LoraInt8.config_name, Llama2LoraInt8)
BaseModel.add_to_registry(Llama2LoraKbit.config_name, Llama2LoraKbit)
BaseModel.add_to_registry(Mamba.config_name, Mamba)
BaseModel.add_to_registry(OPT.config_name, OPT)
BaseModel.add_to_registry(OPTInt8.config_name, OPTInt8)
BaseModel.add_to_registry(OPTLora.config_name, OPTLora)
Expand Down
11 changes: 11 additions & 0 deletions src/xturing/models/mamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import Optional

from xturing.engines.mamba_engine import MambaEngine
from xturing.models.causal import CausalModel


class Mamba(CausalModel):
config_name: str = "mamba"

def __init__(self, weights_path: Optional[str] = None):
super().__init__(MambaEngine.config_name, weights_path)

0 comments on commit 570a0d6

Please sign in to comment.