diff --git a/.github/workflows/unit_test_4gpu.yaml b/.github/workflows/unit_test_4gpu.yaml index b62dbe73..6fbbcc2c 100644 --- a/.github/workflows/unit_test_4gpu.yaml +++ b/.github/workflows/unit_test_4gpu.yaml @@ -25,7 +25,8 @@ jobs: python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 python -m pip install -r requirements.txt python -m pip install -r dev-requirements.txt - mkdir artifacts-to-be-uploaded + python -m pip install git + python -m pip install git+https://github.com/pytorch/pippy python ./test_runner.py artifacts-to-be-uploaded # upload-coverage: # - name: Upload Coverage to Codecov diff --git a/create_seed_checkpoint.sh b/create_seed_checkpoint.sh index 38bab219..1abc77ec 100755 --- a/create_seed_checkpoint.sh +++ b/create_seed_checkpoint.sh @@ -25,7 +25,7 @@ LOG_RANK=0 CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"} seed_checkpoint="--checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint" -force_1d="--training.data_parallel_degree 1 --training.tensor_parallel_degree 1 --training.pipeline_parallel_degree 1" +force_1d="--training.data_parallel_degree 1 --training.tensor_parallel_degree 1 --experimental.pipeline_parallel_degree 1" overrides="" if [ $# -ne 0 ]; then overrides="$*" diff --git a/test_runner.py b/test_runner.py index 03b0a485..c1b729d6 100755 --- a/test_runner.py +++ b/test_runner.py @@ -30,6 +30,8 @@ class OverrideDefinitions: override_args: Sequence[Sequence[str]] = tuple(tuple(" ")) test_descr: str = "default" + requires_seed_checkpoint: bool = False + ngpu: int = 4 CONFIG_DIR = "./train_configs" @@ -88,13 +90,83 @@ class OverrideDefinitions: ], "Checkpoint Integration Test - Save Model Weights Only bf16", ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + f"--job.dump_folder {args.output_dir}/pp/", + "--experimental.pipeline_parallel_degree 2", + "--experimental.pipeline_parallel_split_points layers.1", + "--training.data_parallel_degree 1", + "--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue + ], + ], + "PP 1D test", + requires_seed_checkpoint=True, + ngpu=2, + ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + f"--job.dump_folder {args.output_dir}/pp_dp/", + "--experimental.pipeline_parallel_degree 2", + "--experimental.pipeline_parallel_split_points layers.1", + "--training.data_parallel_degree 2", + "--model.norm_type fused_rmsnorm", + ], + ], + "PP+DP 2D test", + requires_seed_checkpoint=True, + ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + f"--job.dump_folder {args.output_dir}/pp_tp/", + "--experimental.pipeline_parallel_degree 2", + "--experimental.pipeline_parallel_split_points layers.1", + "--training.tensor_parallel_degree 2", + "--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue + ], + ], + "PP+TP 2D test", + requires_seed_checkpoint=True, + ), + # oh.. not enough GPUs? + # OverrideDefinitions( + # [ + # [ + # "--checkpoint.enable_checkpoint", + # f"--job.dump_folder {args.output_dir}/pp_dp_tp/", + # "--experimental.pipeline_parallel_degree 2", + # "--experimental.pipeline_parallel_split_points layers.1", + # "--training.data_parallel_degree 2", + # "--training.tensor_parallel_degree 2", + # "--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue + # ], + # ], + # "PP+DP+TP 3D test", + # requires_seed_checkpoint=True, + # ), ] +def _run_cmd(cmd): + return subprocess.run( + [cmd], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + shell=True, + ) + + def run_test(test_flavor: OverrideDefinitions, full_path: str): # run_test supports sequence of tests. for override_arg in test_flavor.override_args: - cmd = f"CONFIG_FILE={full_path} NGPU=4 LOG_RANK=0,1,2,3 ./run_llama_train.sh" + + cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK=0,1,2,3 ./run_llama_train.sh" if override_arg: cmd += ( " " + " ".join(override_arg) + f" --job.dump_folder {args.output_dir}" @@ -102,13 +174,22 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str): print( f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}=====" ) - result = subprocess.run( - [cmd], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - shell=True, - ) + + if test_flavor.requires_seed_checkpoint: + dump_folder_arg = None + for arg in override_arg: + if "--job.dump_folder" in arg: + dump_folder_arg = arg + assert ( + dump_folder_arg is not None + ), "Can't use seed checkpoint if folder is not specified" + print("Creating seed checkpoint") + result = _run_cmd( + f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh {dump_folder_arg}" + ) + print(result.stdout) + + result = _run_cmd(cmd) print(result.stdout) if result.returncode != 0: raise Exception( diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 1de3c82c..cbcccde2 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -17,6 +17,10 @@ from torchtitan.logging_utils import logger +def string_list(raw_arg): + return raw_arg.split(",") + + class JobConfig: """ A helper class to manage the train configuration. @@ -202,10 +206,56 @@ def __init__(self): help="Whether to apply loss parallel when sequence parallel is enabled", ) self.parser.add_argument( - "--training.pipeline_parallel_degree", + "--experimental.pipeline_parallel_degree", type=int, default=1, - help="Pipeline Parallelism degree. 1 means disabled.", + help=""" + Pipeline Parallelism degree, or number of ranks. 1 means disabled. + If using looped schedules, this still specifies the number of physical ranks, not the number + of stages. Stages per rank are inferred from split points degree, and schedule.""", + ) + self.parser.add_argument( + "--experimental.pipeline_parallel_split_points", + type=string_list, + nargs="+", + default=[], + help=""" + Specify comma-separated names of modules to use as the beginning of a split point. + + e.g. "layers.0,layers.2" will cause the model to be split into 3 stages, + the first containing all the layers up to layers.0, + the second containing layers.0 and up to layers.2, + the third containing layers.2 and all the remaining layers. + + Note: fully-automated splitting may be enabled in the future, + but currently the split points must be specified manually for both manual and tracer.""", + ) + self.parser.add_argument( + "--experimental.pipeline_parallel_schedule", + type=str, + choices=["1f1b", "gpipe"], + default="1f1b", + help=""" + Specify the Pipeline Parallel schedule to use. + + The schedule must be compatible with the split points and stages_per_rank. + + Looped schedules are not yet supported in torchtitan.""", + ) + self.parser.add_argument( + "--experimental.pipeline_parallel_split_mode", + type=str, + choices=["manual", "tracer"], + default="manual", + help=""" + Specify the split method (e.g. the Pipeline Parallelism Front End) + + "manual" means each rank will construct an nn.Module with the appropriate layers and .forward + implementation manually, and then wrap it in a PipelineStage. + + "tracer" means the full model will be initialized (via meta device) and then traced into a graph, + split via the provided split points, unflattened into an nn.Module, + and finally wrapped in a PipelineStage. tracer frontend is currently more experimental.""", ) self.parser.add_argument( "--training.compile", @@ -408,6 +458,10 @@ def parse_args_from_command_line( aux_parser.add_argument( "--" + arg, action="store_true" if val else "store_false" ) + elif arg == "experimental.pipeline_parallel_split_points": + # type inference breaks here, since the type is just 'list' and it ends up flattening + # e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...] + aux_parser.add_argument("--" + arg, type=string_list) else: aux_parser.add_argument("--" + arg, type=type(val)) diff --git a/torchtitan/parallelisms/__init__.py b/torchtitan/parallelisms/__init__.py index e791b832..7e1b21c7 100644 --- a/torchtitan/parallelisms/__init__.py +++ b/torchtitan/parallelisms/__init__.py @@ -9,12 +9,16 @@ from torch.distributed.device_mesh import init_device_mesh from torchtitan.logging_utils import logger -from torchtitan.parallelisms.parallelize_llama import parallelize_llama +from torchtitan.parallelisms.parallelize_llama import parallelize_llama, pipeline_llama models_parallelize_fns = { "llama2": parallelize_llama, "llama3": parallelize_llama, } +models_pipelining_fns = { + "llama2": pipeline_llama, + "llama3": pipeline_llama, +} @dataclass diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 9c8d0a29..4b4ce0ba 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -8,10 +8,20 @@ # llama model, i.e. activation checkpointing, etc. from collections import defaultdict -from typing import Tuple +from typing import Dict, Tuple import torch +try: + from pippy import ManualPipelineStage, pipeline, SplitPoint + from pippy._PipelineStage import _PipelineStage +except ImportError as exc: + raise ImportError( + "pippy is not installed. Please install it to use pipeline parallelism. " + "`pip install git+https://github.com/pytorch/pippy`" + ) from exc + + from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy from torch.distributed._tensor import Replicate, Shard from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( @@ -30,7 +40,7 @@ from torchtitan.config_manager import JobConfig from torchtitan.logging_utils import logger - +from torchtitan.parallelisms.pipelining_utils import split_stage_fqns # for selective AC no_recompute_list = { @@ -129,15 +139,179 @@ def get_tp_parallel_strategy( return RowwiseParallel, ColwiseParallel +def _llama_fqns(num_layers): + return ( + [ + "tok_embeddings", + ] + + [f"layers.{i}" for i in range(num_layers)] + + [ + "norm", + "output", + ] + ) + + +def pipeline_llama( + model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict +): + if job_config.experimental.pipeline_parallel_split_mode == "manual": + return pipeline_llama_manual( + model, world_mesh, parallel_dims, job_config, device, model_config + ) + elif job_config.experimental.pipeline_parallel_split_mode == "tracer": + return pipeline_llama_tracer( + model, world_mesh, parallel_dims, job_config, device, model_config + ) + else: + raise NotImplementedError( + f"{job_config.experimental.pipeline_parallel_split_mode} is not a valid split mode" + ) + + +def _llama_trace_input(job_config, model_config, device="meta"): + """Get meta tensors with the right input shapes used for tracing""" + tokens_shape = (job_config.training.batch_size, job_config.training.seq_len) + tokens = torch.randint( + model_config.vocab_size, tokens_shape, dtype=torch.int64, device=device + ) + return (tokens,) + + +def pipeline_llama_manual( + model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict +): + """ + This API gets individual torch.nn.Module objects for each pipeline stage (including virtual stages). + + The SPMD parallelisms should be applied to + """ + pp_mesh = world_mesh["pp"] + pp_rank = pp_mesh.get_local_rank() + pp_size = pp_mesh.size() + # heuristically == PP dim but should be a config + microbatches = parallel_dims.pp + stage_idx = pp_rank # TODO support virtual stages + this_stage_layer_names = split_stage_fqns( + _llama_fqns(len(model.layers)), + job_config.experimental.pipeline_parallel_split_points, + pp_rank, + ) + + if pp_rank == 0: + model.norm = None + model.output = None + elif pp_rank == pp_size - 1: + model.tok_embeddings = None + names = list(model.layers.keys()) + for name in names: + if f"layers.{name}" not in this_stage_layer_names: + del model.layers[name] + + logger.info(f"PP rank {pp_rank} is using this model chunk\n{model}") + + # TODO(whc) once ManualPipelineStage supports lazy shape inference, we can leave model on meta device longer and + # get rid of the input shape hardcoded here. For now, it should not be a big deal since we only materialize the + # layers of the model that map to this stage, not the whole model. + + # Get example input + if pp_rank == 0: + input_shape = (job_config.training.batch_size, job_config.training.seq_len) + input = torch.randint( + model_config.vocab_size, input_shape, dtype=torch.int64, device=device + ) + + # HACK- can't use shape inference via execution of the PP stage inside ManualPipelineStage API, becuase the + # real output shapes will change after applying TP. So we hardcode output shapes here, and thus bypass doing + # shape inference. + # the real fix is to use lazy shape inference during first PP forward, and not need to specify anything here. + output_shape = ( + job_config.training.batch_size, + int(job_config.training.seq_len // parallel_dims.tp), + model_config.dim, + ) + output = torch.empty(output_shape, dtype=torch.float32, device=device) + else: + # TODO(whc) can we rely on shape inference so that user doesn't have to compute TP impact on seq_len + input_shape = ( + job_config.training.batch_size, + int(job_config.training.seq_len // parallel_dims.tp), + model_config.dim, + ) + input = torch.randint( + model_config.vocab_size, input_shape, dtype=torch.float32, device=device + ) + # TODO wrong shape, need to consider output layer + output_shape = ( + job_config.training.batch_size, + int(job_config.training.seq_len // parallel_dims.tp), + model_config.dim, + ) + output = torch.empty(output_shape, dtype=torch.float32, device=device) + + model.to_empty(device=device) + stage = ManualPipelineStage( + model, + pp_rank, + pp_size, + device, + microbatches, + input_args=input.chunk(microbatches)[0], + output_args=output.chunk(microbatches)[0], + group=pp_mesh.get_group("pp"), + ) + return (stage, model) + + +def pipeline_llama_tracer( + model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict +): + if job_config.model.norm_type == "fused_rmsnorm": + # TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode + # coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm + raise NotImplementedError( + "fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm." + ) + + # TODO(whc) maybe we can just fix this by feeding bf16 into the tracer for its input shapes? + raise NotImplementedError( + "pipeline tracer doesn't work with fsdp mixed precision currently. " + "To work around, edit fsdp mixed precision config to use fp32." + ) + pp_mesh = world_mesh["pp"] + pp_rank = pp_mesh.get_local_rank() + stage_idx = pp_mesh.get_local_rank() + layers_per_rank = len(model.layers) // parallel_dims.pp + split_spec = { + f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING + for i in range(1, parallel_dims.pp) + } + + # Create a pipeline representation from the model + pipe = pipeline( + model, + parallel_dims.pp, + example_args=_llama_trace_input(job_config, model_config), + split_spec=split_spec, + ) + model = pipe.get_stage_module(stage_idx) + stage = _PipelineStage( + stage_module=model, + stage_index=pp_rank, + pipe_info=pipe.pipe_info, + device=device, + group=pp_mesh.get_group(), + ) + return (stage, model) + + def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): """ - Apply parallelisms and activation checkpointing to the model. + Apply SPMD parallelisms and activation checkpointing to the model. NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ - if parallel_dims.pp_enabled: - raise NotImplementedError("PP not implemented yet.") if parallel_dims.tp_enabled: if job_config.model.norm_type == "fused_rmsnorm": @@ -211,7 +385,8 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names # TODO: Expose `reduce_dtype` as a config option. mp_policy = MixedPrecisionPolicy( - param_dtype=torch.bfloat16, reduce_dtype=torch.float32 + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, ) ac_mode = job_config.activation_checkpoint.mode fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} @@ -221,15 +396,22 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): transformer_block, job_config.activation_checkpoint ) # As an optimization, do not reshard after forward for the last - # transformer block since FSDP would prefetch it immediately - reshard_after_forward = int(layer_id) < len(model.layers) - 1 + # transformer block since FSDP would prefetch it immediately. + # When using Pipeline Parallelism, generally zero-2 is best so as to avoid repeated reshardings + # per microbatch. + reshard_after_forward = ( + int(layer_id) < len(model.layers) - 1 and not parallel_dims.pp_enabled + ) fully_shard( transformer_block, **fsdp_config, reshard_after_forward=reshard_after_forward, ) model.layers[layer_id] = transformer_block - model = fully_shard(model, **fsdp_config) + + model = fully_shard( + model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled + ) if ac_mode in ("full", "selective"): logger.info(f"Applied {ac_mode} activation checkpointing to the model") logger.info("Applied FSDP to the model") diff --git a/torchtitan/parallelisms/pipelining_utils.py b/torchtitan/parallelisms/pipelining_utils.py new file mode 100644 index 00000000..5c2ecff3 --- /dev/null +++ b/torchtitan/parallelisms/pipelining_utils.py @@ -0,0 +1,51 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# TODO(whc) this can be removed after pippy migration into pytorch core is complete. +try: + from pippy import Schedule1F1B, ScheduleGPipe +except ImportError as exc: + raise ImportError( + "pippy is not installed. Please install it to use pipeline parallelism. " + "`pip install git+https://github.com/pytorch/pippy`" + ) from exc + + +def build_pipeline_schedule(job_config, parallel_dims, stage, loss_fn): + if job_config.experimental.pipeline_parallel_schedule == "1f1b": + schedule_class = Schedule1F1B + elif job_config.experimental.pipeline_parallel_schedule == "gpipe": + schedule_class = ScheduleGPipe + else: + raise NotImplementedError( + f"{job_config.experimental.pipeline_parallel_schedule} is not implemented" + ) + return schedule_class( + stage, + n_microbatches=parallel_dims.pp, + loss_fn=loss_fn, + ) + + +def split_stage_fqns(fqns, split_points, stage_id): + """Helper for splitting ordered list of layer names into layers per stage. + + split_points is a list of layer names, each layer will be the first layer in a stage + """ + stages = [] + cur = [] + + for name in fqns: + if name in split_points: + assert len( + cur + ), f"{name} is not a valid split point, do not specify the first layer of stage 0" + stages.append(cur) + cur = [] + cur.append(name) + + stages.append(cur) + return stages[stage_id] diff --git a/train.py b/train.py index 318c7174..f6844fe7 100644 --- a/train.py +++ b/train.py @@ -20,6 +20,7 @@ import torch import torch.nn.functional as F from torch.distributed import destroy_process_group +from torch.distributed._composable.fsdp.fully_shard import FSDPModule from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed.tensor.parallel import loss_parallel @@ -32,7 +33,12 @@ from torchtitan.lr_scheduling import get_lr_scheduler from torchtitan.metrics import build_gpu_memory_monitor, build_metric_logger from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config -from torchtitan.parallelisms import models_parallelize_fns, ParallelDims +from torchtitan.parallelisms import ( + models_parallelize_fns, + models_pipelining_fns, + ParallelDims, +) +from torchtitan.parallelisms.pipelining_utils import build_pipeline_schedule from torchtitan.profiling import maybe_enable_profiling from torchtitan.utils import ( Color, @@ -122,11 +128,12 @@ def main(job_config: JobConfig): parallel_dims = ParallelDims( dp=job_config.training.data_parallel_degree, tp=job_config.training.tensor_parallel_degree, - pp=job_config.training.pipeline_parallel_degree, + pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, enable_loss_parallel=job_config.training.enable_loss_parallel, ) - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") + torch.cuda.set_device(device) init_distributed(job_config) world_mesh = parallel_dims.build_mesh(device_type="cuda") @@ -144,6 +151,10 @@ def main(job_config: JobConfig): dp_rank = dp_mesh.get_local_rank() else: dp_degree, dp_rank = 1, 0 + + if parallel_dims.pp_enabled: + pp_mesh = world_mesh["pp"] + data_loader = build_hf_data_loader( job_config.training.dataset, job_config.training.dataset_path, @@ -201,13 +212,26 @@ def loss_fn(pred, labels): # obtain the peak flops of bf16 type for MFU calculation gpu_peak_flops = get_peak_flops(gpu_memory_monitor.device_name) - # apply PT-D parallelisms and activation checkpointing + if parallel_dims.pp_enabled: + stage, model = models_pipelining_fns[model_name]( + model, world_mesh, parallel_dims, job_config, device, model_config + ) + + # apply PT-D DP/TP parallelisms and activation checkpointing model = models_parallelize_fns[model_name]( model, world_mesh, parallel_dims, job_config ) - # allocate sharded model on GPU and initialize weights via DTensor + model.to_empty(device="cuda") - model.init_weights() + + if parallel_dims.pp_enabled: + pp_schedule = build_pipeline_schedule(job_config, parallel_dims, stage, loss_fn) + else: + # If PP is enabled, we can't rely on init_weights, because some layers are missing. + # In the future, we may make init_weights handle missing layers, but also have to consider RNG seed propagation. + + # allocate sharded model on GPU and initialize weights via DTensor + model.init_weights() gpu_mem_stats = gpu_memory_monitor.get_peak_stats() logger.info( @@ -216,6 +240,10 @@ def loss_fn(pred, labels): f"({gpu_mem_stats.max_reserved_pct:.2f}%)" ) + if isinstance(model, FSDPModule) and parallel_dims.pp_enabled: + # reshard now to counteract an issue where FSDP's states got advanced during PP stage shape inference + model.reshard() + # build optimizer after applying parallelisms to the model optimizer = build_optimizer(model, job_config) scheduler = get_lr_scheduler(optimizer, job_config) @@ -257,7 +285,13 @@ def loss_fn(pred, labels): logger.info("Created seed checkpoint") return - checkpoint.load() + checkpoint_loaded = checkpoint.load() + + if parallel_dims.pp_enabled and not checkpoint_loaded: + raise RuntimeError( + "Pipeline Parallelism requires meta-initialization and loading seed checkpoint. " + "Please run `./create_seed_checkpoint.sh` and rerun training with `--checkpoint.enable_checkpoint`" + ) # plot losses loaded from checkpoint (if any) to TensorBoard # NOTE: Loss info after the last log step before checkpoint saving will not be ploted. @@ -299,14 +333,33 @@ def loss_fn(pred, labels): input_ids = input_ids.cuda() labels = labels.cuda() - optimizer.zero_grad() - # forward / backward - with loss_parallel_ctx(): - pred = model(input_ids) - loss = loss_fn(pred, labels) - loss.backward() + if parallel_dims.pp_enabled: + # pipeline parallel forward / backward inside step() call + is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 + + with loss_parallel_ctx(): + if pp_mesh.get_local_rank() == 0: + pp_schedule.step(input_ids) + elif is_last_stage: + losses = [] + pp_schedule.step(target=labels, losses=losses) + else: + schedule.step() + + # accumulate losses across pipeline microbatches + loss = ( + torch.mean(torch.stack(losses)) + if is_last_stage + else torch.Tensor([-1.0]) + ) + else: + # Non-PP forward / backward + with loss_parallel_ctx(): + pred = model(input_ids) + loss = loss_fn(pred, labels) + loss.backward() # clip gradients torch.nn.utils.clip_grad_norm_( diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 4541fec7..009348b5 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -36,11 +36,13 @@ max_norm = 1.0 # grad norm clipping steps = 10 data_parallel_degree = -1 tensor_parallel_degree = 1 -pipeline_parallel_degree = 1 fp8_linear = "" compile = false dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M) +[experimental] +pipeline_parallel_degree = 1 + [checkpoint] enable_checkpoint = false folder = "checkpoint"