Skip to content

Commit

Permalink
Add test for PP tracer frontend
Browse files Browse the repository at this point in the history
- switch to using public PipelineStage API
- clean up some asserts in tracer codepath

ghstack-source-id: 2d069b7d45c4f3c788dec8fc85d8a7e83e463fcd
Pull Request resolved: #357
  • Loading branch information
wconstab committed May 24, 2024
1 parent 02ae169 commit 1ceaa4e
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 22 deletions.
13 changes: 13 additions & 0 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,19 @@ def build_test_list(args):
"PP+TP 2D test",
requires_seed_checkpoint=True,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/pp_tracer/",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.1",
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with tracer
],
],
"PP tracer frontend test",
requires_seed_checkpoint=True,
),
OverrideDefinitions(
[
[
Expand Down
46 changes: 24 additions & 22 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
checkpoint_wrapper as ptd_checkpoint_wrapper,
CheckpointImpl,
)
from torch.distributed.pipelining import pipeline, SplitPoint
from torch.distributed.pipelining.PipelineStage import (
_PipelineStage,
from torch.distributed.pipelining import (
ManualPipelineStage,
pipeline,
PipelineStage,
SplitPoint,
)
from torch.distributed.tensor.parallel import (
ColwiseParallel,
Expand Down Expand Up @@ -159,6 +160,14 @@ def _llama_trace_input(job_config, model_config, device="meta"):
return (tokens,)


def _mixed_precision_dtype(
job_config: JobConfig, parallel_dims, default: torch.dtype = torch.float32
) -> torch.dtype:
"""Get the mixed precision dtype if fsdp is enabled, otherwise return the default"""
mp_arg = job_config.training.mixed_precision_param
return TORCH_DTYPE_MAP[mp_arg] if parallel_dims.dp_enabled else default


def pipeline_llama_manual(
model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict
):
Expand Down Expand Up @@ -204,8 +213,7 @@ def pipeline_llama_manual(
# TODO(whc) once ManualPipelineStage supports lazy shape inference, we can leave model on meta device longer and
# get rid of the input shape hardcoded here. For now, it should not be a big deal since we only materialize the
# layers of the model that map to this stage, not the whole model.
mp_arg = job_config.training.mixed_precision_param
mp_dtype = TORCH_DTYPE_MAP[mp_arg] if parallel_dims.dp_enabled else torch.float32
mp_dtype = _mixed_precision_dtype(job_config, parallel_dims)
batch_size = job_config.training.batch_size
local_seq_len = int(job_config.training.seq_len // parallel_dims.tp)
layers_io_shape = (batch_size, local_seq_len, model_config.dim)
Expand All @@ -216,12 +224,7 @@ def pipeline_llama_manual(
)
if pp_rank == 0:
# first layer
input = torch.randint(
model_config.vocab_size,
size=(batch_size, job_config.training.seq_len),
dtype=torch.int64,
device=device,
)
(input,) = _llama_trace_input(job_config, model_config, device=device)
else:
# later layers (assume all start w/ a transformer layer)
input = torch.rand(layers_io_shape, dtype=mp_dtype, device=device)
Expand Down Expand Up @@ -257,32 +260,31 @@ def pipeline_llama_tracer(
"fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm."
)

# TODO(whc) maybe we can just fix this by feeding bf16 into the tracer for its input shapes?
raise NotImplementedError(
"pipeline tracer doesn't work with fsdp mixed precision currently. "
"To work around, edit fsdp mixed precision config to use fp32."
)
if _mixed_precision_dtype(job_config, parallel_dims) == torch.bfloat16:
raise NotImplementedError(
"pipeline tracer doesn't work with fsdp mixed precision currently. "
"To work around, edit fsdp mixed precision config to use fp32."
)

pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
stage_idx = pp_mesh.get_local_rank()
stage_idx = pp_rank
layers_per_rank = len(model.layers) // parallel_dims.pp
split_spec = {
f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING
for i in range(1, parallel_dims.pp)
}

# Create a pipeline representation from the model
pipe = pipeline(
model,
job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp,
example_args=_llama_trace_input(job_config, model_config),
split_spec=split_spec,
)
model = pipe.get_stage_module(stage_idx)
stage = _PipelineStage(
stage_module=model,
stage_index=pp_rank,
pipe_info=pipe.pipe_info,
stage = PipelineStage(
pipe,
stage_index=stage_idx,
device=device,
group=pp_mesh.get_group(),
)
Expand Down

0 comments on commit 1ceaa4e

Please sign in to comment.