Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perf megablox #694

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ head_dim: 128
num_experts: 1
num_experts_per_tok: 1
moe_matmul: False
megablox: False
mlp_activations: ["silu", "linear"]
dropout_rate: 0
logits_via_embedding: False
Expand Down Expand Up @@ -300,4 +301,4 @@ enable_checkpoint_cloud_logger: False
enable_checkpoint_standard_logger: False

# Single-controller
enable_single_controller: False
enable_single_controller: False
29 changes: 29 additions & 0 deletions MaxText/configs/models/mixtral-8x7b-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for mixtral-8x7b

base_emb_dim: 4096
base_num_query_heads: 32
base_num_kv_heads: 8
base_mlp_dim: 14336
base_num_decoder_layers: 32
head_dim: 128
mlp_activations: ["silu","linear"]
vocab_size: 32000
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1.0e-5
num_experts_per_tok: 2
decoder_block: "mistral"
30 changes: 30 additions & 0 deletions MaxText/configs/models/mixtral-moe-1t.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for mixtral-8x7b

base_emb_dim: 4096
base_num_query_heads: 32
base_num_kv_heads: 8
base_mlp_dim: 14336
base_num_decoder_layers: 32
head_dim: 128
mlp_activations: ["silu","linear"]
vocab_size: 32000
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1.0e-5
num_experts: 12
num_experts_per_tok: 2
decoder_block: "mistral"
30 changes: 30 additions & 0 deletions MaxText/configs/models/mixtral-moe-single.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for mixtral-8x7b

base_emb_dim: 4096
base_num_query_heads: 32
base_num_kv_heads: 8
base_mlp_dim: 14336
base_num_decoder_layers: 32
head_dim: 128
mlp_activations: ["silu","linear"]
vocab_size: 32000
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1.0e-5
num_experts: 1
num_experts_per_tok: 1
decoder_block: "mistral"
30 changes: 30 additions & 0 deletions MaxText/configs/models/mixtral-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for mixtral-8x7b

base_emb_dim: 4096
base_num_query_heads: 32
base_num_kv_heads: 8
base_mlp_dim: 14336
base_num_decoder_layers: 1
head_dim: 128
mlp_activations: ["silu","linear"]
vocab_size: 32000
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1.0e-5
num_experts: 8
num_experts_per_tok: 2
decoder_block: "mistral"
9 changes: 7 additions & 2 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,8 +888,8 @@ class Attention(nn.Module):
max_target_length: int
mesh: Mesh
attention_kernel: str
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
dtype: DType = jnp.bfloat16
weight_dtype: DType = jnp.bfloat16
max_prefill_predict_length: int = -1
dropout_rate: float = 0.0
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal")
Expand Down Expand Up @@ -920,6 +920,7 @@ def query_init(*args):
# pylint: disable=no-value-for-parameter
return self.kernel_init(*args) / depth_scaling

print("query_projection")
query_proj = DenseGeneral(
features=(self.num_query_heads, self.head_dim),
axis=-1,
Expand Down Expand Up @@ -949,6 +950,7 @@ def kv_projection(self, inputs_kv: Array, proj_name: str) -> Array:
if self.num_query_heads % self.num_kv_heads != 0:
raise ValueError("Invalid num_kv_heads for GQA.")

print("kv_proj")
kv_proj = DenseGeneral(
features=(self.num_kv_heads, self.head_dim),
axis=-1,
Expand Down Expand Up @@ -1038,8 +1040,11 @@ def __call__(
value = self.kv_projection(inputs_kv, proj_name="value")

# apply ROPE
# breakpoint()
query = RotaryEmbedding(embedding_dims=self.head_dim, name="query_rotary")(inputs=query, position=inputs_positions)
# breakpoint()
key = self.key_rotary(key, inputs_positions)
# breakpoint()

# annotate with sharding constraint.
query = nn.with_logical_constraint(query, self.query_axis_names)
Expand Down
2 changes: 2 additions & 0 deletions MaxText/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def __call__(self, inputs: Array) -> Array:
one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype)
output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype))
else:
# breakpoint()
output = jnp.asarray(self.embedding, self.dtype)[inputs]
# breakpoint()
output = nn.with_logical_constraint(output, ("activation_batch", "activation_length", "activation_embed"))
return output

Expand Down
Loading
Loading