Skip to content

Commit

Permalink
prune quant params
Browse files Browse the repository at this point in the history
  • Loading branch information
mailvijayasingh committed Oct 27, 2024
1 parent ed2d123 commit f84db7f
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 25 deletions.
21 changes: 18 additions & 3 deletions MaxText/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,12 @@ def generate_kernels(self, num_experts, emb_dim, mlp_dim):
kernel_in_axis,
kernel_out_axis,
)
w0_kernel = jnp.asarray(w0_kernel, self.dtype)
if quantizations.in_serve_mode(self.quant):
# During aqt convert state we delete kernel weight from params to save memory.
# Instead they are retrieved from the tensors stored in the 'aqt' collection.
w0_kernel = jnp.zeros((num_experts, emb_dim, mlp_dim))
else:
w0_kernel = jnp.asarray(w0_kernel, self.dtype)
w1_kernel = self.param(
"wi_1",
nn.with_logical_partitioning(kernel_init, self.wi_kernel_axes),
Expand All @@ -325,7 +330,12 @@ def generate_kernels(self, num_experts, emb_dim, mlp_dim):
kernel_in_axis,
kernel_out_axis,
)
w1_kernel = jnp.asarray(w1_kernel, self.dtype)
if quantizations.in_serve_mode(self.quant):
# During aqt convert state we delete kernel weight from params to save memory.
# Instead they are retrieved from the tensors stored in the 'aqt' collection.
w1_kernel = jnp.zeros((num_experts, emb_dim, mlp_dim))
else:
w1_kernel = jnp.asarray(w1_kernel, self.dtype)
wo_kernel = self.param(
"wo",
nn.with_logical_partitioning(kernel_init, self.wo_kernel_axes),
Expand All @@ -334,7 +344,12 @@ def generate_kernels(self, num_experts, emb_dim, mlp_dim):
kernel_in_axis,
kernel_out_axis,
)
wo_kernel = jnp.asarray(wo_kernel, self.dtype)
if quantizations.in_serve_mode(self.quant):
# During aqt convert state we delete kernel weight from params to save memory.
# Instead they are retrieved from the tensors stored in the 'aqt' collection.
wo_kernel = jnp.zeros((num_experts, mlp_dim, emb_dim))
else:
wo_kernel = jnp.asarray(wo_kernel, self.dtype)
return w0_kernel, w1_kernel, wo_kernel

def permute(self, inputs, gate_logits):
Expand Down
48 changes: 30 additions & 18 deletions MaxText/layers/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,31 +255,43 @@ def configure_quantization(config: Config, quant_mode_str: str = "train"):
return None


def _get_aqt_key_paths(aqt_vars):
"""Generate a list of paths which have aqt state"""
aqt_tree_flat, _ = jax.tree_util.tree_flatten_with_path(aqt_vars)
aqt_key_paths = []
for k, _ in aqt_tree_flat:
pruned_keys = []
for d in list(k):
if "AqtDotGeneral" in d.key:
pruned_keys.append(jax.tree_util.DictKey(key="kernel"))
def match_aqt_and_unquantized_param(aqt_params, params):
aqt_param_flat, aqt_tree_def = jax.tree_util.tree_flatten_with_path(
aqt_params, is_leaf=lambda x: isinstance(x, aqt_tensor.QTensor)
)
param_tree_flat, _ = jax.tree_util.tree_flatten_with_path(params)
aqt_paths = []
# Orginal path of quantized AQT param path.
param_paths = []

for aqt_k, _ in aqt_param_flat:
for index, (k, _) in enumerate(param_tree_flat):
path_depth = len(k)
# every quantized parameter has AQT.. as the leaf node
# AqtDotGeneral and AqtEinsum replace leaf node.
# Therefore, leaf node should be ignored for path matching
if k[: path_depth - 1] == aqt_k[: path_depth - 1]:
aqt_paths.append(aqt_k)
param_paths.append(k)
break
elif "AqtEinsum" in d.key:
continue
else:
assert "Aqt" not in d.key, f"Unexpected Aqt op {d.key} in {k}."
pruned_keys.append(d)
aqt_key_paths.append(tuple(pruned_keys))
return aqt_key_paths
# since the parameter is already added, we can delete it.
param_tree_flat.pop(index)
return jax.tree_util.tree_unflatten(aqt_tree_def, param_paths)


def _get_aqt_key_paths(aqt_vars, params):
"""Generate a list of paths which have aqt state"""
aqt_to_unquantized_key_path = match_aqt_and_unquantized_param(aqt_vars, params)
aqt_key_paths, _ = jax.tree_util.tree_flatten(aqt_to_unquantized_key_path, is_leaf=lambda x: isinstance(x, tuple))
return list(aqt_key_paths)


def remove_quantized_params(params, aqt_vars):
"""Remove param values with aqt tensors to Null to optimize memory."""
aqt_paths = _get_aqt_key_paths(aqt_vars)
quantized_param_paths = _get_aqt_key_paths(aqt_vars, params)
tree_flat, tree_struct = tree_flatten_with_path(params)
for i, (k, v) in enumerate(tree_flat):
if k in aqt_paths:
if k in quantized_param_paths:
v = {}
tree_flat[i] = v
return tree_unflatten(tree_struct, tree_flat)
Expand Down
23 changes: 19 additions & 4 deletions MaxText/tests/quantizations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pyconfig
from layers import quantizations
import unittest
from aqt.jax.v2 import aqt_tensor


class QuantTestModule(nn.Module):
Expand Down Expand Up @@ -110,13 +111,27 @@ def test_remove_quantized_params(self):
"decoder": {
"layers": {
"mlp": {
"wi_0": {"AqtDotGeneral_0": {"qrhs": {"scale": 1.0, "_value": 1.0}}},
"wi_1": {"AqtDotGeneral_0": {"qrhs": {"scale": 1.0, "_value": 1.0}}},
"wo": {"AqtDotGeneral_0": {"qrhs": {"scale": 1.0, "_value": 1.0}}},
"wi_0": {
"AqtDotGeneral_0": {
"qrhs": {"frozen": aqt_tensor.QTensor(qvalue=[1.1, 1.0], scale=[1.0], scale_t=[1.0], bias=1.0)}
}
},
"wi_1": {
"AqtDotGeneral_0": {
"qrhs": {"frozen": aqt_tensor.QTensor(qvalue=[1.1, 1.0], scale=[1.0], scale_t=[1.0], bias=1.0)}
}
},
"wo": {
"AqtDotGeneral_0": {
"qrhs": {"frozen": aqt_tensor.QTensor(qvalue=[1.1, 1.0], scale=[1.0], scale_t=[1.0], bias=1.0)}
}
},
},
"self_attention": {
"key": {
"AqtDotGeneral_0": {"qrhs": {"scale": 1.0, "_value": 1.0}},
"AqtDotGeneral_0": {
"qrhs": {"frozen": aqt_tensor.QTensor(qvalue=[1.1, 1.0], scale=[1.0], scale_t=[1.0], bias=1.0)}
}
}
},
}
Expand Down

0 comments on commit f84db7f

Please sign in to comment.