Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gpu inference #1240

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,8 @@ def wv_product(self, attn_weights: Array, value: Array | KVTensor, model_mode: s

einsum = jnp.einsum
if self.kv_quant:
if isinstance(value, KVTensor) and self.kv_quant.dtype == jnp.float8_e4m3fn:
value.qvalue = value.qvalue.astype(jnp.bfloat16)
einsum = self.kv_quant.einsum_fn_with_rhs_qtensor_and_dequant(value)
if model_mode == common_types.MODEL_MODE_TRAIN or self.compute_axis_order == (0, 1, 2, 3):
out = einsum("bkgts,bskd->btkgd", attn_weights, value)
Expand Down Expand Up @@ -896,6 +898,8 @@ def get_cached_values(self, cache_vars, target_dtype, cache_axis_order) -> jax.A
scale_value /= quantizations.MAX_INT8
elif dtype == jnp.int4:
scale_value /= quantizations.MAX_INT4
elif dtype == jnp.float8_e4m3fn:
scale_value /= quantizations.E4M3_MAX

cache_value = KVTensor(qvalue=cache_value, scale=[scale_value], scale_t=None, dequant_dtype=target_dtype, bias=[])
cache_value_in_logical_shape = jax.tree.map(lambda x: self.reverse_transepose(x, cache_axis_order), cache_value)
Expand Down
78 changes: 57 additions & 21 deletions MaxText/layers/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

MAX_INT8 = 127.5
MAX_INT4 = 7.5
E4M3_MAX = jnp.finfo(jnp.float8_e4m3fn).max.astype(jnp.float32)

Array = common_types.Array
Config = common_types.Config
Expand Down Expand Up @@ -222,6 +223,8 @@ def _get_int8_quant_config(config):
drhs_accumulator_dtype=drhs_accumulator_dtype,
)

def _get_aqt_fp8_quant_config(config):
return aqt_config.config_fwd_fp8()

def _dot_general_make(quant_cfg):
lhs_bits = quant_cfg[_A_BITS]
Expand Down Expand Up @@ -265,6 +268,8 @@ def _get_quant_config(config):
return None
if config.quantization == "int8":
return _get_int8_quant_config(config)
if config.quantization == "aqt_fp8":
return _get_aqt_fp8_quant_config(config)
if config.quantization == "intmp":
assert config.quant_cfg_path, "Must specify quant_cfg for mixed precision quantization"
with open(config.quant_cfg_path, "r") as config_file:
Expand Down Expand Up @@ -368,6 +373,8 @@ def _get_dtype(self, dtype_cfg: str):
return jnp.int4
if dtype_cfg == "int8":
return jnp.int8
if dtype_cfg == "fp8":
return jnp.float8_e4m3fn
raise ValueError(f"Invalid kv_quant_dtype: {dtype_cfg}")

def _get_max_axis(self, axis_names: AxisNames):
Expand All @@ -388,42 +395,71 @@ def quantize(self, kv: Array, axis_names: AxisNames):
if self.dtype == jnp.int4:
value = jnp.int4(jnp.rint(kv * (MAX_INT4 / scale)))
return value, scale
if self.dtype == jnp.float8_e4m3fn:
value = jnp.float8_e4m3fn(kv * (E4M3_MAX / scale))
return value, scale
raise ValueError(f"Invalid KV quant dtype:{self.dtype}.")

def dequant(self, kv: KVTensor):
assert isinstance(kv, KVTensor)
qvalue = kv.qvalue
scale = kv.scale[0]
dequant_kv = jnp.bfloat16(qvalue) * jnp.bfloat16(scale)
return dequant_kv

