Skip to content

Commit

Permalink
Printing string aqt config
Browse files Browse the repository at this point in the history
  • Loading branch information
gobbleturk committed Nov 2, 2023
1 parent a93bf0c commit c016a2a
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion MaxText/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import jax
from jax.sharding import PartitionSpec as P
from jax.experimental.serialize_executable import deserialize_and_load

from aqt.jax.v2.google import aqt_config

import pickle
import functools
Expand Down Expand Up @@ -81,3 +81,15 @@ def get_train_input_output_trees(func, input_args, input_kwargs):
in_tree, out_tree = get_train_input_output_trees(partial_train, shaped_input_args, shaped_input_kwargs)
p_train_step = deserialize_and_load(serialized_compiled, in_tree, out_tree)
return p_train_step

def quanization_config(config):
""" Make the quantization config """
aqt_cfg = aqt_config.quantization_config(
config.fwd_int8,
config.dlhs_int8,
config.drhs_int8,
use_dummy_static_bound=config.aqt_use_dummy_static_bound,
rng_type=config.aqt_rng_type,
use_fwd_quant=config.aqt_use_fwd_quant,
)
return aqt_cfg

0 comments on commit c016a2a

Please sign in to comment.