From eac9c2aa043a9186f80663ee07ea1b5bc4bcd48b Mon Sep 17 00:00:00 2001 From: Chirag Jain Date: Tue, 12 Dec 2023 09:18:23 +0000 Subject: [PATCH] WIP: Enable qlora with deepspeed --- checkpoint_utils.py | 99 ++++++++++++--------------- data_utils.py | 40 ++++------- dist_utils.py | 36 ++++++++++ train.py | 159 +++++++++++++++++++++++++++++++++----------- 4 files changed, 212 insertions(+), 122 deletions(-) create mode 100644 dist_utils.py diff --git a/checkpoint_utils.py b/checkpoint_utils.py index 4a81312..c870791 100644 --- a/checkpoint_utils.py +++ b/checkpoint_utils.py @@ -5,9 +5,9 @@ import tempfile from typing import Optional, Union -from accelerate.state import AcceleratorState from transformers.trainer_utils import get_last_checkpoint +from dist_utils import DistributedState from mlfoundry_utils import ( download_mlfoundry_artifact, get_checkpoint_artifact_version_with_step_or_none, @@ -104,69 +104,56 @@ def get_best_checkpoint_for_resume_if_any( def get_last_checkpoint_for_resume_if_any( - cache_dir, output_dir, resume_from_checkpoint: Optional[Union[bool, str]], mlfoundry_enable_reporting: bool, mlfoundry_ml_repo: Optional[str], mlfoundry_checkpoint_artifact_name: Optional[str], ) -> Optional[str]: - accelerator_s = AcceleratorState() - last_checkpoint_info_path = os.path.join(cache_dir, "last_checkpoint_info.json") last_checkpoint_dir = None - if accelerator_s.is_main_process: - check_mlfoundry = False - # resume_from_checkpoint can be None/true/false/string, None is default - if resume_from_checkpoint is None: - # If no explicit choice has been made we will try and check with mlfoundry we are allowed to + check_mlfoundry = False + # resume_from_checkpoint can be None/true/false/string, None is default + if resume_from_checkpoint is None: + # If no explicit choice has been made we will try and check with mlfoundry we are allowed to + check_mlfoundry = True + elif isinstance(resume_from_checkpoint, str): + # If an explicit choice has been made we will check if the checkpoint exists on disk + if os.path.exists(resume_from_checkpoint): + last_checkpoint_dir = resume_from_checkpoint + else: + raise ValueError(f"Provided path for --resume_from_checkpoint `{resume_from_checkpoint}` does not exist!") + # TODO (chiragjn): Add support for resuming from an already saved checkpoint outside of the job run + # Although this is risky, because all other args (model, data, state) should remain same for a "correct" resume + # Note: Instead if we just want to resume from last checkpoint of the same job run then just use --mlfoundry_enable_reporting true --mlfoundry_checkpoint_artifact_name + # elif _is_mlfoundry_artifact(training_arguments.resume_from_checkpoint): + # _download_mlfoundry_artifact(...) + elif resume_from_checkpoint is True: + # If set to true, we will automatically locate the latest checkpoint, first checking output dir, next mlfoundry if we are allowed to + if os.path.exists(output_dir): + possible_last_checkpoint_dir = get_last_checkpoint(output_dir) + if possible_last_checkpoint_dir: + last_checkpoint_dir = possible_last_checkpoint_dir + + if not last_checkpoint_dir: check_mlfoundry = True - elif isinstance(resume_from_checkpoint, str): - # If an explicit choice has been made we will check if the checkpoint exists on disk - if os.path.exists(resume_from_checkpoint): - last_checkpoint_dir = resume_from_checkpoint - else: - raise ValueError( - f"Provided path for --resume_from_checkpoint `{resume_from_checkpoint}` does not exist!" - ) - # TODO (chiragjn): Add support for resuming from an already saved checkpoint outside of the job run - # Although this is risky, because all other args (model, data, state) should remain same for a "correct" resume - # Note: Instead if we just want to resume from last checkpoint of the same job run then just use --mlfoundry_enable_reporting true --mlfoundry_checkpoint_artifact_name - # elif _is_mlfoundry_artifact(training_arguments.resume_from_checkpoint): - # _download_mlfoundry_artifact(...) - elif resume_from_checkpoint is True: - # If set to true, we will automatically locate the latest checkpoint, first checking output dir, next mlfoundry if we are allowed to - if os.path.exists(output_dir): - possible_last_checkpoint_dir = get_last_checkpoint(output_dir) - if possible_last_checkpoint_dir: - last_checkpoint_dir = possible_last_checkpoint_dir - - if not last_checkpoint_dir: - check_mlfoundry = True - - if check_mlfoundry and mlfoundry_enable_reporting and mlfoundry_checkpoint_artifact_name: - logger.info("Checking for any past checkpoints from same job run...") - last_checkpoint_dir = download_last_checkpoint_if_present( - ml_repo=mlfoundry_ml_repo, - checkpoint_artifact_name=mlfoundry_checkpoint_artifact_name, - local_dir=output_dir, - ) - - with open(last_checkpoint_info_path, "w") as f: - last_checkpoint_info = {"last_checkpoint_dir": last_checkpoint_dir} - json.dump(last_checkpoint_info, f) - - if last_checkpoint_dir: - _ = get_best_checkpoint_for_resume_if_any( - output_dir=output_dir, - last_checkpoint_dir=last_checkpoint_dir, - mlfoundry_enable_reporting=mlfoundry_enable_reporting, - mlfoundry_ml_repo=mlfoundry_ml_repo, - mlfoundry_checkpoint_artifact_name=mlfoundry_checkpoint_artifact_name, - ) - else: - with open(last_checkpoint_info_path, "r") as f: - last_checkpoint_info = json.load(f) - last_checkpoint_dir = last_checkpoint_info["last_checkpoint_dir"] + + if check_mlfoundry and mlfoundry_enable_reporting and mlfoundry_checkpoint_artifact_name: + logger.info("Checking for any past checkpoints from same job run...") + last_checkpoint_dir = download_last_checkpoint_if_present( + ml_repo=mlfoundry_ml_repo, + checkpoint_artifact_name=mlfoundry_checkpoint_artifact_name, + local_dir=output_dir, + ) + + if last_checkpoint_dir: + _ = get_best_checkpoint_for_resume_if_any( + output_dir=output_dir, + last_checkpoint_dir=last_checkpoint_dir, + mlfoundry_enable_reporting=mlfoundry_enable_reporting, + mlfoundry_ml_repo=mlfoundry_ml_repo, + mlfoundry_checkpoint_artifact_name=mlfoundry_checkpoint_artifact_name, + ) + return last_checkpoint_dir diff --git a/data_utils.py b/data_utils.py index deea7ee..a5053ab 100644 --- a/data_utils.py +++ b/data_utils.py @@ -7,7 +7,6 @@ from urllib.parse import parse_qsl, urlparse import torch -from accelerate.state import AcceleratorState from cloudfiles import CloudFile from datasets import Dataset, DatasetDict from sklearn.model_selection import train_test_split @@ -249,30 +248,17 @@ def build_dataset( tokenizer, max_length: int, train_on_prompt: bool, - cache_dir: str, ): - accelerator_s = AcceleratorState() - logger.info("Building dataset...") - dataset_cache_path = os.path.join(cache_dir, "dataset") - if accelerator_s.is_main_process: - builder = CausalDatasetBuilder(tokenizer=tokenizer, max_length=max_length, train_on_prompt=train_on_prompt) - dataset_dict = DatasetDict(train=Dataset.from_list(train_data), eval=Dataset.from_list(eval_data)) - # TODO (chiragjn): Read cpu limits from cgroup, cpu_count is not usable in containers environment - num_proc = max(1, min(4, os.cpu_count())) - num_proc = num_proc if num_proc > 1 else None - dataset_dict = dataset_dict.map( - builder.construct_dataset, - remove_columns=[PROMPT_KEY, COMPLETION_KEY], - batched=True, - batch_size=32, - num_proc=num_proc, - ) - dataset_dict.save_to_disk(dataset_cache_path) - else: - logger.info("Loading datasets from cache ...") - dataset_dict = DatasetDict.load_from_disk(dataset_cache_path) - dataset_dict = dataset_dict.with_format("torch") - train_dataset, eval_dataset = dataset_dict["train"], dataset_dict["eval"] - logger.info(f"Train data size: {len(train_dataset)}") - logger.info(f"Eval data size: {len(eval_dataset)}") - return train_dataset, eval_dataset + builder = CausalDatasetBuilder(tokenizer=tokenizer, max_length=max_length, train_on_prompt=train_on_prompt) + dataset_dict = DatasetDict(train=Dataset.from_list(train_data), eval=Dataset.from_list(eval_data)) + # TODO (chiragjn): Read cpu limits from cgroup, cpu_count is not usable in containers environment + num_proc = max(1, min(4, os.cpu_count())) + num_proc = num_proc if num_proc > 1 else None + dataset_dict = dataset_dict.map( + builder.construct_dataset, + remove_columns=[PROMPT_KEY, COMPLETION_KEY], + batched=True, + batch_size=32, + num_proc=num_proc, + ) + return dataset_dict diff --git a/dist_utils.py b/dist_utils.py new file mode 100644 index 0000000..c783faa --- /dev/null +++ b/dist_utils.py @@ -0,0 +1,36 @@ +import contextlib +import logging + +import torch.distributed + +logger = logging.getLogger("truefoundry-finetune") + + +class DistributedState: + def __init__(self, world_size: int, local_rank: int): + self.world_size = world_size + self.local_rank = local_rank + + @property + def is_distributed(self): + return self.world_size > 1 + + @property + def is_main_process(self): + return self.local_rank <= 0 + + @contextlib.contextmanager + def main_process_first(self): + if self.is_distributed: + if not self.is_main_process: + torch.distributed.barrier() + yield + if self.is_main_process: + logger.info("Getting other ranks in sync with main process") + torch.distributed.barrier() + else: + yield + + def wait_for_everyone(self): + if self.is_distributed: + torch.distributed.barrier() diff --git a/train.py b/train.py index 2bebef2..9f8cf73 100644 --- a/train.py +++ b/train.py @@ -12,8 +12,8 @@ import bitsandbytes as bnb import mlfoundry import torch -from accelerate import Accelerator, infer_auto_device_map, init_empty_weights -from accelerate.state import AcceleratorState +from accelerate import infer_auto_device_map, init_empty_weights +from datasets import DatasetDict from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict from peft import ( AutoPeftModelForCausalLM, @@ -40,6 +40,7 @@ from checkpoint_utils import cleanup_checkpoints, get_last_checkpoint_for_resume_if_any from data_utils import SequenceDataCollator, build_dataset, get_data +from dist_utils import DistributedState from mlfoundry_utils import MLFoundryCallback, log_model_to_mlfoundry # TODO (chiragjn): @@ -203,6 +204,17 @@ class OtherArguments: ) +class HFTrainer(Trainer): + def _inner_training_loop( + self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None + ): + # Hack to fix: https://github.com/huggingface/transformers/issues/24558 + if self.args.auto_find_batch_size: + self.model_wrapped = self.model + self.deepspeed = None + return super()._inner_training_loop(batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval) + + def get_torch_dtype(training_arguments: HFTrainingArguments): torch_dtype = None if training_arguments.bf16: @@ -381,7 +393,10 @@ def get_model( other_arguments: OtherArguments, device_map=None, ): - accelerator_s = AcceleratorState() + dist_s = DistributedState( + world_size=training_arguments.world_size, + local_rank=training_arguments.local_rank, + ) logger.info("Loading model...") model_load_kwargs = {} model_load_kwargs["use_cache"] = False if training_arguments.gradient_checkpointing else True @@ -415,12 +430,12 @@ def get_model( device_map=device_map, **model_load_kwargs, ) - if accelerator_s.is_main_process: + if dist_s.is_main_process: _log_model_parameters(model) model = prepare_model_for_kbit_training( model, use_gradient_checkpointing=training_arguments.gradient_checkpointing ) - if accelerator_s.is_main_process: + if dist_s.is_main_process: _log_model_parameters(model) # TODO (chiragjn): This is disabled because resuming does not work: https://github.com/TimDettmers/bitsandbytes/issues/782 # training_arguments.optim = "paged_adamw_32bit" @@ -445,7 +460,10 @@ def get_peft_wrapped_model( _device_map=None, _checkpoint_dir: Optional[str] = None, ): - acclerator_s = AcceleratorState() + dist_s = DistributedState( + world_size=training_arguments.world_size, + local_rank=training_arguments.local_rank, + ) # if _checkpoint_dir: # model = PeftModel.from_pretrained( # model=model, @@ -502,7 +520,7 @@ def get_peft_wrapped_model( model.enable_input_require_grads() model.print_trainable_parameters() - if acclerator_s.is_main_process: + if dist_s.is_main_process: _log_model_parameters(model) return model @@ -576,20 +594,85 @@ def check_if_model_will_fit_only_with_gpus( ) +def dist_get_last_checkpoint_for_resume_if_any( + training_arguments: HFTrainingArguments, + other_arguments: OtherArguments, +): + last_checkpoint_dir = None + dist_s = DistributedState( + world_size=training_arguments.world_size, + local_rank=training_arguments.local_rank, + ) + last_checkpoint_info_path = os.path.join(CACHE_DIR, "last_checkpoint_info.json") + if dist_s.is_main_process: + last_checkpoint_dir = get_last_checkpoint_for_resume_if_any( + output_dir=training_arguments.output_dir, + resume_from_checkpoint=training_arguments.resume_from_checkpoint, + mlfoundry_enable_reporting=other_arguments.mlfoundry_enable_reporting, + mlfoundry_ml_repo=other_arguments.mlfoundry_ml_repo, + mlfoundry_checkpoint_artifact_name=other_arguments.mlfoundry_checkpoint_artifact_name, + ) + with open(last_checkpoint_info_path, "w") as f: + last_checkpoint_info = {"last_checkpoint_dir": last_checkpoint_dir} + json.dump(last_checkpoint_info, f) + else: + with open(last_checkpoint_info_path, "r") as f: + last_checkpoint_info = json.load(f) + last_checkpoint_dir = last_checkpoint_info["last_checkpoint_dir"] + + return last_checkpoint_dir + + +def dist_build_dataset( + train_data, + eval_data, + tokenizer, + max_length, + train_on_prompt: bool, + training_arguments: HFTrainingArguments, +): + dist_s = DistributedState( + world_size=training_arguments.world_size, + local_rank=training_arguments.local_rank, + ) + logger.info("Building dataset...") + dataset_cache_path = os.path.join(CACHE_DIR, "dataset") + if dist_s.is_main_process: + dataset_dict = build_dataset( + train_data=train_data, + eval_data=eval_data, + tokenizer=tokenizer, + max_length=max_length, + train_on_prompt=train_on_prompt, + ) + dataset_dict.save_to_disk(dataset_cache_path) + else: + logger.info("Loading datasets from cache ...") + dataset_dict = DatasetDict.load_from_disk(dataset_cache_path) + dataset_dict = dataset_dict.with_format("torch") + train_dataset, eval_dataset = dataset_dict["train"], dataset_dict["eval"] + logger.info(f"Train data size: {len(train_dataset)}") + logger.info(f"Eval data size: {len(eval_dataset)}") + return train_dataset, eval_dataset + + def _train( *, training_arguments: HFTrainingArguments, other_arguments: OtherArguments, run: Optional[mlfoundry.MlFoundryRun] = None, ): - accelerator_s = AcceleratorState() + dist_s = DistributedState( + world_size=training_arguments.world_size, + local_rank=training_arguments.local_rank, + ) set_seed(training_arguments.seed) - if not accelerator_s.is_main_process: + if not dist_s.is_main_process: logger.info("Waiting for main process to load data, process it and fetch any checkpoints ...") - with accelerator_s.main_process_first(): - if accelerator_s.is_main_process: + with dist_s.main_process_first(): + if dist_s.is_main_process: train_data, eval_data = get_data( train_data=other_arguments.train_data, eval_data=other_arguments.eval_data, @@ -600,13 +683,8 @@ def _train( else: train_data, eval_data = None, None - last_checkpoint_dir = get_last_checkpoint_for_resume_if_any( - cache_dir=CACHE_DIR, - output_dir=training_arguments.output_dir, - resume_from_checkpoint=training_arguments.resume_from_checkpoint, - mlfoundry_enable_reporting=other_arguments.mlfoundry_enable_reporting, - mlfoundry_ml_repo=other_arguments.mlfoundry_ml_repo, - mlfoundry_checkpoint_artifact_name=other_arguments.mlfoundry_checkpoint_artifact_name, + last_checkpoint_dir = dist_get_last_checkpoint_for_resume_if_any( + training_arguments=training_arguments, other_arguments=other_arguments ) logger.info("Loading config ...") @@ -625,22 +703,21 @@ def _train( model_config=model_config, ) - train_dataset, eval_dataset = build_dataset( + train_dataset, eval_dataset = dist_build_dataset( train_data=train_data, eval_data=eval_data, tokenizer=tokenizer, max_length=max_length, train_on_prompt=other_arguments.train_on_prompt, - cache_dir=CACHE_DIR, + training_arguments=training_arguments, ) - if accelerator_s.is_main_process: - logger.info("Getting other ranks in sync with main process") - no_of_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0 device_map = None - if not training_arguments.deepspeed: - if other_arguments.use_ddp and no_of_gpus > 1: + if training_arguments.deepspeed: + device_map = None + elif other_arguments.use_ddp: + if no_of_gpus > 1: device_map = {"": "cuda:" + str(training_arguments.local_rank)} else: device_map = "auto" @@ -700,7 +777,7 @@ def _train( early_stopping_threshold=other_arguments.early_stopping_threshold, ) ) - trainer = Trainer( + trainer = HFTrainer( model=model, tokenizer=tokenizer, train_dataset=train_dataset, @@ -712,7 +789,7 @@ def _train( trainer.train(resume_from_checkpoint=last_checkpoint_dir) - accelerator_s.wait_for_everyone() + dist_s.wait_for_everyone() logger.info("Saving model...") if training_arguments.deepspeed and is_deepspeed_zero3_enabled() and EXPORT_ZERO3_CHECKPOINT_TO_FP32: @@ -722,29 +799,33 @@ def _train( # then an additional pytorch_model.bin is saved as a 16-bit checkpoint # if we want fp32 pytorch_model.bin then we would have to export separately from the checkpoint in zero format trainer.save_model(output_dir=training_arguments.output_dir) - if accelerator_s.is_main_process: + if dist_s.is_main_process: fp32_weights_path = os.path.join(training_arguments.output_dir, WEIGHTS_NAME) convert_zero_checkpoint_to_fp32_state_dict(trainer.state.best_model_checkpoint, fp32_weights_path) cleanup_checkpoints(output_dir=training_arguments.output_dir) else: - if accelerator_s.is_main_process: + if dist_s.is_main_process: cleanup_checkpoints(output_dir=training_arguments.output_dir) trainer.save_model(output_dir=training_arguments.output_dir) - accelerator_s.wait_for_everyone() + dist_s.wait_for_everyone() def train(training_arguments: HFTrainingArguments, other_arguments: OtherArguments): - Accelerator() - accelerator_s = AcceleratorState() + dist_s = DistributedState( + world_size=training_arguments.world_size, + local_rank=training_arguments.local_rank, + ) timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%S") logger.info(f"Training Arguments: {training_arguments}") logger.info(f"Arguments: {other_arguments}") - if other_arguments.use_lora or other_arguments.use_qlora: + if other_arguments.use_qlora: if not torch.cuda.is_available() or torch.cuda.device_count() < 1: raise RuntimeError("No GPUs detected. We need at least one gpu available for Lora/QLora finetuning!") - # TODO (chiragjn): Support LoRA and QLoRA with deepspeed + + if other_arguments.use_lora: + # TODO (chiragjn): Support LoRA with deepspeed if training_arguments.deepspeed: raise ValueError( "deepspeed is currently not supported with lora/qlora fine-tuning please try fine-tuning without deepspeed" @@ -753,7 +834,7 @@ def train(training_arguments: HFTrainingArguments, other_arguments: OtherArgumen setup(training_arguments=training_arguments, other_arguments=other_arguments) run = None - if accelerator_s.is_main_process and other_arguments.mlfoundry_enable_reporting: + if dist_s.is_main_process and other_arguments.mlfoundry_enable_reporting: mlfoundry_client = mlfoundry.get_client() if not other_arguments.mlfoundry_run_name: fallback_run_name = f"finetune-{timestamp}" @@ -786,12 +867,12 @@ def train(training_arguments: HFTrainingArguments, other_arguments: OtherArgumen # run.log_params(training_arguments.to_sanitized_dict(), flatten_params=True) # Disk space management - if accelerator_s.is_main_process: + if dist_s.is_main_process: if other_arguments.cleanup_output_dir_on_start and os.path.exists(training_arguments.output_dir): logger.warning(f"--cleanup_output_dir_on_start was to set to True, wiping {training_arguments.output_dir}") shutil.rmtree(training_arguments.output_dir) - if accelerator_s.is_main_process: + if dist_s.is_main_process: if other_arguments.use_lora or other_arguments.use_qlora: check_if_model_will_fit_only_with_gpus( training_arguments=training_arguments, other_arguments=other_arguments @@ -804,11 +885,11 @@ def train(training_arguments: HFTrainingArguments, other_arguments: OtherArgumen ) _cleanup_gpus() - if accelerator_s.is_main_process: + if dist_s.is_main_process: if other_arguments.use_lora or other_arguments.use_qlora: merge_adapters_if_any(training_arguments=training_arguments, other_arguments=other_arguments) - if accelerator_s.is_main_process and run: + if dist_s.is_main_process and run: *_, model_name = other_arguments.model_id.rsplit("/", 1) model_name = "-".join(["finetuned", model_name, timestamp]) model_name = model_name.replace(".", "-")