From 6f5cee12ae00444d2341c7661944ac46353aef5f Mon Sep 17 00:00:00 2001 From: shuw Date: Thu, 11 Apr 2024 19:48:51 +0000 Subject: [PATCH] Direct quantization for FP8 Dense Layer. --- flax/linen/__init__.py | 7 +- flax/linen/fp8_ops.py | 327 +++++++++++++++++++++++++++++++++++--- tests/linen/linen_test.py | 52 +++--- 3 files changed, 341 insertions(+), 45 deletions(-) diff --git a/flax/linen/__init__.py b/flax/linen/__init__.py index f01ed92880..24a33d8730 100644 --- a/flax/linen/__init__.py +++ b/flax/linen/__init__.py @@ -86,8 +86,11 @@ ) from .batch_apply import BatchApply as BatchApply from .combinators import Sequential as Sequential -from .fp8_ops import Fp8DotGeneralOp as Fp8DotGeneralOp -from .fp8_ops import NANOOFp8DotGeneralOp as NANOOFp8DotGeneralOp +from .fp8_ops import ( + Fp8DotGeneralOp as Fp8DotGeneralOp, + Fp8DirectDotGeneralOp as Fp8DirectDotGeneralOp, + NANOOFp8DotGeneralOp as NANOOFp8DotGeneralOp, +) from .initializers import ( ones_init as ones_init, ones as ones, diff --git a/flax/linen/fp8_ops.py b/flax/linen/fp8_ops.py index 490647186e..24bf2acb38 100644 --- a/flax/linen/fp8_ops.py +++ b/flax/linen/fp8_ops.py @@ -13,6 +13,7 @@ # limitations under the License. import dataclasses +import itertools import numpy as np import warnings from functools import partial @@ -25,6 +26,8 @@ from jax import numpy as jnp from jax._src import core from jax._src import dtypes +from jax._src.lax import lax +from jax._src.typing import DTypeLike try: from jax._src import earray @@ -163,39 +166,91 @@ def compute_amax_history(x, amax_history): return new_history -def qdq_and_return(x, q_dtype, scale, amax_history, compute_dtype): - is_fmax32 = (scale.dtype == fp32_max_grad and - amax_history.dtype == fp32_max_grad) - # convert fmax32->f32 so we can do math +def quantize_and_update( + x, q_dtype, scale, amax_history, compute_dtype, use_direct_quant=False +): + is_fmax32 = (scale.dtype == fm32 and amax_history.dtype == fm32) + # convert fm32->f32 so we can do math if is_fmax32: amax_history = lax.convert_element_type(amax_history, jnp.float32) scale = lax.convert_element_type(scale, jnp.float32) + # Update the fp8 meta dtype_max = get_fp8_max(q_dtype, jnp.float32) amax_from_history = jnp.max(amax_history, axis=0) - new_scale = compute_scale(amax_from_history, scale, dtype_max) - - qx = quantize_dequantize(x, q_dtype, new_scale, compute_dtype) + new_scale = compute_scale(amax_from_history, scale, dtype_max) new_history = compute_amax_history(x, amax_history) # convert f32->fmax32 so the autodiff system accumulates fp8 meta correctly if is_fmax32: new_history = lax.convert_element_type(new_history, fp32_max_grad) new_scale = lax.convert_element_type(new_scale, fp32_max_grad) + + # Quantize the input + if not use_direct_quant: + qx = quantize_dequantize(x, q_dtype, new_scale, compute_dtype) + return qx, new_scale, new_history + + return new_scale, new_history + + return qx, new_scale, new_history +def dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, + preferred_element_type: DTypeLike | None, + swap_ans=False): + def _remaining(original, *removed_lists): + removed = set(itertools.chain(*removed_lists)) + return [i for i in original if i not in removed] + + def _ranges_like(*xs): + start = 0 + for x in xs: + x_len = len(x) + yield range(start, start + x_len) + start += x_len + + (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers + x_ndim = x.aval.ndim + x_kept = _remaining(range(x_ndim), x_contract, x_batch) + y_kept = _remaining(range(np.ndim(y)), y_contract, y_batch) + if swap_ans: + ans_batch, ans_y, _ = _ranges_like(x_batch, y_kept, x_kept) + else: + ans_batch, _, ans_y = _ranges_like(x_batch, x_kept, y_kept) + dims = ((ans_y, y_kept), (ans_batch, y_batch)) + x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract))) + out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y) + x_bar = lax.transpose( + lax.dot_general( + g, y, dims, precision=precision, + preferred_element_type=preferred_element_type + ), + tuple(out_axes) + ) + return x_bar + +def dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision, + preferred_element_type: DTypeLike | None): + (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers + swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch)) + y_bar = dot_general_transpose_lhs( + g, y, x, dimension_numbers=swapped_dimension_numbers, precision=precision, + preferred_element_type=preferred_element_type, + swap_ans=True) + return y_bar @partial(custom_vjp, nondiff_argnums=(0, 1)) def in_qdq(compute_dtype, q_dtype, inp, scale, amax_history): - qin, _, _ = qdq_and_return( + qin, _, _ = quantize_and_update( inp, q_dtype, scale, amax_history, compute_dtype ) return qin def in_qdq_fwd(compute_dtype, q_dtype, inp, scale, amax_history): - qin, new_scale, new_history = qdq_and_return( + qin, new_scale, new_history = quantize_and_update( inp, q_dtype, scale, amax_history, compute_dtype ) return qin, (new_scale, new_history) @@ -221,7 +276,7 @@ def out_qdq_fwd(compute_dtype, q_dtype, out, scale, amax_history): def out_qdq_bwd(compute_dtype, q_dtype, res, g): scale, amax_history = res - q_g, new_scale, new_history = qdq_and_return( + q_g, new_scale, new_history = quantize_and_update( g, q_dtype, scale, amax_history, compute_dtype ) return q_g, new_scale, new_history @@ -230,6 +285,203 @@ def out_qdq_bwd(compute_dtype, q_dtype, res, g): out_qdq.defvjp(out_qdq_fwd, out_qdq_bwd) +def q_dot_dq_impl( + lhs, + rhs, + lhs_scale, + rhs_scale, + out_grad_scale, + lhs_amax_history, + rhs_amax_history, + out_grad_amax_history, + compute_dtype, + dimension_numbers, + precision, + preferred_element_type, + is_training +): + if precision != None or preferred_element_type != None: + warnings.warn( + "The function dot_general_with_precision will set the " + "precision/preferred_element_type and disregard any provided " + "values." + ) + new_lhs_scale, new_lhs_amax_history = quantize_and_update( + lhs, + jnp.float8_e4m3fn, + lhs_scale, + lhs_amax_history, + compute_dtype, + use_direct_quant=True + ) + new_rhs_scale, new_rhs_amax_history = quantize_and_update( + rhs, + jnp.float8_e4m3fn, + rhs_scale, + rhs_amax_history, + compute_dtype, + use_direct_quant=True + ) + + q_lhs = quantize(lhs, jnp.float8_e4m3fn, new_lhs_scale, preferred_element_type) + q_rhs = quantize(rhs, jnp.float8_e4m3fn, new_rhs_scale, preferred_element_type) + + out = lax.dot_general( + q_lhs, + q_rhs, + dimension_numbers, + preferred_element_type=preferred_element_type, + precision=lax.Precision.DEFAULT, + ) + + out = dequantize(out, preferred_element_type, new_lhs_scale * new_rhs_scale) + if is_training: + res = ( + lhs, + rhs, + q_lhs, + q_rhs, + new_lhs_scale, + new_rhs_scale, + out_grad_scale, + new_lhs_amax_history, + new_rhs_amax_history, + out_grad_amax_history, + ) + return out, res + else: + return out + + +@partial(custom_vjp, nondiff_argnums=(8, 9, 10, 11)) +def q_dot_dq( + lhs, + rhs, + lhs_scale, + rhs_scale, + out_grad_scale, + lhs_amax_history, + rhs_amax_history, + out_grad_amax_history, + compute_dtype, + dimension_numbers, + precision=None, + preferred_element_type=None +): + return q_dot_dq_impl( + lhs, + rhs, + lhs_scale, + rhs_scale, + out_grad_scale, + lhs_amax_history, + rhs_amax_history, + out_grad_amax_history, + compute_dtype, + dimension_numbers, + precision, + preferred_element_type, + is_training=False, + ) + + +def q_dot_dq_fwd( + lhs, + rhs, + lhs_scale, + rhs_scale, + out_grad_scale, + lhs_amax_history, + rhs_amax_history, + out_grad_amax_history, + compute_dtype, + dimension_numbers, + precision, + preferred_element_type, +): + return q_dot_dq_impl( + lhs, + rhs, + lhs_scale, + rhs_scale, + out_grad_scale, + lhs_amax_history, + rhs_amax_history, + out_grad_amax_history, + compute_dtype, + dimension_numbers, + precision, + preferred_element_type, + is_training=True + ) + + +def q_dot_dq_bwd( + compute_dtype, + dimension_numbers, + precision, + preferred_element_type, + res, + g +): + ( + lhs, + rhs, + q_lhs, + q_rhs, + new_lhs_scale, + new_rhs_scale, + out_grad_scale, + new_lhs_amax_history, + new_rhs_amax_history, + out_grad_amax_history, + ) = res + + new_out_grad_scale, new_out_grad_amax_history = quantize_and_update( + g, + jnp.float8_e5m2, + out_grad_scale, + out_grad_amax_history, + compute_dtype, + use_direct_quant=True + ) + + q_g = quantize(g, jnp.float8_e5m2, new_out_grad_scale, preferred_element_type) + + grad_lhs = dot_general_transpose_lhs( + q_g, + lhs, + q_rhs, + dimension_numbers=dimension_numbers, + precision=lax.Precision.HIGHEST, + preferred_element_type=preferred_element_type, + ) + grad_lhs = dequantize(grad_lhs, preferred_element_type, new_rhs_scale * new_out_grad_scale) + + grad_rhs = dot_general_transpose_rhs( + q_g, + q_lhs, + rhs, + dimension_numbers=dimension_numbers, + precision=lax.Precision.HIGHEST, + preferred_element_type=preferred_element_type, + ) + grad_rhs = dequantize(grad_rhs, preferred_element_type, new_lhs_scale * new_out_grad_scale) + + return ( + grad_lhs, + grad_rhs, + new_lhs_scale, + new_rhs_scale, + new_out_grad_scale, + new_lhs_amax_history, + new_rhs_amax_history, + new_out_grad_amax_history, + ) + +q_dot_dq.defvjp(q_dot_dq_fwd, q_dot_dq_bwd) + + @partial(custom_jvp, nondiff_argnums=(2, 3, 4)) def dot_general_with_precision( lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None @@ -263,7 +515,20 @@ def dot_general_with_precision_jvp( return out, grad_out -class Fp8DotGeneralOp(module.Module): +def _parse_dot_inputs(*args, **kwargs): + assert len(args) == 3 + x = args[0] + k = args[1] + dimension_numbers = args[2] + + # Use the `k.dtype` since it aligns with the `dtype` of its layers, + # namely, the computation data type. + comp_dtype = k.dtype + x = jnp.asarray(x, comp_dtype) + return x, k, dimension_numbers, comp_dtype + + +class Fp8DotGeneralBase(module.Module): amax_history_length: int = 1024 e4m3_dtype: DType = jnp.float8_e4m3fn e5m2_dtype: DType = jnp.float8_e5m2 @@ -302,24 +567,20 @@ def setup(self) -> None: OVERWRITE_WITH_GRADIENT, 'output_grad_scale', *scale_args ) - def __call__(self, *args, **kwargs): - assert len(args) == 3 - x = args[0] - k = args[1] - dimension_numbers = args[2] - - # Use the `k.dtype` since it aligns with the `dtype` of its layers, - # namely, the computation data type. - comp_dtype = k.dtype - x = jnp.asarray(x, comp_dtype) +class Fp8DotGeneralOp(Fp8DotGeneralBase): + def __call__(self, *args, **kwargs): + x, k, dimension_numbers, comp_dtype = _parse_dot_inputs( + *args, **kwargs + ) x_qdq = in_qdq( comp_dtype, self.e4m3_dtype, x, self.input_scale.value, self.input_amax_history.value ) k_qdq = in_qdq( comp_dtype, self.e4m3_dtype, k, self.kernel_scale.value, self.kernel_amax_history.value ) - y_qdq = dot_general_with_precision(x_qdq, k_qdq, dimension_numbers) # type: ignore + + y_qdq = dot_general_with_precision(x_qdq, k_qdq, dimension_numbers) y = out_qdq( comp_dtype, self.e5m2_dtype, @@ -330,6 +591,28 @@ def __call__(self, *args, **kwargs): return y # type: ignore +class Fp8DirectDotGeneralOp(Fp8DotGeneralBase): + def __call__(self, *args, **kwargs): + x, k, dimension_numbers, comp_dtype = _parse_dot_inputs( + *args, **kwargs + ) + + y = q_dot_dq( + x, + k, + self.input_scale.value, + self.kernel_scale.value, + self.output_grad_scale.value, + self.input_amax_history.value, + self.kernel_amax_history.value, + self.output_grad_amax_history.value, + comp_dtype, + dimension_numbers, + preferred_element_type=x.dtype + ) + + return y # type: ignore + class NANOOFp8DotGeneralOp(Fp8DotGeneralOp): e4m3_dtype: DType = jnp.float8_e4m3fnuz e5m2_dtype: DType = jnp.float8_e5m2fnuz diff --git a/tests/linen/linen_test.py b/tests/linen/linen_test.py index 7fe0b8be95..8e2fd51cd6 100644 --- a/tests/linen/linen_test.py +++ b/tests/linen/linen_test.py @@ -1255,9 +1255,11 @@ def get_fp8_dtypes(fp8_genre): class Fp8Test(parameterized.TestCase): @parameterized.parameters( - {'fp8_genre': 'OCP'}, {'fp8_genre': 'NANOO'} + {'fp8_genre': 'OCP', 'use_direct_quant': True}, + {'fp8_genre': 'OCP', 'use_direct_quant': False}, + {'fp8_genre': 'NANOO', 'use_direct_quant': False} ) - def test_fp8_dot_general_injection(self, fp8_genre): + def test_fp8_dot_general_injection(self, fp8_genre, use_direct_quant): # Used to cast the inputs to be representable in FP8, so that the difference # of the results from the original gemm and fp8 gemm is small. cast_to_representable = functools.partial( @@ -1276,13 +1278,20 @@ def test_fp8_dot_general_injection(self, fp8_genre): random.uniform(random_key, (16, 64)), e5m2_dtype ) + if fp8_genre == 'NANOO': + assert use_direct_quant == False + quant_cls = nn.NANOOFp8DotGeneralOp + else: + quant_cls = ( + nn.Fp8DirectDotGeneralOp + if use_direct_quant else nn.Fp8DotGeneralOp + ) + def run(fp8_injection, expected_shapes): p = nn.DenseGeneral(features=64, name='dense') + if fp8_injection: - if fp8_genre == 'OCP': - p.dot_general_cls = nn.Fp8DotGeneralOp - else: - p.dot_general_cls = nn.NANOOFp8DotGeneralOp + p.dot_general_cls = quant_cls init_fn = jax.jit(p.init_with_output) y, initial_vars = init_fn(init_key, x) @@ -1301,14 +1310,11 @@ def _train(variables, x): expected_shapes_original = { 'params': {'kernel': (32, 64), 'bias': (64,)}, } - if fp8_genre == 'OCP': - fp8_op_name = 'Fp8DotGeneralOp_0' - else: - fp8_op_name = 'NANOOFp8DotGeneralOp_0' + expected_shapes_new = { 'params': {'kernel': (32, 64), 'bias': (64,)}, fp8_ops.OVERWRITE_WITH_GRADIENT: { - fp8_op_name: { + f'{quant_cls.__name__}_0': { 'input_amax_history': (1024,), 'kernel_amax_history': (1024,), 'output_grad_amax_history': (1024,), @@ -1318,7 +1324,6 @@ def _train(variables, x): } }, } - output1a, output1b = run(False, expected_shapes_original) output2a, output2b = run(True, expected_shapes_new) dw1, dw2 = output1b[0]['params']['kernel'], output2b[0]['params']['kernel'] @@ -1329,19 +1334,24 @@ def _train(variables, x): np.testing.assert_allclose(dx1, dx2, atol=1e-04) @parameterized.parameters( - {'fp8_genre': 'OCP'}, {'fp8_genre': 'NANOO'} + {'fp8_genre': 'OCP', 'use_direct_quant': True}, + {'fp8_genre': 'OCP', 'use_direct_quant': False}, + {'fp8_genre': 'NANOO', 'use_direct_quant': False} ) - def test_fp8_train_state(self, fp8_genre): + def test_fp8_train_state(self, fp8_genre, use_direct_quant): key, init_key, random_key = random.split(random.PRNGKey(seed=123), 3) x = random.uniform(random_key, (16, 16), dtype=jnp.float32) - if fp8_genre == 'OCP': - fp8_dot_op = nn.Fp8DotGeneralOp - fp8_op_name = 'Fp8DotGeneralOp_0' + + if fp8_genre == 'NANOO': + assert use_direct_quant == False + quant_cls = nn.NANOOFp8DotGeneralOp else: - fp8_dot_op = nn.NANOOFp8DotGeneralOp - fp8_op_name = 'NANOOFp8DotGeneralOp_0' + quant_cls = ( + nn.Fp8DirectDotGeneralOp + if use_direct_quant else nn.Fp8DotGeneralOp + ) dense = nn.DenseGeneral( - features=32, use_bias=True, dot_general_cls=fp8_dot_op + features=32, use_bias=True, dot_general_cls=quant_cls ) init_fn = jax.jit(dense.init) @@ -1395,7 +1405,7 @@ def loss_fn(vars): rtol, atol = 0.001, 0.001 fp8_vars = state.params[fp8_ops.OVERWRITE_WITH_GRADIENT][ - fp8_op_name + f'{quant_cls.__name__}_0' ] np.testing.assert_allclose( fp8_vars['input_amax_history'],