Skip to content

Commit

Permalink
pipeline first
Browse files Browse the repository at this point in the history
  • Loading branch information
gobbleturk committed Sep 12, 2024
1 parent c7c3f4e commit 6c8b42d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ jax_cache_dir: "~/jax_cache"
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' and 'cpu'

# Parallelism
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'expert', 'autoregressive']
mesh_axes: ['stage', 'data', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'expert', 'autoregressive']
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
Expand Down Expand Up @@ -221,7 +221,7 @@ logical_axis_rules: [
['exp', 'expert'],
]
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'expert', 'autoregressive']]
data_sharding: [['stage', 'data', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'expert', 'autoregressive']]

# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
Expand Down
4 changes: 2 additions & 2 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,8 @@ def create_device_mesh(config, devices=None):
multi_slice_env = num_slices > 1

dcn_parallelism = [
config.dcn_data_parallelism,
config.dcn_pipeline_parallelism,
config.dcn_data_parallelism,
config.dcn_fsdp_parallelism,
config.dcn_fsdp_transpose_parallelism,
config.dcn_sequence_parallelism,
Expand All @@ -395,8 +395,8 @@ def create_device_mesh(config, devices=None):
config.dcn_autoregressive_parallelism,
]
ici_parallelism = [
config.ici_data_parallelism,
config.ici_pipeline_parallelism,
config.ici_data_parallelism,
config.ici_fsdp_parallelism,
config.ici_fsdp_transpose_parallelism,
config.ici_sequence_parallelism,
Expand Down

0 comments on commit 6c8b42d

Please sign in to comment.