Skip to content

Commit

Permalink
fix num_microbatches input for PP
Browse files Browse the repository at this point in the history
ghstack-source-id: 459e8bee48fd77b027fecdc9e5f78ec375b87cb6
Pull Request resolved: #781
  • Loading branch information
H-Huang committed Jan 7, 2025
1 parent a85a44c commit 90567fc
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 11 deletions.
6 changes: 2 additions & 4 deletions torchtitan/parallelisms/pipeline_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ def pipeline_llama_manual_split(
"""
pp_rank = pp_mesh.get_local_rank()
pp_size = pp_mesh.size()
microbatches = (
job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp
)

splits = (
job_config.experimental.pipeline_parallel_split_points
or generate_split_points(job_config, parallel_dims.pp, model_config)
Expand Down Expand Up @@ -117,7 +115,7 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal
)
logger.info(
f"PP rank {pp_rank} is building stage_idx {stage_idx}"
f" with start_layer {start_layer}, stop_layer {stop_layer}: model chunk \n{model_chunk}"
f" with start_layer {start_layer}, stop_layer {stop_layer}"
)
stages.append(stage)
models.append(model_chunk)
Expand Down
31 changes: 24 additions & 7 deletions torchtitan/parallelisms/pipelining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,27 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
from typing import Tuple
from typing import List, Tuple

from torch.distributed.pipelining.schedules import (
_PipelineScheduleRuntime,
get_schedule_class,
PipelineScheduleMulti,
PipelineScheduleSingle,
)
from torchtitan.config_manager import JobConfig
from torchtitan.logging import logger
from torchtitan.models.llama.model import ModelArgs


def generate_split_points(job_config, pp_dim, model_config):
def generate_split_points(
job_config: JobConfig, pp_dim: int, model_config: ModelArgs
) -> List[str]:
"""
Generate a default split point based on the number of layers and
pipeline parallel dimension.
"""

schedule_class = get_schedule_class(
job_config.experimental.pipeline_parallel_schedule
)
Expand Down Expand Up @@ -51,7 +60,7 @@ def generate_split_points(job_config, pp_dim, model_config):
current_layer += base_interval
splits.append("layers." + str(current_layer))
logger.info(
f"No 'pipeline_parallel_split_points' so the generated splits are: {splits} \
f"No 'pipeline_parallel_split_points' provided so the generated splits are: {splits} \
This may be sub-optimal as the number of layers per stage may be unbalanced."
)
return splits
Expand All @@ -73,18 +82,26 @@ def build_pipeline_schedule(job_config, stages, loss_fn):
)

looped_schedule = issubclass(schedule_class, PipelineScheduleMulti)
logger.info(
f"Using pipeline schedule {job_config.experimental.pipeline_parallel_schedule}"
)
n_microbatches = job_config.experimental.pipeline_parallel_microbatches
# We expect that the number of local stages (`len(stages)`) is the same across all ranks
num_total_stages = job_config.experimental.pipeline_parallel_degree * len(stages)
if n_microbatches is None:
n_microbatches = job_config.experimental.pipeline_parallel_degree
n_microbatches = num_total_stages
elif n_microbatches < num_total_stages:
logger.warning(
f"Number of microbatches ({n_microbatches}) is less than the total number \
of stages ({num_total_stages}) which may result in a bubble in the pipeline."
)

schedule = schedule_class(
stages if looped_schedule else stages[0],
n_microbatches=n_microbatches,
loss_fn=loss_fn,
)
logger.info(
f"Using pipeline schedule {job_config.experimental.pipeline_parallel_schedule} \
with {n_microbatches} and {num_total_stages} stages."
)

if pp_schedule_csv:
assert schedule_class in [
Expand Down

0 comments on commit 90567fc

Please sign in to comment.