Skip to content

Commit

Permalink
WIP: Enable qlora with deepspeed
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragjn committed Dec 12, 2023
1 parent 3c81ed0 commit eac9c2a
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 122 deletions.
99 changes: 43 additions & 56 deletions checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import tempfile
from typing import Optional, Union

from accelerate.state import AcceleratorState
from transformers.trainer_utils import get_last_checkpoint

from dist_utils import DistributedState
from mlfoundry_utils import (
download_mlfoundry_artifact,
get_checkpoint_artifact_version_with_step_or_none,
Expand Down Expand Up @@ -104,69 +104,56 @@ def get_best_checkpoint_for_resume_if_any(


def get_last_checkpoint_for_resume_if_any(
cache_dir,
output_dir,
resume_from_checkpoint: Optional[Union[bool, str]],
mlfoundry_enable_reporting: bool,
mlfoundry_ml_repo: Optional[str],
mlfoundry_checkpoint_artifact_name: Optional[str],
) -> Optional[str]:
accelerator_s = AcceleratorState()
last_checkpoint_info_path = os.path.join(cache_dir, "last_checkpoint_info.json")
last_checkpoint_dir = None
if accelerator_s.is_main_process:
check_mlfoundry = False
# resume_from_checkpoint can be None/true/false/string, None is default
if resume_from_checkpoint is None:
# If no explicit choice has been made we will try and check with mlfoundry we are allowed to
check_mlfoundry = False
# resume_from_checkpoint can be None/true/false/string, None is default
if resume_from_checkpoint is None:
# If no explicit choice has been made we will try and check with mlfoundry we are allowed to
check_mlfoundry = True
elif isinstance(resume_from_checkpoint, str):
# If an explicit choice has been made we will check if the checkpoint exists on disk
if os.path.exists(resume_from_checkpoint):
last_checkpoint_dir = resume_from_checkpoint
else:
raise ValueError(f"Provided path for --resume_from_checkpoint `{resume_from_checkpoint}` does not exist!")
# TODO (chiragjn): Add support for resuming from an already saved checkpoint outside of the job run
# Although this is risky, because all other args (model, data, state) should remain same for a "correct" resume
# Note: Instead if we just want to resume from last checkpoint of the same job run then just use --mlfoundry_enable_reporting true --mlfoundry_checkpoint_artifact_name <name>
# elif _is_mlfoundry_artifact(training_arguments.resume_from_checkpoint):
# _download_mlfoundry_artifact(...)
elif resume_from_checkpoint is True:
# If set to true, we will automatically locate the latest checkpoint, first checking output dir, next mlfoundry if we are allowed to
if os.path.exists(output_dir):
possible_last_checkpoint_dir = get_last_checkpoint(output_dir)
if possible_last_checkpoint_dir:
last_checkpoint_dir = possible_last_checkpoint_dir

if not last_checkpoint_dir:
check_mlfoundry = True
elif isinstance(resume_from_checkpoint, str):
# If an explicit choice has been made we will check if the checkpoint exists on disk
if os.path.exists(resume_from_checkpoint):
last_checkpoint_dir = resume_from_checkpoint
else:
raise ValueError(
f"Provided path for --resume_from_checkpoint `{resume_from_checkpoint}` does not exist!"
)
# TODO (chiragjn): Add support for resuming from an already saved checkpoint outside of the job run
# Although this is risky, because all other args (model, data, state) should remain same for a "correct" resume
# Note: Instead if we just want to resume from last checkpoint of the same job run then just use --mlfoundry_enable_reporting true --mlfoundry_checkpoint_artifact_name <name>
# elif _is_mlfoundry_artifact(training_arguments.resume_from_checkpoint):
# _download_mlfoundry_artifact(...)
elif resume_from_checkpoint is True:
# If set to true, we will automatically locate the latest checkpoint, first checking output dir, next mlfoundry if we are allowed to
if os.path.exists(output_dir):
possible_last_checkpoint_dir = get_last_checkpoint(output_dir)
if possible_last_checkpoint_dir:
last_checkpoint_dir = possible_last_checkpoint_dir

if not last_checkpoint_dir:
check_mlfoundry = True

if check_mlfoundry and mlfoundry_enable_reporting and mlfoundry_checkpoint_artifact_name:
logger.info("Checking for any past checkpoints from same job run...")
last_checkpoint_dir = download_last_checkpoint_if_present(
ml_repo=mlfoundry_ml_repo,
checkpoint_artifact_name=mlfoundry_checkpoint_artifact_name,
local_dir=output_dir,
)

with open(last_checkpoint_info_path, "w") as f:
last_checkpoint_info = {"last_checkpoint_dir": last_checkpoint_dir}
json.dump(last_checkpoint_info, f)

if last_checkpoint_dir:
_ = get_best_checkpoint_for_resume_if_any(
output_dir=output_dir,
last_checkpoint_dir=last_checkpoint_dir,
mlfoundry_enable_reporting=mlfoundry_enable_reporting,
mlfoundry_ml_repo=mlfoundry_ml_repo,
mlfoundry_checkpoint_artifact_name=mlfoundry_checkpoint_artifact_name,
)
else:
with open(last_checkpoint_info_path, "r") as f:
last_checkpoint_info = json.load(f)
last_checkpoint_dir = last_checkpoint_info["last_checkpoint_dir"]

if check_mlfoundry and mlfoundry_enable_reporting and mlfoundry_checkpoint_artifact_name:
logger.info("Checking for any past checkpoints from same job run...")
last_checkpoint_dir = download_last_checkpoint_if_present(
ml_repo=mlfoundry_ml_repo,
checkpoint_artifact_name=mlfoundry_checkpoint_artifact_name,
local_dir=output_dir,
)

if last_checkpoint_dir:
_ = get_best_checkpoint_for_resume_if_any(
output_dir=output_dir,
last_checkpoint_dir=last_checkpoint_dir,
mlfoundry_enable_reporting=mlfoundry_enable_reporting,
mlfoundry_ml_repo=mlfoundry_ml_repo,
mlfoundry_checkpoint_artifact_name=mlfoundry_checkpoint_artifact_name,
)

return last_checkpoint_dir


Expand Down
40 changes: 13 additions & 27 deletions data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from urllib.parse import parse_qsl, urlparse

import torch
from accelerate.state import AcceleratorState
from cloudfiles import CloudFile
from datasets import Dataset, DatasetDict
from sklearn.model_selection import train_test_split
Expand Down Expand Up @@ -249,30 +248,17 @@ def build_dataset(
tokenizer,
max_length: int,
train_on_prompt: bool,
cache_dir: str,
):
accelerator_s = AcceleratorState()
logger.info("Building dataset...")
dataset_cache_path = os.path.join(cache_dir, "dataset")
if accelerator_s.is_main_process:
builder = CausalDatasetBuilder(tokenizer=tokenizer, max_length=max_length, train_on_prompt=train_on_prompt)
dataset_dict = DatasetDict(train=Dataset.from_list(train_data), eval=Dataset.from_list(eval_data))
# TODO (chiragjn): Read cpu limits from cgroup, cpu_count is not usable in containers environment
num_proc = max(1, min(4, os.cpu_count()))
num_proc = num_proc if num_proc > 1 else None
dataset_dict = dataset_dict.map(
builder.construct_dataset,
remove_columns=[PROMPT_KEY, COMPLETION_KEY],
batched=True,
batch_size=32,
num_proc=num_proc,
)
dataset_dict.save_to_disk(dataset_cache_path)
else:
logger.info("Loading datasets from cache ...")
dataset_dict = DatasetDict.load_from_disk(dataset_cache_path)
dataset_dict = dataset_dict.with_format("torch")
train_dataset, eval_dataset = dataset_dict["train"], dataset_dict["eval"]
logger.info(f"Train data size: {len(train_dataset)}")
logger.info(f"Eval data size: {len(eval_dataset)}")
return train_dataset, eval_dataset
builder = CausalDatasetBuilder(tokenizer=tokenizer, max_length=max_length, train_on_prompt=train_on_prompt)
dataset_dict = DatasetDict(train=Dataset.from_list(train_data), eval=Dataset.from_list(eval_data))
# TODO (chiragjn): Read cpu limits from cgroup, cpu_count is not usable in containers environment
num_proc = max(1, min(4, os.cpu_count()))
num_proc = num_proc if num_proc > 1 else None
dataset_dict = dataset_dict.map(
builder.construct_dataset,
remove_columns=[PROMPT_KEY, COMPLETION_KEY],
batched=True,
batch_size=32,
num_proc=num_proc,
)
return dataset_dict
36 changes: 36 additions & 0 deletions dist_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import contextlib
import logging

import torch.distributed

logger = logging.getLogger("truefoundry-finetune")


class DistributedState:
def __init__(self, world_size: int, local_rank: int):
self.world_size = world_size
self.local_rank = local_rank

@property
def is_distributed(self):
return self.world_size > 1

@property
def is_main_process(self):
return self.local_rank <= 0

@contextlib.contextmanager
def main_process_first(self):
if self.is_distributed:
if not self.is_main_process:
torch.distributed.barrier()
yield
if self.is_main_process:
logger.info("Getting other ranks in sync with main process")
torch.distributed.barrier()
else:
yield

def wait_for_everyone(self):
if self.is_distributed:
torch.distributed.barrier()
Loading

0 comments on commit eac9c2a

Please sign in to comment.