Skip to content

Commit

Permalink
fix some of the edge cases for Jamba
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Mar 29, 2024
1 parent 02af082 commit 1c3e4a1
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 14 deletions.
11 changes: 8 additions & 3 deletions examples/jamba/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Jamba

qlora w/ deepspeed needs at least 2x GPUs and 35GiB VRAM per GPU

qlora single-gpu - training will start, but loss is off by an order of magnitude
- ✅ qlora w/ deepspeed Zero-2 needs at least 2x GPUs and
- 35GiB VRAM per GPU w minimal context length
- 56GiB VRAM per GPU (w multipack enabled)
- ✅ qlora w/ deepspeed Zero-3 needs at least 2x GPUs and 67GiB VRAM (wtf?)
- ✅ qlora single-gpu, ~51GiB VRAM
- ✅ multipack
- ❓ FSDP
- ❓ 8-bit LoRA
62 changes: 62 additions & 0 deletions examples/jamba/qlora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
base_model: ai21labs/Jamba-v0.1
trust_remote_code: true

load_in_8bit: false
load_in_4bit: true
strict: false

datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./out

sequence_len: 4096
sample_packing: false
pad_to_sequence_len: false
eval_sample_packing: false

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

adapter: qlora
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true

low_cpu_mem_usage: true
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 2
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.00001

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
evals_per_epoch:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
special_tokens:
24 changes: 13 additions & 11 deletions src/axolotl/monkeypatch/multipack.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,16 @@ def patch_for_multipack(model_type, model_name=None):
get_unpad_data
)
elif model_type == "gemmoe":
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# we need to load the model here in order for modeling_gemmoe to be available
with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
module_name = model_config.__class__.__module__.replace(
".configuration_gemmoe", ".modeling_gemmoe"
)
modeling_gemmoe = importlib.import_module(module_name)
modeling_gemmoe._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
elif model_type == "jamba":
patch_remote(model_name, ".configuration_jamba", ".modeling_jamba")


def patch_remote(model_name, config_name, modeling_name):
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# we need to load the model here in order for modeling_* to be available
with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
module_name = model_config.__class__.__module__.replace(config_name, modeling_name)
modeling_arch = importlib.import_module(module_name)
modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access
4 changes: 4 additions & 0 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,10 @@ def load_model(
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_quant_storage": torch.bfloat16,
}
if cfg.model_config_type == "jamba" and not cfg.deepspeed:
# for some reason, this causes the loss to be off by an order of magnitude
# but deepspeed needs this still in bfloat16
bnb_config["bnb_4bit_quant_storage"] = torch.float32

if cfg.bnb_config_kwargs:
bnb_config.update(cfg.bnb_config_kwargs)
Expand Down

0 comments on commit 1c3e4a1

Please sign in to comment.