Skip to content

Commit

Permalink
Add LR scheduler and Fix Granite Checkpoint Loading (#80)
Browse files Browse the repository at this point in the history
This PR adds the LR scheduler configuration and also fixes the granite loading issue for multinode settings.

this PR subsumes the changes in @aldo-pareja 's lr_scheduler branch, so a seperate merge of that would not be necessary
the strategy now for converting the granite checkpoint, is that each node will elect one local rank to do the convert.
the conversion will be one on the args.output_dir. There is no assumption it is the same for all nodes or not. But assuming they might be the same, the tmpdir is shared by a .{group_rank} postfix. See group rank
tested on multi node

---------

Signed-off-by: aldo pareja-cardona <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Co-authored-by: aldo pareja-cardona <[email protected]>
  • Loading branch information
fabianlim and aldo-pareja authored Jun 28, 2024
1 parent afb251d commit 0d88f30
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 8 deletions.
30 changes: 28 additions & 2 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def setup_model(args, tokenizer, train_loader, grad_accum):
)

if args.is_granite:
with ensure_loadable_granite_checkpoint(args.model_name_or_path) as path:
with ensure_loadable_granite_checkpoint(
args.model_name_or_path, args.output_dir
) as path:
model = GPTDolomiteForCausalLM.from_pretrained(
path,
attn_implementation="flash_attention_2",
Expand Down Expand Up @@ -231,7 +233,7 @@ def make_inputs_require_grad(module, input, output):
)

lr_scheduler = get_scheduler(
name="cosine",
name=args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.num_warmup_steps,
num_training_steps=args.num_epochs * len(train_loader),
Expand Down Expand Up @@ -447,6 +449,13 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger):
if local_rank == 0:
inner_pb.update(1)
torch.cuda.empty_cache()
if args.save_last:
save_hf_format_ds(
args,
model,
tokenizer,
global_step * args.samples_per_gpu * world_size,
)


def main(args):
Expand Down Expand Up @@ -651,6 +660,20 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs):
# parser.add_argument("--samples_per_gpu", type=int, default=8)
parser.add_argument("--effective_batch_size", type=int, default=3840)
parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument(
"--lr_scheduler",
type=str,
default="cosine",
help="The scheduler type to use.",
choices=[
"linear",
"cosine",
"cosine_with_restarts",
"polynomial",
"constant",
"constant_with_warmup",
],
)
parser.add_argument("--num_warmup_steps", type=int, default=1000)
# parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--save_samples", type=int)
Expand All @@ -660,6 +683,9 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs):
help="for saving in ds native format",
default=None,
)
parser.add_argument(
"--save_last", action="store_true", help="save after finishing training"
)
parser.add_argument("--log_level", type=str, default="INFO")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--mock_data", action="store_true")
Expand Down
36 changes: 30 additions & 6 deletions src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from contextlib import contextmanager
from functools import partial
from pathlib import Path
from tempfile import TemporaryDirectory, mktemp
from tempfile import TemporaryDirectory
from typing import Any, List, Optional
import importlib
import inspect
Expand Down Expand Up @@ -482,7 +482,13 @@ class UniversalCheckpointArgs:


@contextmanager
def ensure_loadable_granite_checkpoint(model_name_or_path: str):
def ensure_loadable_granite_checkpoint(
model_name_or_path: str,
tmpdir: str,
):
local_rank = int(os.environ["LOCAL_RANK"])
group_rank = int(os.environ["GROUP_RANK"])

try:
GPTDolomiteConfig.from_pretrained(model_name_or_path)
yield model_name_or_path
Expand All @@ -493,15 +499,33 @@ def ensure_loadable_granite_checkpoint(model_name_or_path: str):
)
# if the load failed then it must not be a granite
# for now just assume its a llama
# with TemporaryDirectory("w") as tmpdir:
# make a temp directory name, but do not create it
tmpdir = mktemp()
if not dist.is_initialized() or dist.get_rank() == 0:
# previously we used mktemp, but it caused problems in multi node settings
# so now we use a provided tmpdir
# Assumption: tmpdir should be accessible by all ranks, even those
# in different nodes
tmpdir = Path(tmpdir) / f"tmp.{group_rank}"
if os.path.exists(tmpdir) and (not dist.is_initialized() or local_rank == 0):
# need to delete if it exists because import doesnt like it to
shutil.rmtree(tmpdir, ignore_errors=True)

if not dist.is_initialized() or local_rank == 0:
import_from_huggingface(model_name_or_path, tmpdir)

if dist.is_initialized():
# the first barrier is to wait for local rank 0 to finish converting the model
# and place into tmpdir
dist.barrier()

# return tmpdir out for loading
yield tmpdir
if not dist.is_initialized() or dist.get_rank() == 0:

if dist.is_initialized():
# the second barrier is to wait for all the models to finish loading
dist.barrier()

if not dist.is_initialized() or local_rank == 0:
# at this point, we can be confident that the tmpdir is no longer needed
shutil.rmtree(tmpdir, ignore_errors=True)


Expand Down

0 comments on commit 0d88f30

Please sign in to comment.