From 3ebf22464b30390220d22a7b5fee04815cdbc0d9 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 30 Jul 2024 19:21:38 -0400 Subject: [PATCH] qlora-fsdp ram efficient loading with hf trainer (#1791) * fix 405b with lower cpu ram requirements * make sure to use doouble quant and only skip output embeddings * set model attributes * more fixes for sharded fsdp loading * update the base model in example to use pre-quantized nf4-bf16 weights * upstream fixes for qlora+fsdp --- docker/Dockerfile-cloud | 1 - docker/Dockerfile-cloud-no-tmux | 1 - examples/llama-3/qlora-fsdp-405b.yaml | 7 ++++--- requirements.txt | 4 ++-- src/axolotl/cli/__init__.py | 4 +++- src/axolotl/core/trainer_builder.py | 4 +++- .../config/models/input/v0_4_1/__init__.py | 8 ++++++++ src/axolotl/utils/model_shard_quant.py | 18 ++++++++++++++++++ src/axolotl/utils/models.py | 9 ++++++++- src/axolotl/utils/trainer.py | 10 ++++++---- 10 files changed, 52 insertions(+), 14 deletions(-) diff --git a/docker/Dockerfile-cloud b/docker/Dockerfile-cloud index 69ce143bb2..c0bb266d28 100644 --- a/docker/Dockerfile-cloud +++ b/docker/Dockerfile-cloud @@ -3,7 +3,6 @@ FROM winglian/axolotl:$BASE_TAG ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets" ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub" -ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub" ENV HF_HOME="/workspace/data/huggingface-cache/hub" ENV HF_HUB_ENABLE_HF_TRANSFER="1" diff --git a/docker/Dockerfile-cloud-no-tmux b/docker/Dockerfile-cloud-no-tmux index efeffef8e6..3e59d41191 100644 --- a/docker/Dockerfile-cloud-no-tmux +++ b/docker/Dockerfile-cloud-no-tmux @@ -3,7 +3,6 @@ FROM winglian/axolotl:$BASE_TAG ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets" ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub" -ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub" ENV HF_HOME="/workspace/data/huggingface-cache/hub" ENV HF_HUB_ENABLE_HF_TRANSFER="1" diff --git a/examples/llama-3/qlora-fsdp-405b.yaml b/examples/llama-3/qlora-fsdp-405b.yaml index 385b7f91d6..6eeec01c9b 100644 --- a/examples/llama-3/qlora-fsdp-405b.yaml +++ b/examples/llama-3/qlora-fsdp-405b.yaml @@ -1,4 +1,4 @@ -base_model: meta-llama/Meta-Llama-3.1-405B +base_model: hugging-quants/Meta-Llama-3.1-405B-BNB-NF4-BF16 tokenizer_type: AutoTokenizer load_in_4bit: true @@ -10,10 +10,11 @@ datasets: dataset_prepared_path: last_run_prepared val_set_size: 0.0 output_dir: ./outputs/out/qlora-llama3_1-405b +save_safetensors: true adapter: qlora -sequence_len: 1024 +sequence_len: 2048 sample_packing: true pad_to_sequence_len: true @@ -25,7 +26,7 @@ lora_target_linear: true gradient_accumulation_steps: 4 micro_batch_size: 1 -num_epochs: 4 +num_epochs: 2 optimizer: adamw_torch lr_scheduler: cosine learning_rate: 0.00001 diff --git a/requirements.txt b/requirements.txt index 5825ee1903..d2ec9266c2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,9 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 peft==0.11.1 -transformers==4.43.3 +transformers @ git+https://github.com/huggingface/transformers.git@026a173a64372e9602a16523b8fae9de4b0ff428 tokenizers==0.19.1 -bitsandbytes==0.43.1 +bitsandbytes==0.43.3 accelerate==0.32.0 deepspeed==0.14.4 pydantic==2.6.3 diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 5966d59313..a05ee84e97 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -40,7 +40,7 @@ from axolotl.utils.mlflow_ import setup_mlflow_env_vars from axolotl.utils.models import load_tokenizer from axolotl.utils.tokenization import check_dataset_labels -from axolotl.utils.trainer import prepare_optim_env +from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env from axolotl.utils.wandb_ import setup_wandb_env_vars project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) @@ -382,6 +382,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): prepare_optim_env(cfg) + prepare_opinionated_env(cfg) + normalize_config(cfg) normalize_cfg_datasets(cfg) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index ff4804b104..cf2866d81d 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1243,7 +1243,9 @@ def build(self, total_num_steps): if self.cfg.fsdp: training_arguments_kwargs["fsdp"] = self.cfg.fsdp if self.cfg.fsdp_config: - training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config) + training_arguments_kwargs["fsdp_config"] = { + k.lstrip("fsdp_"): v for k, v in dict(self.cfg.fsdp_config).items() + } if self.cfg.adapter == "qlora": training_arguments_kwargs["qlora"] = True diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index e92c794859..3b9dbb1a1c 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -235,6 +235,12 @@ class LoraConfig(BaseModel): peft_use_rslora: Optional[bool] = None peft_layer_replication: Optional[List[Tuple[int, int]]] = None + qlora_sharded_model_loading: Optional[bool] = Field( + default=False, + metadata={ + "help": "load qlora model in sharded format for FSDP using answer.ai technique." + }, + ) lora_on_cpu: Optional[bool] = None gptq: Optional[bool] = None bnb_config_kwargs: Optional[Dict[str, Any]] = None @@ -939,6 +945,8 @@ def check_evals(cls, data): @model_validator(mode="before") @classmethod def check_eval_packing(cls, data): + # TODO also should check test_datasets and val_set_size as we can skip + # if there are no eval datasets/splits if ( data.get("sample_packing") and data.get("eval_table_size") diff --git a/src/axolotl/utils/model_shard_quant.py b/src/axolotl/utils/model_shard_quant.py index 65f23b9e0f..9ed7ae471d 100644 --- a/src/axolotl/utils/model_shard_quant.py +++ b/src/axolotl/utils/model_shard_quant.py @@ -13,6 +13,7 @@ from torch import Tensor, nn from tqdm import tqdm from transformers import AutoModelForCausalLM +from transformers.quantizers import AutoHfQuantizer from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub @@ -173,6 +174,7 @@ def load_sharded_model_quant( low_memory=True, verbose=False, loading_workers=2, + quantization_config=None, ): with init_empty_weights(): model = AutoModelForCausalLM.from_config( @@ -186,15 +188,26 @@ def load_sharded_model_quant( compute_dtype=compute_dtype, quant_type="nf4", quant_storage=quant_storage, + compress_statistics=True, # bnb_4bit_use_double_quant + skip_modules=[ + "lm_head", + "embed_out", + ], ) else: # this is the more common case with HF transformers + # TODO can we detect the model arch and dynamically set skip_modules model.model = _replace_linear( model.model, Linear4bit, compute_dtype=compute_dtype, quant_type="nf4", quant_storage=quant_storage, + compress_statistics=True, # bnb_4bit_use_double_quant + skip_modules=[ + "lm_head", + "embed_out", + ], ) model.is_loaded_in_4bit = True @@ -251,6 +264,11 @@ def load_and_quantize_parallel(name_param, model, **kwargs): quant_method=quant_method, ) + # these attributes are needed to inform transformers/peft of the quantization + model.is_quantized = True + model.quantization_method = "bitsandbytes" + model.hf_quantizer = AutoHfQuantizer.from_config(quantization_config) + if cfg.local_rank == 0 and verbose: print(f"Loaded model weights in {time.time()-start:.3f} seconds") # cleanup any extra memory usage from parallel loading diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 8a50631ef8..f65da71d44 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -624,14 +624,21 @@ def load_model( elif ( qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading - and cfg.model_config_type == "dbrx" + and (cfg.model_config_type == "dbrx" or cfg.qlora_sharded_model_loading) ): quant_storage = cfg.torch_dtype + quantization_config = hasattr( + model_config, "quantization_config" + ) and getattr(model_config, "quantization_config") + quantization_config = ( + quantization_config or model_kwargs["quantization_config"] + ) model = load_sharded_model_quant( base_model, model_config, cfg, quant_storage=quant_storage, + quantization_config=quantization_config, ) skip_move_to_device = True elif ( diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index bb96240514..7a9cf2fbbd 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -393,10 +393,6 @@ def calc_sample_packing_eff_est(estimates: List[float]): def setup_deepspeed_env(cfg, stage=None): os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed - if cfg.bf16: - os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16" - elif cfg.fp16: - os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16" if stage: os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage) if stage == 3: @@ -444,6 +440,12 @@ def prepare_optim_env(cfg): os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16" +def prepare_opinionated_env(cfg): + if cfg.qlora_sharded_model_loading: + # model loading is forked after the tokenizer + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]: trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)