Skip to content

Commit

Permalink
Merge pull request #757 from google:moe_quantization
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 651070063
  • Loading branch information
maxtext authors committed Jul 10, 2024
2 parents b506266 + 6a0e570 commit 704ab1c
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 0 deletions.
3 changes: 3 additions & 0 deletions MaxText/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ class MoeBlock(nn.Module):
kernel_axes: Tuple with axes to apply kernel function.
weight_dtype: Type for the weights.
dtype: Type for the dense layer.
quant: Optional quantization config, no quantization if None.
"""

config: Config
Expand All @@ -288,6 +289,7 @@ class MoeBlock(nn.Module):
kernel_axes: Tuple[str, ...]
weight_dtype: DType = jnp.float32
dtype: DType = jnp.float32
quant: Optional[Quant] = None

def generate_kernels(self, num_experts, emb_dim, mlp_dim):

Expand Down Expand Up @@ -411,6 +413,7 @@ def __call__(self, inputs):
self.num_experts,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
quant=self.quant,
kernel_init=self.kernel_init,
kernel_axes=self.kernel_axes,
name="gate")(inputs)
Expand Down
2 changes: 2 additions & 0 deletions MaxText/layers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __call__(
kernel_axes=('embed', 'mlp'),
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
quant=self.quant,
)(hidden_states)
mlp_lnx = nn.with_logical_constraint(
mlp_lnx, ('activation_batch', 'activation_length', 'activation_embed')
Expand All @@ -145,6 +146,7 @@ def __call__(
weight_dtype=cfg.weight_dtype,
name="mlp",
config=cfg,
quant=self.quant,
)(hidden_states, deterministic=deterministic)
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed"))

Expand Down

0 comments on commit 704ab1c

Please sign in to comment.