def einsum_fn_with_rhs_qtensor(
self,
kv: Array | aqt_tensor.QTensor,
rhs_dequant_mode=None,
rhs_calibration_mode=None,
):
self,
kv: Array| aqt_tensor.QTensor,
rhs_dequant_mode=None,
rhs_calibration_mode=None,
lhs_dequant_mode=None,
lhs_calibration_mode=None,
):
# Assumes kv is already quantized.
einsum = jnp.einsum
if isinstance(kv, aqt_tensor.QTensor):
num_bits = 4 if kv.qvalue.dtype == jnp.int4 else 8
kv_cfg = aqt_config.dot_general_make(
lhs_bits=None,
rhs_bits=num_bits,
if kv.qvalue.dtype != jnp.float8_e4m3fn:
lhs_bits = None
rhs_bits = 4 if kv.qvalue.dtype == jnp.int4 else 8
kv_cfg = aqt_config.dot_general_make(
lhs_bits=lhs_bits,
rhs_bits=rhs_bits,
bwd_bits=None,
use_fwd_quant=False,
)
)
else:
kv_cfg = aqt_config.config_fwd_fp8()

if rhs_dequant_mode:
aqt_config.set_fwd_dequant_mode(kv_cfg, rhs_dequant_mode=rhs_dequant_mode)
if rhs_calibration_mode:
aqt_config.set_fwd_calibration_mode(
kv_cfg,
rhs_calibration_mode=rhs_calibration_mode,
kv_cfg,
rhs_calibration_mode=rhs_calibration_mode,
)
if lhs_dequant_mode:
aqt_config.set_fwd_dequant_mode(
kv_cfg, lhs_dequant_mode=lhs_dequant_mode
)
if lhs_calibration_mode:
aqt_config.set_fwd_calibration_mode(
kv_cfg,
lhs_calibration_mode=lhs_calibration_mode,
)
einsum = aqt_flax.AqtEinsum(
rhs_quant_mode=aqt_flax.QuantMode.TRAIN,
lhs_freeze_mode=aqt_flax.FreezerMode.NONE,
rhs_freeze_mode=aqt_flax.FreezerMode.NONE,
cfg=kv_cfg,
)
rhs_quant_mode=aqt_flax.QuantMode.TRAIN,
lhs_freeze_mode=aqt_flax.FreezerMode.NONE,
rhs_freeze_mode=aqt_flax.FreezerMode.NONE,
cfg=kv_cfg
)

return einsum

def einsum_fn_with_rhs_qtensor_and_dequant(self, value):
return self.einsum_fn_with_rhs_qtensor(
value,
rhs_dequant_mode=aqt_config.DequantMode.OTHER_INPUT,
rhs_calibration_mode=aqt_config.CalibrationMode.REMAINING_AXIS,
)
value,
lhs_dequant_mode=aqt_config.DequantMode.THIS_INPUT,
lhs_calibration_mode=aqt_config.CalibrationMode.REMAINING_AXIS,
rhs_dequant_mode=aqt_config.DequantMode.OTHER_INPUT,
rhs_calibration_mode=aqt_config.CalibrationMode.REMAINING_AXIS
)
3 changes: 2 additions & 1 deletion MaxText/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,8 @@ def set_engine_vars_from_base_engine(
"""Set internal vars from base_engine, which has already loaded the checkpoint and has sharding,
mesh, and kv cache related vars set.
"""
engine.model.quant.quant_mode = base_engine.model.quant.quant_mode
if base_engine.model.quant is not None:
engine.model.quant.quant_mode = base_engine.model.quant.quant_mode
engine.state_mesh_annotations = base_engine.state_mesh_annotations
engine.abstract_params = base_engine.abstract_params
engine.kv_cache_annotations = max_utils.get_kv_cache_annotations(engine.model, engine.config, rng, engine.mesh) # pylint: disable=protected-access
Expand Down
4 changes: 4 additions & 0 deletions MaxText/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def activate(self, blocking_object=None, optional_postfix=""):
return
self.libcudart.cudaProfilerStart()
elif self.mode == "xplane":
# when running on GPU, seems the gs bucket upload fails
if self.output_path.startswith("gs://"):
self.output_path = "/tmp/jax-profiles/" + self.output_path[4:]
print("Start profiling {}".format(self.output_path))
jax.profiler.start_trace(self.output_path)

def deactivate(self, blocking_object=None):
Expand Down
Loading