Skip to content

Commit

Permalink
Merge branch 'main' into parambole/stable-stack
Browse files Browse the repository at this point in the history
  • Loading branch information
parambole committed Sep 19, 2024
2 parents caa4be4 + 48cb7b0 commit f605bc6
Show file tree
Hide file tree
Showing 71 changed files with 3,045 additions and 2,729 deletions.
2 changes: 1 addition & 1 deletion MaxText/accelerator_to_spec_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class SystemCharacteristics:
# across hosts will occur over DCN. This makes the "slice" topology of A3 fixed to a single host.
# To use AoT compilation with multihost, the `compile_topology_num_slices` flag should be
# specified to the number of hosts.
"a3": SystemCharacteristics("gpu", None, None, None, 8, None)
"a3": SystemCharacteristics("gpu", None, None, None, 8, None),
}


Expand Down
57 changes: 15 additions & 42 deletions MaxText/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@
CheckpointManagerOptions = ocp.CheckpointManagerOptions
PyTreeCheckpointHandler = ocp.PyTreeCheckpointHandler
LocalCheckpointOptions = emergency_checkpoint_manager.LocalCheckpointOptions
PersistentCheckpointOptions = (
emergency_checkpoint_manager.PersistentCheckpointOptions
)
PersistentCheckpointOptions = emergency_checkpoint_manager.PersistentCheckpointOptions

abstract_logger = ocp.logging.abstract_logger
cloud_logger = ocp.logging.cloud_logger
Expand Down Expand Up @@ -76,7 +74,7 @@ def create_orbax_checkpoint_manager(
save_interval_steps=save_interval_steps,
enable_async_checkpointing=use_async,
),
logger=orbax_logger
logger=orbax_logger,
)
max_logging.log("Checkpoint manager created!")
return mngr
Expand All @@ -96,12 +94,8 @@ def create_orbax_emergency_checkpoint_manager(
max_logging.log("Creating emergency checkpoint manager...")

options = emergency_checkpoint_manager.CheckpointManagerOptions(
local=LocalCheckpointOptions(
save_interval_steps=local_save_interval_steps
),
persistent=PersistentCheckpointOptions(
save_interval_steps=persistent_save_interval_steps
),
local=LocalCheckpointOptions(save_interval_steps=local_save_interval_steps),
persistent=PersistentCheckpointOptions(save_interval_steps=persistent_save_interval_steps),
)
emergency_mngr = emergency_checkpoint_manager.CheckpointManager(
local_checkpoint_dir,
Expand Down Expand Up @@ -191,16 +185,13 @@ def map_to_pspec(data):
replica_axis_index = 0
replica_devices = _replica_devices(mesh.devices, replica_axis_index)
replica_mesh = jax.sharding.Mesh(replica_devices, mesh.axis_names)
single_replica_sharding = jax.sharding.NamedSharding(
replica_mesh, pspec)
single_replica_sharding = jax.sharding.NamedSharding(replica_mesh, pspec)

array_handler = ocp.type_handlers.SingleReplicaArrayHandler(
replica_axis_index=0,
broadcast_memory_limit_bytes=1024 * 1024 * 1000, # 1000 MB limit
)
ocp.type_handlers.register_type_handler(
jax.Array, array_handler, override=True
)
ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True)

return ocp.type_handlers.SingleReplicaArrayRestoreArgs(
sharding=jax.sharding.NamedSharding(mesh, pspec),
Expand All @@ -218,9 +209,7 @@ def map_to_pspec(data):
return (
checkpoint_manager.restore(
latest_step,
args=ocp.args.PyTreeRestore(
item=abstract_unboxed_pre_state, restore_args=restore_args
),
args=ocp.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args),
),
None,
)
Expand All @@ -234,9 +223,7 @@ def map_to_pspec(data):
item=abstract_unboxed_pre_state,
restore_args=restore_args,
),
iter=grain.PyGrainCheckpointRestore(
data_iterator.local_iterator
),
iter=grain.PyGrainCheckpointRestore(data_iterator.local_iterator),
),
),
None,
Expand Down Expand Up @@ -282,9 +269,7 @@ def setup_checkpoint_logger(config) -> composite_logger.CompositeLogger | None:
max_logging.log("Setting up checkpoint logger...")
if config.enable_checkpoint_cloud_logger:
logger_name = f"checkpoint_{config.run_name}"
options = cloud_logger.CloudLoggerOptions(
job_name=config.run_name, logger_name=logger_name
)
options = cloud_logger.CloudLoggerOptions(job_name=config.run_name, logger_name=logger_name)
orbax_cloud_logger = cloud_logger.CloudLogger(options=options)
max_logging.log("Successfully set up checkpoint cloud logger.")

