From 1c3e4a1d134ac70b79274073daf8eee07d105085 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 29 Mar 2024 00:07:17 -0400 Subject: [PATCH] fix some of the edge cases for Jamba --- examples/jamba/README.md | 11 +++-- examples/jamba/qlora.yaml | 62 ++++++++++++++++++++++++++++ src/axolotl/monkeypatch/multipack.py | 24 ++++++----- src/axolotl/utils/models.py | 4 ++ 4 files changed, 87 insertions(+), 14 deletions(-) create mode 100644 examples/jamba/qlora.yaml diff --git a/examples/jamba/README.md b/examples/jamba/README.md index aa98c02450..54f5d1da9c 100644 --- a/examples/jamba/README.md +++ b/examples/jamba/README.md @@ -1,5 +1,10 @@ # Jamba -qlora w/ deepspeed needs at least 2x GPUs and 35GiB VRAM per GPU - -qlora single-gpu - training will start, but loss is off by an order of magnitude +- ✅ qlora w/ deepspeed Zero-2 needs at least 2x GPUs and + - 35GiB VRAM per GPU w minimal context length + - 56GiB VRAM per GPU (w multipack enabled) +- ✅ qlora w/ deepspeed Zero-3 needs at least 2x GPUs and 67GiB VRAM (wtf?) +- ✅ qlora single-gpu, ~51GiB VRAM +- ✅ multipack +- ❓ FSDP +- ❓ 8-bit LoRA diff --git a/examples/jamba/qlora.yaml b/examples/jamba/qlora.yaml new file mode 100644 index 0000000000..41a3854fe1 --- /dev/null +++ b/examples/jamba/qlora.yaml @@ -0,0 +1,62 @@ +base_model: ai21labs/Jamba-v0.1 +trust_remote_code: true + +load_in_8bit: false +load_in_4bit: true +strict: false + +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca +dataset_prepared_path: +val_set_size: 0.0 +output_dir: ./out + +sequence_len: 4096 +sample_packing: false +pad_to_sequence_len: false +eval_sample_packing: false + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +adapter: qlora +lora_r: 8 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true + +low_cpu_mem_usage: true +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 2 +optimizer: paged_adamw_8bit +lr_scheduler: cosine +learning_rate: 0.00001 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 10 +evals_per_epoch: +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +special_tokens: diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index fbcaf7a668..a8f5e7a84f 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -48,14 +48,16 @@ def patch_for_multipack(model_type, model_name=None): get_unpad_data ) elif model_type == "gemmoe": - model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) - # we need to load the model here in order for modeling_gemmoe to be available - with init_empty_weights(): - AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) - module_name = model_config.__class__.__module__.replace( - ".configuration_gemmoe", ".modeling_gemmoe" - ) - modeling_gemmoe = importlib.import_module(module_name) - modeling_gemmoe._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) + patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe") + elif model_type == "jamba": + patch_remote(model_name, ".configuration_jamba", ".modeling_jamba") + + +def patch_remote(model_name, config_name, modeling_name): + model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + # we need to load the model here in order for modeling_* to be available + with init_empty_weights(): + AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) + module_name = model_config.__class__.__module__.replace(config_name, modeling_name) + modeling_arch = importlib.import_module(module_name) + modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 911a6c31be..31686f6006 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -456,6 +456,10 @@ def load_model( "bnb_4bit_quant_type": "nf4", "bnb_4bit_quant_storage": torch.bfloat16, } + if cfg.model_config_type == "jamba" and not cfg.deepspeed: + # for some reason, this causes the loss to be off by an order of magnitude + # but deepspeed needs this still in bfloat16 + bnb_config["bnb_4bit_quant_storage"] = torch.float32 if cfg.bnb_config_kwargs: bnb_config.update(cfg.bnb_config_kwargs)