Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Gemma2 support to MaxText #814

Merged
merged 1 commit into from
Aug 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,15 @@ scan_layers: True
param_scan_axis: 1

# The attention parameter dictates the specific algorithm/methodology used to compute the attention scores
attention: 'autoselected' # Supported attention: autoselected, dot_product, flash, cudnn_flash_te

# The attention_type parameter determines the variants of attention, e.g. global or local_sliding
ZhaoyueCheng marked this conversation as resolved.
Show resolved Hide resolved
attention_type: 'global' # Supported attention_type: global, local_sliding
attention: 'autoselected' # Supported attention: autoselected, dot_product, flash, cudnn_flash_te
attention_type: 'global' # Supported attention_type: global, local_sliding
sliding_window_size: 0
attn_logits_soft_cap: 0.0
final_logits_soft_cap: 0.0
use_post_attn_norm: False
use_post_ffw_norm: False


# Combine matmuls for QKV and MLP
fused_qkv: False
Expand Down
33 changes: 33 additions & 0 deletions MaxText/configs/models/gemma2-27b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for gemma2-27B

base_emb_dim: 4608
base_num_query_heads: 32
base_num_kv_heads: 16
base_mlp_dim: 36864
base_num_decoder_layers: 23 # half of the real number of layers because we merge [local_attention, global_attention] into one layer
head_dim: 128
mlp_activations: ["gelu","linear"]
vocab_size: 256128
decoder_block: "gemma2"
normalization_layer_epsilon: 1.e-06
logits_via_embedding: True
attention: "dot_product"
final_logits_soft_cap: 30.0
attn_logits_soft_cap: 50.0
sliding_window_size: 4096
use_post_attn_norm: True
use_post_ffw_norm: True
33 changes: 33 additions & 0 deletions MaxText/configs/models/gemma2-2b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for gemma-2B