Expand All @@ -294,9 +279,7 @@ def setup_checkpoint_logger(config) -> composite_logger.CompositeLogger | None:

orbax_logger = None
if orbax_cloud_logger is not None and orbax_standard_logger is not None:
orbax_logger = composite_logger.CompositeLogger(
orbax_cloud_logger, orbax_standard_logger
)
orbax_logger = composite_logger.CompositeLogger(orbax_cloud_logger, orbax_standard_logger)
max_logging.log("Successfully set up checkpoint composite logger.")

return orbax_logger
Expand All @@ -312,27 +295,17 @@ def load_params_from_path(load_parameters_from_path, abstract_unboxed_params):
# Rather than pass the entire abstract state, which could unnecessarily restore opt_state and such and waste
# memory, we instead specify here that we are just restoring the params field of the checkpoint
# (which itself may be a dictionary containing a key named 'params').
restore_args = ocp.checkpoint_utils.construct_restore_args(
abstract_unboxed_params
)
restore_args = ocp.checkpoint_utils.construct_restore_args(abstract_unboxed_params)
restored = ckptr.restore(
ckpt,
item={"params": abstract_unboxed_params},
transforms={},
restore_args={"params": restore_args}
)
ckpt, item={"params": abstract_unboxed_params}, transforms={}, restore_args={"params": restore_args}
)
return restored["params"]


def save_params_to_path(checkpoint_dir, params):
"""Save decode params in checkpoint at specified path."""
assert checkpoint_dir, "checkpoint_dir is not defined."
orbax_checkpointer = ocp.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target({"params":params})
orbax_checkpointer.save(
checkpoint_dir,
{"params":params},
save_args=save_args,
force=True
)
save_args = orbax_utils.save_args_from_target({"params": params})
orbax_checkpointer.save(checkpoint_dir, {"params": params}, save_args=save_args, force=True)
print(f"Quantized params checkpoint saved at: {checkpoint_dir}")
157 changes: 95 additions & 62 deletions MaxText/convert_gemma2_chkpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

