Skip to content

Commit

Permalink
Update gemma_backbone.py for sharding config. (keras-team#1491)
Browse files Browse the repository at this point in the history
* Update gemma_backbone.py for sharding config.

* Update unit test and fix format.

* Update sharding spec for gemma based on gemma training.
  • Loading branch information
qlzh727 authored Mar 14, 2024
1 parent 673c63b commit 4511580
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 16 deletions.
31 changes: 23 additions & 8 deletions keras_nlp/models/gemma/gemma_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,11 @@ def presets(cls):
return copy.deepcopy(backbone_presets)

@staticmethod
def get_layout_map(device_mesh, model_parallel_dim_name="model"):
def get_layout_map(
device_mesh,
model_parallel_dim_name="model",
data_parallel_dim_name="batch",
):
"""Get a `keras.distribution.LayoutMap` for model parallel distribution.
The returned `LayoutMap` contains the sharding spec for the gemma
Expand All @@ -221,6 +225,8 @@ def get_layout_map(device_mesh, model_parallel_dim_name="model"):
distribution.
model_parallel_dim_name: The axis name of the device mesh, where
the weights should be partition on.
data_parallel_dim_name: The axis name of the device mesh, where
the data should be partition on.
Return:
`keras.distribution.LayoutMap` that contains the sharding spec
of all the model weights.
Expand Down Expand Up @@ -248,21 +254,30 @@ def get_layout_map(device_mesh, model_parallel_dim_name="model"):
f"{model_parallel_dim_name} is not found in the "
f"device_mesh.axis_names. {device_mesh.axis_name=}"
)
if data_parallel_dim_name not in device_mesh.axis_names:
raise ValueError(
f"{data_parallel_dim_name} is not found in the "
f"device_mesh.axis_names. {device_mesh.axis_name=}"
)
# Note that it is possible to further config the mesh to be 3D, eg
# (data, seq, model). We leave it as 2D for now for simplicity.
data_dim = data_parallel_dim_name
model_dim = model_parallel_dim_name
# The sharding is partition for the hidden_dim of the model.
# The sharding config is based on the Gemma team training config.
# See https://arxiv.org/abs/2403.08295
layout_map = keras.distribution.LayoutMap(device_mesh)
layout_map["token_embedding/embeddings"] = (None, model_dim)
layout_map["token_embedding/embeddings"] = (model_dim, data_dim)
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
None,
model_dim,
data_dim,
None,
)
layout_map["decoder_block.*attention_output.*kernel"] = (
None,
None,
model_dim,
None,
data_dim,
)
layout_map["decoder_block.*ffw_gating.*kernel"] = (model_dim, None)
layout_map["decoder_block.*ffw_linear.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_gating.*kernel"] = (data_dim, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, data_dim)

return layout_map
24 changes: 16 additions & 8 deletions keras_nlp/models/gemma/gemma_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,26 +106,34 @@ def test_distribution(self):

for w in model.weights:
if "token_embedding/embeddings" in w.path:
self.assertEqual(tuple(w.value.sharding.spec), (None, "model"))
self.assertEqual(
tuple(w.value.sharding.spec), ("model", "batch")
)
if "attention/query/kernel" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), (None, "model", None)
tuple(w.value.sharding.spec), ("model", "batch", None)
)
if "attention/key/kernel" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), (None, "model", None)
tuple(w.value.sharding.spec), ("model", "batch", None)
)
if "attention/value/kernel" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), (None, "model", None)
tuple(w.value.sharding.spec), ("model", "batch", None)
)
if "attention/attention_output/kernel" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), (None, None, "model")
tuple(w.value.sharding.spec), ("model", None, "batch")
)
if "ffw_gating/kernel" in w.path:
self.assertEqual(tuple(w.value.sharding.spec), ("model", None))
self.assertEqual(
tuple(w.value.sharding.spec), ("batch", "model")
)
if "ffw_gating_2/kernel" in w.path:
self.assertEqual(tuple(w.value.sharding.spec), ("model", None))
self.assertEqual(
tuple(w.value.sharding.spec), ("batch", "model")
)
if "ffw_linearl" in w.path:
self.assertEqual(tuple(w.value.sharding.spec), (None, "model"))
self.assertEqual(
tuple(w.value.sharding.spec), ("model", "batch")
)

0 comments on commit 4511580

Please sign in to comment.