diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index a50e71b3ba95..9f8971a4faf5 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -94,8 +94,8 @@ def get_stage_index( stage_indices.append([num_layers_per_stage_accumulated[stage], num_layers_per_stage_accumulated[stage + 1]]) stage_indices.append( [ - num_layers_per_stage_accumulated[num_stages - stage - 1], - num_layers_per_stage_accumulated[num_stages - stage], + num_layers_per_stage_accumulated[2 * num_stages - stage - 1], + num_layers_per_stage_accumulated[2 * num_stages - stage], ] ) return stage_indices