# pylint: disable=line-too-long
"""
Convert orbax Gemma checkpoint to MaxText compatible checkpoint.
Expand All @@ -36,6 +37,7 @@

Params = dict[str, Any]


def nest_params(params: Params) -> Params:
"""Nests params as a dict of dicts rather than a flat dict."""
nested_params = {}
Expand Down Expand Up @@ -75,7 +77,7 @@ def main(raw_args=None) -> None:
if args.model_size in ("2b", "9b"):
query_pre_attn_scalar = head_dim**-0.5
elif args.model_size in ("27b"):
query_pre_attn_scalar = (embed_dim // num_heads)**-0.5
query_pre_attn_scalar = (embed_dim // num_heads) ** -0.5

transpose_gating_einsum = True
if args.model_size in ("2b"):
Expand All @@ -89,44 +91,49 @@ def main(raw_args=None) -> None:
},
"token_embedder": {"embedding": params["transformer"]["embedder"]["input_embedding"] * jnp.sqrt(embed_dim)},
}
self_attention_local = dict({
"query": {"kernel": []},
"key": {"kernel": []},
"value": {"kernel": []},
"out": {"kernel": []},
})
self_attention_global = dict({
"query": {"kernel": []},
"key": {"kernel": []},
"value": {"kernel": []},
"out": {"kernel": []},
})

layer_weight = dict({
"mlp_local": {
"wi_0": {"kernel": []},
"wi_1": {"kernel": []},
"wo": {"kernel": []},
},
"mlp_global": {
"wi_0": {"kernel": []},
"wi_1": {"kernel": []},
"wo": {"kernel": []},
},
"pre_self_attention_norm_local": {"scale": []},
"pre_ffw_norm_local": {"scale": []},
"post_self_attention_norm_local": {"scale": []},
"post_ffw_norm_local": {"scale": []},
"pre_self_attention_norm_global": {"scale": []},
"pre_ffw_norm_global": {"scale": []},
"post_self_attention_norm_global": {"scale": []},
"post_ffw_norm_global": {"scale": []},
})
self_attention_local = dict(
{
"query": {"kernel": []},
"key": {"kernel": []},
"value": {"kernel": []},
"out": {"kernel": []},
}
)
self_attention_global = dict(
{
"query": {"kernel": []},
"key": {"kernel": []},
"value": {"kernel": []},
"out": {"kernel": []},
}
)

layer_weight = dict(
{
"mlp_local": {
"wi_0": {"kernel": []},
"wi_1": {"kernel": []},
"wo": {"kernel": []},
},
"mlp_global": {
"wi_0": {"kernel": []},
"wi_1": {"kernel": []},
"wo": {"kernel": []},
},
"pre_self_attention_norm_local": {"scale": []},
"pre_ffw_norm_local": {"scale": []},
"post_self_attention_norm_local": {"scale": []},
"post_ffw_norm_local": {"scale": []},
"pre_self_attention_norm_global": {"scale": []},
"pre_ffw_norm_global": {"scale": []},
"post_self_attention_norm_global": {"scale": []},
"post_ffw_norm_global": {"scale": []},
}
)

for layer_idx in range(0, num_layers, 2):
in_layer_name_local = "layer_" + str(layer_idx)
in_layer_name_global = "layer_" + str(layer_idx+1)
in_layer_name_global = "layer_" + str(layer_idx + 1)

######################## layer local attention ########################
self_attention_local["query"]["kernel"].append(
Expand All @@ -142,23 +149,35 @@ def main(raw_args=None) -> None:

# mlp
if transpose_gating_einsum:
layer_weight["mlp_local"]["wi_0"]["kernel"].append(np.transpose(params["transformer"][in_layer_name_local]["mlp"]["gating_einsum"]["w"][0]))
layer_weight["mlp_local"]["wi_1"]["kernel"].append(np.transpose(params["transformer"][in_layer_name_local]["mlp"]["gating_einsum"]["w"][1]))
layer_weight["mlp_local"]["wi_0"]["kernel"].append(
np.transpose(params["transformer"][in_layer_name_local]["mlp"]["gating_einsum"]["w"][0])
)
layer_weight["mlp_local"]["wi_1"]["kernel"].append(
np.transpose(params["transformer"][in_layer_name_local]["mlp"]["gating_einsum"]["w"][1])
)
else:
layer_weight["mlp_local"]["wi_0"]["kernel"].append(params["transformer"][in_layer_name_local]["mlp"]["gating_einsum"]["w"][0])
layer_weight["mlp_local"]["wi_1"]["kernel"].append(params["transformer"][in_layer_name_local]["mlp"]["gating_einsum"]["w"][1])
layer_weight["mlp_local"]["wi_0"]["kernel"].append(
params["transformer"][in_layer_name_local]["mlp"]["gating_einsum"]["w"][0]
)
layer_weight["mlp_local"]["wi_1"]["kernel"].append(
params["transformer"][in_layer_name_local]["mlp"]["gating_einsum"]["w"][1]
)

layer_weight["mlp_local"]["wo"]["kernel"].append(params["transformer"][in_layer_name_local]["mlp"]["linear"]["w"])

layer_weight["pre_self_attention_norm_local"]["scale"].append(
params["transformer"][in_layer_name_local]["pre_attention_norm"]["scale"] + 1
)
layer_weight["pre_ffw_norm_local"]["scale"].append(params["transformer"][in_layer_name_local]["pre_ffw_norm"]["scale"] + 1)
layer_weight["pre_ffw_norm_local"]["scale"].append(
params["transformer"][in_layer_name_local]["pre_ffw_norm"]["scale"] + 1
)

layer_weight["post_self_attention_norm_local"]["scale"].append(
params["transformer"][in_layer_name_local]["post_attention_norm"]["scale"] + 1
params["transformer"][in_layer_name_local]["post_attention_norm"]["scale"] + 1
)
layer_weight["post_ffw_norm_local"]["scale"].append(
params["transformer"][in_layer_name_local]["post_ffw_norm"]["scale"] + 1
)
layer_weight["post_ffw_norm_local"]["scale"].append(params["transformer"][in_layer_name_local]["post_ffw_norm"]["scale"] + 1)

######################## layer global attention ########################

Expand All @@ -171,27 +190,41 @@ def main(raw_args=None) -> None:
self_attention_global["value"]["kernel"].append(
params["transformer"][in_layer_name_global]["attn"]["kv_einsum"]["w"][1].transpose((1, 0, 2))
)
self_attention_global["out"]["kernel"].append(params["transformer"][in_layer_name_global]["attn"]["attn_vec_einsum"]["w"])
self_attention_global["out"]["kernel"].append(
params["transformer"][in_layer_name_global]["attn"]["attn_vec_einsum"]["w"]
)

# mlp
if transpose_gating_einsum:
layer_weight["mlp_global"]["wi_0"]["kernel"].append(np.transpose(params["transformer"][in_layer_name_global]["mlp"]["gating_einsum"]["w"][0]))
layer_weight["mlp_global"]["wi_1"]["kernel"].append(np.transpose(params["transformer"][in_layer_name_global]["mlp"]["gating_einsum"]["w"][1]))
layer_weight["mlp_global"]["wi_0"]["kernel"].append(
np.transpose(params["transformer"][in_layer_name_global]["mlp"]["gating_einsum"]["w"][0])
)
layer_weight["mlp_global"]["wi_1"]["kernel"].append(
np.transpose(params["transformer"][in_layer_name_global]["mlp"]["gating_einsum"]["w"][1])
)
else:
layer_weight["mlp_global"]["wi_0"]["kernel"].append(params["transformer"][in_layer_name_global]["mlp"]["gating_einsum"]["w"][0])
layer_weight["mlp_global"]["wi_1"]["kernel"].append(params["transformer"][in_layer_name_global]["mlp"]["gating_einsum"]["w"][1])
layer_weight["mlp_global"]["wi_0"]["kernel"].append(
params["transformer"][in_layer_name_global]["mlp"]["gating_einsum"]["w"][0]
)
layer_weight["mlp_global"]["wi_1"]["kernel"].append(
params["transformer"][in_layer_name_global]["mlp"]["gating_einsum"]["w"][1]
)

layer_weight["mlp_global"]["wo"]["kernel"].append(params["transformer"][in_layer_name_global]["mlp"]["linear"]["w"])

layer_weight["pre_self_attention_norm_global"]["scale"].append(
params["transformer"][in_layer_name_global]["pre_attention_norm"]["scale"] + 1
)
layer_weight["pre_ffw_norm_global"]["scale"].append(params["transformer"][in_layer_name_global]["pre_ffw_norm"]["scale"] + 1)
layer_weight["pre_ffw_norm_global"]["scale"].append(
params["transformer"][in_layer_name_global]["pre_ffw_norm"]["scale"] + 1
)

layer_weight["post_self_attention_norm_global"]["scale"].append(
params["transformer"][in_layer_name_global]["post_attention_norm"]["scale"] + 1
params["transformer"][in_layer_name_global]["post_attention_norm"]["scale"] + 1
)
layer_weight["post_ffw_norm_global"]["scale"].append(
params["transformer"][in_layer_name_global]["post_ffw_norm"]["scale"] + 1
)
layer_weight["post_ffw_norm_global"]["scale"].append(params["transformer"][in_layer_name_global]["post_ffw_norm"]["scale"] + 1)

self_attention_local["query"]["kernel"] = np.array(self_attention_local["query"]["kernel"]).transpose((1, 0, 2, 3))
self_attention_local["key"]["kernel"] = np.array(self_attention_local["key"]["kernel"]).transpose((1, 0, 2, 3))
Expand All @@ -211,22 +244,22 @@ def main(raw_args=None) -> None:
layer_weight["mlp_global"]["wi_1"]["kernel"] = np.array(layer_weight["mlp_global"]["wi_1"]["kernel"]).transpose((1, 0, 2))
layer_weight["mlp_global"]["wo"]["kernel"] = np.array(layer_weight["mlp_global"]["wo"]["kernel"]).transpose((1, 0, 2))

layer_weight["pre_self_attention_norm_local"]["scale"] = np.array(layer_weight["pre_self_attention_norm_local"]["scale"]).transpose(
(1, 0)
)
layer_weight["pre_self_attention_norm_local"]["scale"] = np.array(
layer_weight["pre_self_attention_norm_local"]["scale"]
).transpose((1, 0))
layer_weight["pre_ffw_norm_local"]["scale"] = np.array(layer_weight["pre_ffw_norm_local"]["scale"]).transpose((1, 0))
layer_weight["post_self_attention_norm_local"]["scale"] = np.array(layer_weight["post_self_attention_norm_local"]["scale"]).transpose(
(1, 0)
)
layer_weight["post_self_attention_norm_local"]["scale"] = np.array(
layer_weight["post_self_attention_norm_local"]["scale"]
).transpose((1, 0))
layer_weight["post_ffw_norm_local"]["scale"] = np.array(layer_weight["post_ffw_norm_local"]["scale"]).transpose((1, 0))

layer_weight["pre_self_attention_norm_global"]["scale"] = np.array(layer_weight["pre_self_attention_norm_global"]["scale"]).transpose(
(1, 0)
)
layer_weight["pre_self_attention_norm_global"]["scale"] = np.array(
layer_weight["pre_self_attention_norm_global"]["scale"]
).transpose((1, 0))
layer_weight["pre_ffw_norm_global"]["scale"] = np.array(layer_weight["pre_ffw_norm_global"]["scale"]).transpose((1, 0))
layer_weight["post_self_attention_norm_global"]["scale"] = np.array(layer_weight["post_self_attention_norm_global"]["scale"]).transpose(
(1, 0)
)
layer_weight["post_self_attention_norm_global"]["scale"] = np.array(
layer_weight["post_self_attention_norm_global"]["scale"]
).transpose((1, 0))
layer_weight["post_ffw_norm_global"]["scale"] = np.array(layer_weight["post_ffw_norm_global"]["scale"]).transpose((1, 0))

layer_weight["self_attention_local"] = copy.deepcopy(self_attention_local)
Expand Down
Loading

0 comments on commit f605bc6

Please sign in to comment.