base_emb_dim: 2304
base_num_query_heads: 8
base_num_kv_heads: 4
base_mlp_dim: 9216
base_num_decoder_layers: 13 # half of the real number of layers because we merge [local_attention, global_attention] into one layer
head_dim: 256
mlp_activations: ["gelu","linear"]
vocab_size: 256128
decoder_block: "gemma2"
normalization_layer_epsilon: 1.e-06
logits_via_embedding: True
attention: "dot_product"
final_logits_soft_cap: 30.0
attn_logits_soft_cap: 50.0
sliding_window_size: 4096
use_post_attn_norm: True
use_post_ffw_norm: True
13 changes: 7 additions & 6 deletions MaxText/configs/models/gemma2-9b.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google LLC
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -18,15 +18,16 @@ base_emb_dim: 3584
base_num_query_heads: 16
base_num_kv_heads: 8
base_mlp_dim: 14336
base_num_decoder_layers: 42
base_num_decoder_layers: 21 # half of the real number of layers because we merge [local_attention, global_attention] into one layer
head_dim: 256
mlp_activations: ["gelu","linear"]
vocab_size: 256128
decoder_block: "gemma"
decoder_block: "gemma2"
normalization_layer_epsilon: 1.e-06
logits_via_embedding: True
# final_logit_softcap: 30.0
attention: "dot_product"
final_logits_soft_cap: 30.0
attn_logits_soft_cap: 50.0
sliding_window_size: 4096
# use_post_attn_norm: True
# use_post_ffw_norm: True
use_post_attn_norm: True
use_post_ffw_norm: True
268 changes: 268 additions & 0 deletions MaxText/convert_gemma2_chkpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
"""
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# pylint: disable=line-too-long
"""
Convert orbax Gemma checkpoint to MaxText compatible checkpoint.
"""

import jax
import jax.numpy as jnp
import numpy as np

jax.config.update("jax_platform_name", "cpu")
import argparse
import copy
from flax.training import train_state

from typing import Any
import sys
import max_logging


import orbax

import checkpointing
from train import save_checkpoint

Params = dict[str, Any]

def nest_params(params: Params) -> Params:
"""Nests params as a dict of dicts rather than a flat dict."""
nested_params = {}
for path, param in params.items():
*path, leaf = path.split("/")
subdict = nested_params
for key in path:
subdict = subdict.setdefault(key, {})
subdict[leaf] = param
return nested_params


def main(raw_args=None) -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--base_model_path", type=str, required=True)
parser.add_argument("--maxtext_model_path", type=str, required=True)
parser.add_argument("--model_size", type=str, required=True)
args = parser.parse_args(raw_args)
if args.model_size not in ("2b", "9b", "27b"):
raise NotImplementedError("only implemented for gemma 2 classes")

print("Loading checkpoint")
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
params = checkpointer.restore(args.base_model_path)
params = nest_params(params)
num_layers = max((int(k.split("_")[1]) for k in params["transformer"].keys() if "layer_" in k)) + 1
hidden_dim, embed_dim = params["transformer"]["layer_0"]["mlp"]["linear"]["w"].shape
num_heads, head_dim, _ = params["transformer"]["layer_0"]["attn"]["attn_vec_einsum"]["w"].shape
print("Model configurations from checkpoint")
print(f"num_layers: {num_layers}")
print(f"hidden_dim: {hidden_dim}")
print(f"embed_dim: {embed_dim}")
print(f"num_heads: {num_heads}")
print(f"head_dim: {head_dim}")

query_pre_attn_scalar = None
if args.model_size in ("2b", "9b"):
query_pre_attn_scalar = head_dim**-0.5
elif args.model_size in ("27b"):
query_pre_attn_scalar = (embed_dim // num_heads)**-0.5

transpose_gating_einsum = True
if args.model_size in ("2b"):
transpose_gating_einsum = False
elif args.model_size in ("9b", "27b"):
transpose_gating_einsum = True

jax_weights = {
"decoder": {
"decoder_norm": {"scale": params["transformer"]["final_norm"]["scale"] + 1},
},
"token_embedder": {"embedding": params["transformer"]["embedder"]["input_embedding"] * jnp.sqrt(embed_dim)},
}
self_attention_local = dict({
"query": {"kernel": []},
"key": {"kernel": []},
"value": {"kernel": []},
"out": {"kernel": []},
})
self_attention_global = dict({
"query": {"kernel": []},
"key": {"kernel": []},
"value": {"kernel": []},
"out": {"kernel": []},
})

layer_weight = dict({
"mlp_local": {
"wi_0": {"kernel": []},
"wi_1": {"kernel": []},
"wo": {"kernel": []},
},
"mlp_global": {
"wi_0": {"kernel": []},
"wi_1": {"kernel": []},
"wo": {"kernel": []},
},
"pre_self_attention_norm_local": {"scale": []},
"pre_ffw_norm_local": {"scale": []},
"post_self_attention_norm_local": {"scale": []},
"post_ffw_norm_local": {"scale": []},
"pre_self_attention_norm_global": {"scale": []},
"pre_ffw_norm_global": {"scale": []},
"post_self_attention_norm_global": {"scale": []},
"post_ffw_norm_global": {"scale": []},
})


for layer_idx in range(0, num_layers, 2):
in_layer_name_local = "layer_" + str(layer_idx)
in_layer_name_global = "layer_" + str(layer_idx+1)

######################## layer local attention ########################
self_attention_local["query"]["kernel"].append(
params["transformer"][in_layer_name_local]["attn"]["q_einsum"]["w"].transpose((1, 0, 2)) * query_pre_attn_scalar
)
self_attention_local["key"]["kernel"].append(
params["transformer"][in_layer_name_local]["attn"]["kv_einsum"]["w"][0].transpose((1, 0, 2))
)
self_attention_local["value"]["kernel"].append(
params["transformer"][in_layer_name_local]["attn"]["kv_einsum"]["w"][1].transpose((1, 0, 2))
)
self_attention_local["out"]["kernel"].append(params["transformer"][in_layer_name_local]["attn"]["attn_vec_einsum"]["w"])

# mlp
if transpose_gating_einsum:
layer_weight["mlp_local"]["wi_0"]["kernel"].append(np.transpose(params["transformer"][in_layer_name_local]["mlp"]["gating_einsum"]["w"][0]))
layer_weight["mlp_local"]["wi_1"]["kernel"].append(np.transpose(params["transformer"][in_layer_name_local]["mlp"]["gating_einsum"]["w"][1]))
else:
layer_weight["mlp_local"]["wi_0"]["kernel"].append(params["transformer"][in_layer_name_local]["mlp"]["gating_einsum"]["w"][0])
layer_weight["mlp_local"]["wi_1"]["kernel"].append(params["transformer"][in_layer_name_local]["mlp"]["gating_einsum"]["w"][1])

layer_weight["mlp_local"]["wo"]["kernel"].append(params["transformer"][in_layer_name_local]["mlp"]["linear"]["w"])

layer_weight["pre_self_attention_norm_local"]["scale"].append(
params["transformer"][in_layer_name_local]["pre_attention_norm"]["scale"] + 1
)
layer_weight["pre_ffw_norm_local"]["scale"].append(params["transformer"][in_layer_name_local]["pre_ffw_norm"]["scale"] + 1)

layer_weight["post_self_attention_norm_local"]["scale"].append(
params["transformer"][in_layer_name_local]["post_attention_norm"]["scale"] + 1
)
layer_weight["post_ffw_norm_local"]["scale"].append(params["transformer"][in_layer_name_local]["post_ffw_norm"]["scale"] + 1)

######################## layer global attention ########################

self_attention_global["query"]["kernel"].append(
params["transformer"][in_layer_name_global]["attn"]["q_einsum"]["w"].transpose((1, 0, 2)) * query_pre_attn_scalar
)
self_attention_global["key"]["kernel"].append(
params["transformer"][in_layer_name_global]["attn"]["kv_einsum"]["w"][0].transpose((1, 0, 2))
)
self_attention_global["value"]["kernel"].append(
params["transformer"][in_layer_name_global]["attn"]["kv_einsum"]["w"][1].transpose((1, 0, 2))
)
self_attention_global["out"]["kernel"].append(params["transformer"][in_layer_name_global]["attn"]["attn_vec_einsum"]["w"])

# mlp
if transpose_gating_einsum:
layer_weight["mlp_global"]["wi_0"]["kernel"].append(np.transpose(params["transformer"][in_layer_name_global]["mlp"]["gating_einsum"]["w"][0]))
layer_weight["mlp_global"]["wi_1"]["kernel"].append(np.transpose(params["transformer"][in_layer_name_global]["mlp"]["gating_einsum"]["w"][1]))
else:
layer_weight["mlp_global"]["wi_0"]["kernel"].append(params["transformer"][in_layer_name_global]["mlp"]["gating_einsum"]["w"][0])
layer_weight["mlp_global"]["wi_1"]["kernel"].append(params["transformer"][in_layer_name_global]["mlp"]["gating_einsum"]["w"][1])

layer_weight["mlp_global"]["wo"]["kernel"].append(params["transformer"][in_layer_name_global]["mlp"]["linear"]["w"])

layer_weight["pre_self_attention_norm_global"]["scale"].append(
params["transformer"][in_layer_name_global]["pre_attention_norm"]["scale"] + 1
)
layer_weight["pre_ffw_norm_global"]["scale"].append(params["transformer"][in_layer_name_global]["pre_ffw_norm"]["scale"] + 1)

layer_weight["post_self_attention_norm_global"]["scale"].append(
params["transformer"][in_layer_name_global]["post_attention_norm"]["scale"] + 1
)
layer_weight["post_ffw_norm_global"]["scale"].append(params["transformer"][in_layer_name_global]["post_ffw_norm"]["scale"] + 1)

self_attention_local["query"]["kernel"] = np.array(self_attention_local["query"]["kernel"]).transpose((1, 0, 2, 3))
self_attention_local["key"]["kernel"] = np.array(self_attention_local["key"]["kernel"]).transpose((1, 0, 2, 3))
self_attention_local["value"]["kernel"] = np.array(self_attention_local["value"]["kernel"]).transpose((1, 0, 2, 3))
self_attention_local["out"]["kernel"] = np.array(self_attention_local["out"]["kernel"]).transpose((1, 0, 2, 3))

self_attention_global["query"]["kernel"] = np.array(self_attention_global["query"]["kernel"]).transpose((1, 0, 2, 3))
self_attention_global["key"]["kernel"] = np.array(self_attention_global["key"]["kernel"]).transpose((1, 0, 2, 3))
self_attention_global["value"]["kernel"] = np.array(self_attention_global["value"]["kernel"]).transpose((1, 0, 2, 3))
self_attention_global["out"]["kernel"] = np.array(self_attention_global["out"]["kernel"]).transpose((1, 0, 2, 3))

layer_weight["mlp_local"]["wi_0"]["kernel"] = np.array(layer_weight["mlp_local"]["wi_0"]["kernel"]).transpose((1, 0, 2))
layer_weight["mlp_local"]["wi_1"]["kernel"] = np.array(layer_weight["mlp_local"]["wi_1"]["kernel"]).transpose((1, 0, 2))
layer_weight["mlp_local"]["wo"]["kernel"] = np.array(layer_weight["mlp_local"]["wo"]["kernel"]).transpose((1, 0, 2))

layer_weight["mlp_global"]["wi_0"]["kernel"] = np.array(layer_weight["mlp_global"]["wi_0"]["kernel"]).transpose((1, 0, 2))
layer_weight["mlp_global"]["wi_1"]["kernel"] = np.array(layer_weight["mlp_global"]["wi_1"]["kernel"]).transpose((1, 0, 2))
layer_weight["mlp_global"]["wo"]["kernel"] = np.array(layer_weight["mlp_global"]["wo"]["kernel"]).transpose((1, 0, 2))

layer_weight["pre_self_attention_norm_local"]["scale"] = np.array(layer_weight["pre_self_attention_norm_local"]["scale"]).transpose(
(1, 0)
)
layer_weight["pre_ffw_norm_local"]["scale"] = np.array(layer_weight["pre_ffw_norm_local"]["scale"]).transpose((1, 0))
layer_weight["post_self_attention_norm_local"]["scale"] = np.array(layer_weight["post_self_attention_norm_local"]["scale"]).transpose(
(1, 0)
)
layer_weight["post_ffw_norm_local"]["scale"] = np.array(layer_weight["post_ffw_norm_local"]["scale"]).transpose((1, 0))

layer_weight["pre_self_attention_norm_global"]["scale"] = np.array(layer_weight["pre_self_attention_norm_global"]["scale"]).transpose(
(1, 0)
)
layer_weight["pre_ffw_norm_global"]["scale"] = np.array(layer_weight["pre_ffw_norm_global"]["scale"]).transpose((1, 0))
layer_weight["post_self_attention_norm_global"]["scale"] = np.array(layer_weight["post_self_attention_norm_global"]["scale"]).transpose(
(1, 0)
)
layer_weight["post_ffw_norm_global"]["scale"] = np.array(layer_weight["post_ffw_norm_global"]["scale"]).transpose((1, 0))

layer_weight["self_attention_local"] = copy.deepcopy(self_attention_local)
layer_weight["self_attention_global"] = copy.deepcopy(self_attention_global)

jax_weights["decoder"]["layers"] = copy.deepcopy(layer_weight)
jax_weights = jax.tree_util.tree_map(jnp.array, jax_weights)

def astype_fn(x):
if isinstance(x, jnp.ndarray):
return x.astype(jnp.bfloat16)
else:
return x

jax_weights = jax.tree_util.tree_map(astype_fn, jax_weights)

enable_checkpointing = True
async_checkpointing = False
save_interval_steps = 1

checkpoint_manager = checkpointing.create_orbax_checkpoint_manager(
args.maxtext_model_path, enable_checkpointing, async_checkpointing, save_interval_steps
)

state_new = train_state.TrainState(
step=0, apply_fn=None, params={"params": jax_weights}, tx=None, opt_state={} # type: ignore
)

if checkpoint_manager is not None:
if save_checkpoint(checkpoint_manager, 0, state_new):
max_logging.log("saved a checkpoint at step 0")
# Upon preemption, exit when and only when all ongoing saves are complete.
if checkpoint_manager.reached_preemption(0):
checkpoint_manager.wait_until_finished()
sys.exit()


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def generate_attention_mask(self, query, key, decoder_segment_ids: Array | None,
all_ones, -1 * self.sliding_window_size + 1
) * jnp.tril(all_ones, self.sliding_window_size - 1)
output_mask = sliding_mask * output_mask

return jnp.where(output_mask, 0.0, DEFAULT_MASK_VALUE) if output_mask is not None else None

def apply_attention(self, query: Array, key: Array| KVTensor, value: Array| KVTensor, decoder_segment_ids: Array | None, model_mode: str):
Expand Down
Loading
Loading