From f223f9d2f409d347eec46c81dc29c001744197bc Mon Sep 17 00:00:00 2001 From: Will Constable Date: Tue, 14 May 2024 21:44:11 -0700 Subject: [PATCH] Add Pipeline Parallel (and 2D PP+FSDP) support runs PP+DP and PP+TP without issue, runs PP+TP+DP with decreasing loss, but fails DCP save Supports only simple schedules currently, gpipe and 1f1b. Ads cmdline/toml arg for specifiying split points, in a unified way between tracer or manual frontend. e.g. user can specifiy "layers.2,layers.4" as split points. Currently uses manual frontend by default, but allows specifying tracer frontend. Tracer frontend requires working around additional compatibility limitations, indicated by raising assertions, and is not ready for wider use yet. ghstack-source-id: 7a1b6ea024726bc7bf2430854c8088b77ff4e29e Pull Request resolved: https://github.com/pytorch/torchtitan/pull/318 --- .github/workflows/unit_test_4gpu.yaml | 3 +- create_seed_checkpoint.sh | 2 +- test_runner.py | 97 ++++++++- torchtitan/config_manager.py | 58 +++++- torchtitan/parallelisms/__init__.py | 6 +- torchtitan/parallelisms/parallelize_llama.py | 200 ++++++++++++++++++- torchtitan/parallelisms/pipelining_utils.py | 51 +++++ train.py | 79 ++++++-- train_configs/debug_model.toml | 4 +- 9 files changed, 464 insertions(+), 36 deletions(-) create mode 100644 torchtitan/parallelisms/pipelining_utils.py diff --git a/.github/workflows/unit_test_4gpu.yaml b/.github/workflows/unit_test_4gpu.yaml index b62dbe73..6fbbcc2c 100644 --- a/.github/workflows/unit_test_4gpu.yaml +++ b/.github/workflows/unit_test_4gpu.yaml @@ -25,7 +25,8 @@ jobs: python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 python -m pip install -r requirements.txt python -m pip install -r dev-requirements.txt - mkdir artifacts-to-be-uploaded + python -m pip install git + python -m pip install git+https://github.com/pytorch/pippy python ./test_runner.py artifacts-to-be-uploaded # upload-coverage: # - name: Upload Coverage to Codecov diff --git a/create_seed_checkpoint.sh b/create_seed_checkpoint.sh index 38bab219..1abc77ec 100755 --- a/create_seed_checkpoint.sh +++ b/create_seed_checkpoint.sh @@ -25,7 +25,7 @@ LOG_RANK=0 CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"} seed_checkpoint="--checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint" -force_1d="--training.data_parallel_degree 1 --training.tensor_parallel_degree 1 --training.pipeline_parallel_degree 1" +force_1d="--training.data_parallel_degree 1 --training.tensor_parallel_degree 1 --experimental.pipeline_parallel_degree 1" overrides="" if [ $# -ne 0 ]; then overrides="$*" diff --git a/test_runner.py b/test_runner.py index 03b0a485..c1b729d6 100755 --- a/test_runner.py +++ b/test_runner.py @@ -30,6 +30,8 @@ class OverrideDefinitions: override_args: Sequence[Sequence[str]] = tuple(tuple(" ")) test_descr: str = "default" + requires_seed_checkpoint: bool = False + ngpu: int = 4 CONFIG_DIR = "./train_configs" @@ -88,13 +90,83 @@ class OverrideDefinitions: ], "Checkpoint Integration Test - Save Model Weights Only bf16", ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + f"--job.dump_folder {args.output_dir}/pp/", + "--experimental.pipeline_parallel_degree 2", + "--experimental.pipeline_parallel_split_points layers.1", + "--training.data_parallel_degree 1", + "--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue + ], + ], + "PP 1D test", + requires_seed_checkpoint=True, + ngpu=2, + ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + f"--job.dump_folder {args.output_dir}/pp_dp/", + "--experimental.pipeline_parallel_degree 2", + "--experimental.pipeline_parallel_split_points layers.1", + "--training.data_parallel_degree 2", + "--model.norm_type fused_rmsnorm", + ], + ], + "PP+DP 2D test", + requires_seed_checkpoint=True, + ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + f"--job.dump_folder {args.output_dir}/pp_tp/", + "--experimental.pipeline_parallel_degree 2", + "--experimental.pipeline_parallel_split_points layers.1", + "--training.tensor_parallel_degree 2", + "--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue + ], + ], + "PP+TP 2D test", + requires_seed_checkpoint=True, + ), + # oh.. not enough GPUs? + # OverrideDefinitions( + # [ + # [ + # "--checkpoint.enable_checkpoint", + # f"--job.dump_folder {args.output_dir}/pp_dp_tp/", + # "--experimental.pipeline_parallel_degree 2", + # "--experimental.pipeline_parallel_split_points layers.1", + # "--training.data_parallel_degree 2", + # "--training.tensor_parallel_degree 2", + # "--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue + # ], + # ], + # "PP+DP+TP 3D test", + # requires_seed_checkpoint=True, + # ), ] +def _run_cmd(cmd): + return subprocess.run( + [cmd], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + shell=True, + ) + + def run_test(test_flavor: OverrideDefinitions, full_path: str): # run_test supports sequence of tests. for override_arg in test_flavor.override_args: - cmd = f"CONFIG_FILE={full_path} NGPU=4 LOG_RANK=0,1,2,3 ./run_llama_train.sh" + + cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK=0,1,2,3 ./run_llama_train.sh" if override_arg: cmd += ( " " + " ".join(override_arg) + f" --job.dump_folder {args.output_dir}" @@ -102,13 +174,22 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str): print( f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}=====" ) - result = subprocess.run( - [cmd], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - shell=True, - ) + + if test_flavor.requires_seed_checkpoint: + dump_folder_arg = None + for arg in override_arg: + if "--job.dump_folder" in arg: + dump_folder_arg = arg + assert ( + dump_folder_arg is not None + ), "Can't use seed checkpoint if folder is not specified" + print("Creating seed checkpoint") + result = _run_cmd( + f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh {dump_folder_arg}" + ) + print(result.stdout) + + result = _run_cmd(cmd) print(result.stdout) if result.returncode != 0: raise Exception( diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 1de3c82c..cbcccde2 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -17,6 +17,10 @@ from torchtitan.logging_utils import logger +def string_list(raw_arg): + return raw_arg.split(",") + + class JobConfig: """ A helper class to manage the train configuration. @@ -202,10 +206,56 @@ def __init__(self): help="Whether to apply loss parallel when sequence parallel is enabled", ) self.parser.add_argument( - "--training.pipeline_parallel_degree", + "--experimental.pipeline_parallel_degree", type=int, default=1, - help="Pipeline Parallelism degree. 1 means disabled.", + help=""" + Pipeline Parallelism degree, or number of ranks. 1 means disabled. + If using looped schedules, this still specifies the number of physical ranks, not the number + of stages. Stages per rank are inferred from split points degree, and schedule.""", + ) + self.parser.add_argument( + "--experimental.pipeline_parallel_split_points", + type=string_list, + nargs="+", + default=[], + help=""" + Specify comma-separated names of modules to use as the beginning of a split point. + + e.g. "layers.0,layers.2" will cause the model to be split into 3 stages, + the first containing all the layers up to layers.0, + the second containing layers.0 and up to layers.2, + the third containing layers.2 and all the remaining layers. + + Note: fully-automated splitting may be enabled in the future, + but currently the split points must be specified manually for both manual and tracer.""", + ) + self.parser.add_argument( + "--experimental.pipeline_parallel_schedule", + type=str, + choices=["1f1b", "gpipe"], + default="1f1b", + help=""" + Specify the Pipeline Parallel schedule to use. + + The schedule must be compatible with the split points and stages_per_rank. + + Looped schedules are not yet supported in torchtitan.""", + ) + self.parser.add_argument( + "--experimental.pipeline_parallel_split_mode", + type=str, + choices=["manual", "tracer"], + default="manual", + help=""" + Specify the split method (e.g. the Pipeline Parallelism Front End) + + "manual" means each rank will construct an nn.Module with the appropriate layers and .forward + implementation manually, and then wrap it in a PipelineStage. + + "tracer" means the full model will be initialized (via meta device) and then traced into a graph, + split via the provided split points, unflattened into an nn.Module, + and finally wrapped in a PipelineStage. tracer frontend is currently more experimental.""", ) self.parser.add_argument( "--training.compile", @@ -408,6 +458,10 @@ def parse_args_from_command_line( aux_parser.add_argument( "--" + arg, action="store_true" if val else "store_false" ) + elif arg == "experimental.pipeline_parallel_split_points": + # type inference breaks here, since the type is just 'list' and it ends up flattening + # e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...] + aux_parser.add_argument("--" + arg, type=string_list) else: aux_parser.add_argument("--" + arg, type=type(val)) diff --git a/torchtitan/parallelisms/__init__.py b/torchtitan/parallelisms/__init__.py index e791b832..7e1b21c7 100644 --- a/torchtitan/parallelisms/__init__.py +++ b/torchtitan/parallelisms/__init__.py @@ -9,12 +9,16 @@ from torch.distributed.device_mesh import init_device_mesh from torchtitan.logging_utils import logger -from torchtitan.parallelisms.parallelize_llama import parallelize_llama +from torchtitan.parallelisms.parallelize_llama import parallelize_llama, pipeline_llama models_parallelize_fns = { "llama2": parallelize_llama, "llama3": parallelize_llama, } +models_pipelining_fns = { + "llama2": pipeline_llama, + "llama3": pipeline_llama, +} @dataclass diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 9c8d0a29..4b4ce0ba 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -8,10 +8,20 @@ # llama model, i.e. activation checkpointing, etc. from collections import defaultdict -from typing import Tuple +from typing import Dict, Tuple import torch +try: + from pippy import ManualPipelineStage, pipeline, SplitPoint + from pippy._PipelineStage import _PipelineStage +except ImportError as exc: + raise ImportError( + "pippy is not installed. Please install it to use pipeline parallelism. " + "`pip install git+https://github.com/pytorch/pippy`" + ) from exc + + from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy from torch.distributed._tensor import Replicate, Shard from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( @@ -30,7 +40,7 @@ from torchtitan.config_manager import JobConfig from torchtitan.logging_utils import logger - +from torchtitan.parallelisms.pipelining_utils import split_stage_fqns # for selective AC no_recompute_list = { @@ -129,15 +139,179 @@ def get_tp_parallel_strategy( return RowwiseParallel, ColwiseParallel +def _llama_fqns(num_layers): + return ( + [ + "tok_embeddings", + ] + + [f"layers.{i}" for i in range(num_layers)] + + [ + "norm", + "output", + ] + ) + + +def pipeline_llama( + model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict +): + if job_config.experimental.pipeline_parallel_split_mode == "manual": + return pipeline_llama_manual( + model, world_mesh, parallel_dims, job_config, device, model_config + ) + elif job_config.experimental.pipeline_parallel_split_mode == "tracer": + return pipeline_llama_tracer( + model, world_mesh, parallel_dims, job_config, device, model_config + ) + else: + raise NotImplementedError( + f"{job_config.experimental.pipeline_parallel_split_mode} is not a valid split mode" + ) + + +def _llama_trace_input(job_config, model_config, device="meta"): + """Get meta tensors with the right input shapes used for tracing""" + tokens_shape = (job_config.training.batch_size, job_config.training.seq_len) + tokens = torch.randint( + model_config.vocab_size, tokens_shape, dtype=torch.int64, device=device + ) + return (tokens,) + + +def pipeline_llama_manual( + model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict +): + """ + This API gets individual torch.nn.Module objects for each pipeline stage (including virtual stages). + + The SPMD parallelisms should be applied to + """ + pp_mesh = world_mesh["pp"] + pp_rank = pp_mesh.get_local_rank() + pp_size = pp_mesh.size() + # heuristically == PP dim but should be a config + microbatches = parallel_dims.pp + stage_idx = pp_rank # TODO support virtual stages + this_stage_layer_names = split_stage_fqns( + _llama_fqns(len(model.layers)), + job_config.experimental.pipeline_parallel_split_points, + pp_rank, + ) + + if pp_rank == 0: + model.norm = None + model.output = None + elif pp_rank == pp_size - 1: + model.tok_embeddings = None + names = list(model.layers.keys()) + for name in names: + if f"layers.{name}" not in this_stage_layer_names: + del model.layers[name] + + logger.info(f"PP rank {pp_rank} is using this model chunk\n{model}") + + # TODO(whc) once ManualPipelineStage supports lazy shape inference, we can leave model on meta device longer and + # get rid of the input shape hardcoded here. For now, it should not be a big deal since we only materialize the + # layers of the model that map to this stage, not the whole model. + + # Get example input + if pp_rank == 0: + input_shape = (job_config.training.batch_size, job_config.training.seq_len) + input = torch.randint( + model_config.vocab_size, input_shape, dtype=torch.int64, device=device + ) + + # HACK- can't use shape inference via execution of the PP stage inside ManualPipelineStage API, becuase the + # real output shapes will change after applying TP. So we hardcode output shapes here, and thus bypass doing + # shape inference. + # the real fix is to use lazy shape inference during first PP forward, and not need to specify anything here. + output_shape = ( + job_config.training.batch_size, + int(job_config.training.seq_len // parallel_dims.tp), + model_config.dim, + ) + output = torch.empty(output_shape, dtype=torch.float32, device=device) + else: + # TODO(whc) can we rely on shape inference so that user doesn't have to compute TP impact on seq_len + input_shape = ( + job_config.training.batch_size, + int(job_config.training.seq_len // parallel_dims.tp), + model_config.dim, + ) + input = torch.randint( + model_config.vocab_size, input_shape, dtype=torch.float32, device=device + ) + # TODO wrong shape, need to consider output layer + output_shape = ( + job_config.training.batch_size, + int(job_config.training.seq_len // parallel_dims.tp), + model_config.dim, + ) + output = torch.empty(output_shape, dtype=torch.float32, device=device) + + model.to_empty(device=device) + stage = ManualPipelineStage( + model, + pp_rank, + pp_size, + device, + microbatches, + input_args=input.chunk(microbatches)[0], + output_args=output.chunk(microbatches)[0], + group=pp_mesh.get_group("pp"), + ) + return (stage, model) + + +def pipeline_llama_tracer( + model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict +): + if job_config.model.norm_type == "fused_rmsnorm": + # TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode + # coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm + raise NotImplementedError( + "fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm." + ) + + # TODO(whc) maybe we can just fix this by feeding bf16 into the tracer for its input shapes? + raise NotImplementedError( + "pipeline tracer doesn't work with fsdp mixed precision currently. " + "To work around, edit fsdp mixed precision config to use fp32." + ) + pp_mesh = world_mesh["pp"] + pp_rank = pp_mesh.get_local_rank() + stage_idx = pp_mesh.get_local_rank() + layers_per_rank = len(model.layers) // parallel_dims.pp + split_spec = { + f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING + for i in range(1, parallel_dims.pp) + } + + # Create a pipeline representation from the model + pipe = pipeline( + model, + parallel_dims.pp, + example_args=_llama_trace_input(job_config, model_config), + split_spec=split_spec, + ) + model = pipe.get_stage_module(stage_idx) + stage = _PipelineStage( + stage_module=model, + stage_index=pp_rank, + pipe_info=pipe.pipe_info, + device=device, + group=pp_mesh.get_group(), + ) + return (stage, model) + + def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): """ - Apply parallelisms and activation checkpointing to the model. + Apply SPMD parallelisms and activation checkpointing to the model. NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ - if parallel_dims.pp_enabled: - raise NotImplementedError("PP not implemented yet.") if parallel_dims.tp_enabled: if job_config.model.norm_type == "fused_rmsnorm": @@ -211,7 +385,8 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names # TODO: Expose `reduce_dtype` as a config option. mp_policy = MixedPrecisionPolicy( - param_dtype=torch.bfloat16, reduce_dtype=torch.float32 + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, ) ac_mode = job_config.activation_checkpoint.mode fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} @@ -221,15 +396,22 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): transformer_block, job_config.activation_checkpoint ) # As an optimization, do not reshard after forward for the last - # transformer block since FSDP would prefetch it immediately - reshard_after_forward = int(layer_id) < len(model.layers) - 1 + # transformer block since FSDP would prefetch it immediately. + # When using Pipeline Parallelism, generally zero-2 is best so as to avoid repeated reshardings + # per microbatch. + reshard_after_forward = ( + int(layer_id) < len(model.layers) - 1 and not parallel_dims.pp_enabled + ) fully_shard( transformer_block, **fsdp_config, reshard_after_forward=reshard_after_forward, ) model.layers[layer_id] = transformer_block - model = fully_shard(model, **fsdp_config) + + model = fully_shard( + model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled + ) if ac_mode in ("full", "selective"): logger.info(f"Applied {ac_mode} activation checkpointing to the model") logger.info("Applied FSDP to the model") diff --git a/torchtitan/parallelisms/pipelining_utils.py b/torchtitan/parallelisms/pipelining_utils.py new file mode 100644 index 00000000..5c2ecff3 --- /dev/null +++ b/torchtitan/parallelisms/pipelining_utils.py @@ -0,0 +1,51 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# TODO(whc) this can be removed after pippy migration into pytorch core is complete. +try: + from pippy import Schedule1F1B, ScheduleGPipe +except ImportError as exc: + raise ImportError( + "pippy is not installed. Please install it to use pipeline parallelism. " + "`pip install git+https://github.com/pytorch/pippy`" + ) from exc + + +def build_pipeline_schedule(job_config, parallel_dims, stage, loss_fn): + if job_config.experimental.pipeline_parallel_schedule == "1f1b": + schedule_class = Schedule1F1B + elif job_config.experimental.pipeline_parallel_schedule == "gpipe": + schedule_class = ScheduleGPipe + else: + raise NotImplementedError( + f"{job_config.experimental.pipeline_parallel_schedule} is not implemented" + ) + return schedule_class( + stage, + n_microbatches=parallel_dims.pp, + loss_fn=loss_fn, + ) + + +def split_stage_fqns(fqns, split_points, stage_id): + """Helper for splitting ordered list of layer names into layers per stage. + + split_points is a list of layer names, each layer will be the first layer in a stage + """ + stages = [] + cur = [] + + for name in fqns: + if name in split_points: + assert len( + cur + ), f"{name} is not a valid split point, do not specify the first layer of stage 0" + stages.append(cur) + cur = [] + cur.append(name) + + stages.append(cur) + return stages[stage_id] diff --git a/train.py b/train.py index 318c7174..f6844fe7 100644 --- a/train.py +++ b/train.py @@ -20,6 +20,7 @@ import torch import torch.nn.functional as F from torch.distributed import destroy_process_group +from torch.distributed._composable.fsdp.fully_shard import FSDPModule from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed.tensor.parallel import loss_parallel @@ -32,7 +33,12 @@ from torchtitan.lr_scheduling import get_lr_scheduler from torchtitan.metrics import build_gpu_memory_monitor, build_metric_logger from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config -from torchtitan.parallelisms import models_parallelize_fns, ParallelDims +from torchtitan.parallelisms import ( + models_parallelize_fns, + models_pipelining_fns, + ParallelDims, +) +from torchtitan.parallelisms.pipelining_utils import build_pipeline_schedule from torchtitan.profiling import maybe_enable_profiling from torchtitan.utils import ( Color, @@ -122,11 +128,12 @@ def main(job_config: JobConfig): parallel_dims = ParallelDims( dp=job_config.training.data_parallel_degree, tp=job_config.training.tensor_parallel_degree, - pp=job_config.training.pipeline_parallel_degree, + pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, enable_loss_parallel=job_config.training.enable_loss_parallel, ) - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") + torch.cuda.set_device(device) init_distributed(job_config) world_mesh = parallel_dims.build_mesh(device_type="cuda") @@ -144,6 +151,10 @@ def main(job_config: JobConfig): dp_rank = dp_mesh.get_local_rank() else: dp_degree, dp_rank = 1, 0 + + if parallel_dims.pp_enabled: + pp_mesh = world_mesh["pp"] + data_loader = build_hf_data_loader( job_config.training.dataset, job_config.training.dataset_path, @@ -201,13 +212,26 @@ def loss_fn(pred, labels): # obtain the peak flops of bf16 type for MFU calculation gpu_peak_flops = get_peak_flops(gpu_memory_monitor.device_name) - # apply PT-D parallelisms and activation checkpointing + if parallel_dims.pp_enabled: + stage, model = models_pipelining_fns[model_name]( + model, world_mesh, parallel_dims, job_config, device, model_config + ) + + # apply PT-D DP/TP parallelisms and activation checkpointing model = models_parallelize_fns[model_name]( model, world_mesh, parallel_dims, job_config ) - # allocate sharded model on GPU and initialize weights via DTensor + model.to_empty(device="cuda") - model.init_weights() + + if parallel_dims.pp_enabled: + pp_schedule = build_pipeline_schedule(job_config, parallel_dims, stage, loss_fn) + else: + # If PP is enabled, we can't rely on init_weights, because some layers are missing. + # In the future, we may make init_weights handle missing layers, but also have to consider RNG seed propagation. + + # allocate sharded model on GPU and initialize weights via DTensor + model.init_weights() gpu_mem_stats = gpu_memory_monitor.get_peak_stats() logger.info( @@ -216,6 +240,10 @@ def loss_fn(pred, labels): f"({gpu_mem_stats.max_reserved_pct:.2f}%)" ) + if isinstance(model, FSDPModule) and parallel_dims.pp_enabled: + # reshard now to counteract an issue where FSDP's states got advanced during PP stage shape inference + model.reshard() + # build optimizer after applying parallelisms to the model optimizer = build_optimizer(model, job_config) scheduler = get_lr_scheduler(optimizer, job_config) @@ -257,7 +285,13 @@ def loss_fn(pred, labels): logger.info("Created seed checkpoint") return - checkpoint.load() + checkpoint_loaded = checkpoint.load() + + if parallel_dims.pp_enabled and not checkpoint_loaded: + raise RuntimeError( + "Pipeline Parallelism requires meta-initialization and loading seed checkpoint. " + "Please run `./create_seed_checkpoint.sh` and rerun training with `--checkpoint.enable_checkpoint`" + ) # plot losses loaded from checkpoint (if any) to TensorBoard # NOTE: Loss info after the last log step before checkpoint saving will not be ploted. @@ -299,14 +333,33 @@ def loss_fn(pred, labels): input_ids = input_ids.cuda() labels = labels.cuda() - optimizer.zero_grad() - # forward / backward - with loss_parallel_ctx(): - pred = model(input_ids) - loss = loss_fn(pred, labels) - loss.backward() + if parallel_dims.pp_enabled: + # pipeline parallel forward / backward inside step() call + is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 + + with loss_parallel_ctx(): + if pp_mesh.get_local_rank() == 0: + pp_schedule.step(input_ids) + elif is_last_stage: + losses = [] + pp_schedule.step(target=labels, losses=losses) + else: + schedule.step() + + # accumulate losses across pipeline microbatches + loss = ( + torch.mean(torch.stack(losses)) + if is_last_stage + else torch.Tensor([-1.0]) + ) + else: + # Non-PP forward / backward + with loss_parallel_ctx(): + pred = model(input_ids) + loss = loss_fn(pred, labels) + loss.backward() # clip gradients torch.nn.utils.clip_grad_norm_( diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 4541fec7..009348b5 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -36,11 +36,13 @@ max_norm = 1.0 # grad norm clipping steps = 10 data_parallel_degree = -1 tensor_parallel_degree = 1 -pipeline_parallel_degree = 1 fp8_linear = "" compile = false dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M) +[experimental] +pipeline_parallel_degree = 1 + [checkpoint] enable_checkpoint = false folder = "checkpoint"