Skip to content

Commit

Permalink
Fix mesh_axes and data_sharding for LLaMA 2 GPU configs.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 646795068
  • Loading branch information
golechwierowicz authored and maxtext authors committed Jun 26, 2024
1 parent 5a215db commit 679ec8c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions MaxText/configs/llama2_70b_gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ logits_dot_in_fp32: False
per_device_batch_size: 6
max_target_length: 4096

mesh_axes: ['stage', 'data', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose',]],
# For pipeline parallelism the pre and post decoder layer tensors' batch dimension is sharded by stages.
Expand Down Expand Up @@ -52,4 +52,4 @@ logical_axis_rules: [
['cache_sequence', []],
]
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
data_sharding: [['stage', 'data', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']]
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']]
4 changes: 2 additions & 2 deletions MaxText/configs/llama2_7b_gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ logits_dot_in_fp32: False
per_device_batch_size: 4
max_target_length: 4096

mesh_axes: ['stage', 'data', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose',]],
# For pipeline parallelism the pre and post decoder layer tensors' batch dimension is sharded by stages.
Expand Down Expand Up @@ -54,4 +54,4 @@ logical_axis_rules: [
]

# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
data_sharding: [['stage', 'data', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']]
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']]

0 comments on commit 679ec8c

Please sign in to comment.