Skip to content

Commit

Permalink
Improve based on review 3
Browse files Browse the repository at this point in the history
  • Loading branch information
wenscarl committed May 30, 2024
1 parent 858695d commit 2d3e9f5
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 50 deletions.
216 changes: 176 additions & 40 deletions flax/linen/fp8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,18 @@ def dequantize(x, dq_dtype, scale):
return x.astype(dq_dtype) * jnp.broadcast_to(scale.astype(dq_dtype), x.shape)


def get_new_scale(amax, scale, fp8_max, margin=0):
def quantize_dequantize(x, q_dtype, scale, compute_dtype):
qx = quantize(x, q_dtype, scale, compute_dtype)
return dequantize(qx, x.dtype, scale)


def compute_amax_history(x, amax_history):
amax_update = jnp.max(jnp.abs(x)).astype(amax_history.dtype)
new_history = jnp.roll(amax_history, shift=-1, axis=0).at[0].set(amax_update)
return new_history


def compute_scale(amax, scale, fp8_max, margin=0):
# The algorithm for computing the new scale is sourced from
# https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/jax.html#transformer_engine.jax.update_fp8_metas
# wherein the `original_scale` corresponds to the reciprocal of the `scale`
Expand All @@ -121,36 +132,60 @@ def get_new_scale(amax, scale, fp8_max, margin=0):
return 1.0 / sf


def compute_amax_history(x, amax_history):
amax_update = jnp.max(jnp.abs(x)).astype(amax_history.dtype)
new_history = jnp.roll(amax_history, shift=-1, axis=0).at[0].set(amax_update)
return new_history

def compute_scale(q_dtype, scale, amax_history):
dtype_max = get_fp8_max(q_dtype, jnp.float32)
amax_from_history = jnp.max(amax_history, axis=0)
new_scale = get_new_scale(amax_from_history, scale, dtype_max)
return new_scale


def _compute_new_meta(x, q_dtype, scale, amax_history, compute_dtype):
def quantize_and_update(
x, q_dtype, scale, amax_history, compute_dtype, use_direct_quant=False
):
is_fm32 = scale.dtype == fm32 and amax_history.dtype == fm32
# convert fm32->f32 so we can do math
if is_fm32:
amax_history = lax.convert_element_type(amax_history, jnp.float32)
scale = lax.convert_element_type(scale, jnp.float32)

new_scale = compute_scale(q_dtype, scale, amax_history)
# Update the fp8 meta
dtype_max = get_fp8_max(q_dtype, jnp.float32)
amax_from_history = jnp.max(amax_history, axis=0)

new_scale = compute_scale(amax_from_history, scale, dtype_max)
new_history = compute_amax_history(x, amax_history)

# convert f32->fm32 so the autodiff system accumulates fp8 meta correctly
if is_fm32:
new_history = lax.convert_element_type(new_history, fm32)
new_scale = lax.convert_element_type(new_scale, fm32)

# Quantize the input
if not use_direct_quant:
qx = quantize_dequantize(x, q_dtype, new_scale, compute_dtype)
return qx, new_scale, new_history

return new_scale, new_history

def _q_dot_dq_impl(

@partial(custom_vjp, nondiff_argnums=(0,))
def in_qdq(compute_dtype, inp, scale, amax_history):
qin, _, _ = quantize_and_update(
inp, jnp.float8_e4m3fn, scale, amax_history, compute_dtype
)
return qin


def in_qdq_fwd(compute_dtype, inp, scale, amax_history):
qin, new_scale, new_history = quantize_and_update(
inp, jnp.float8_e4m3fn, scale, amax_history, compute_dtype
)
return qin, (new_scale, new_history)


def in_qdq_bwd(compute_dtype, res, g):
new_scale, new_history = res
q_g = g
return q_g, new_scale, new_history


in_qdq.defvjp(in_qdq_fwd, in_qdq_bwd)


def q_dot_dq_impl(
lhs,
rhs,
lhs_scale,
Expand All @@ -163,19 +198,29 @@ def _q_dot_dq_impl(
dimension_numbers,
precision,
preferred_element_type,
is_training=False
is_training
):
if precision != None or preferred_element_type != None:
warnings.warn(
"The function dot_general_with_precision will set the "
"precision/preferred_element_type and disregard any provided "
"values."
)
new_lhs_scale, new_lhs_amax_history = _compute_new_meta(
lhs, jnp.float8_e4m3fn, lhs_scale, lhs_amax_history, compute_dtype
new_lhs_scale, new_lhs_amax_history = quantize_and_update(
lhs,
jnp.float8_e4m3fn,
lhs_scale,
lhs_amax_history,
compute_dtype,
use_direct_quant=True
)
new_rhs_scale, new_rhs_amax_history = _compute_new_meta(
rhs, jnp.float8_e4m3fn, rhs_scale, rhs_amax_history, compute_dtype
new_rhs_scale, new_rhs_amax_history = quantize_and_update(
rhs,
jnp.float8_e4m3fn,
rhs_scale,
rhs_amax_history,
compute_dtype,
use_direct_quant=True
)

q_lhs = quantize(lhs, jnp.float8_e4m3fn, new_lhs_scale, preferred_element_type)
Expand Down Expand Up @@ -209,7 +254,7 @@ def _q_dot_dq_impl(


@partial(custom_vjp, nondiff_argnums=(8, 9, 10, 11))
def _q_dot_dq(
def q_dot_dq(
lhs,
rhs,
lhs_scale,
Expand All @@ -223,7 +268,7 @@ def _q_dot_dq(
precision=None,
preferred_element_type=None
):
return _q_dot_dq_impl(
return q_dot_dq_impl(
lhs,
rhs,
lhs_scale,
Expand All @@ -235,11 +280,12 @@ def _q_dot_dq(
compute_dtype,
dimension_numbers,
precision,
preferred_element_type
preferred_element_type,
is_training=False,
)


def _q_dot_dq_fwd(
def q_dot_dq_fwd(
lhs,
rhs,
lhs_scale,
Expand All @@ -253,7 +299,7 @@ def _q_dot_dq_fwd(
precision,
preferred_element_type,
):
return _q_dot_dq_impl(
return q_dot_dq_impl(
lhs,
rhs,
lhs_scale,
Expand All @@ -270,7 +316,7 @@ def _q_dot_dq_fwd(
)


def _q_dot_dq_bwd(
def q_dot_dq_bwd(
compute_dtype,
dimension_numbers,
precision,
Expand All @@ -291,8 +337,13 @@ def _q_dot_dq_bwd(
out_grad_amax_history,
) = res

new_out_grad_scale, new_out_grad_amax_history = _compute_new_meta(
g, jnp.float8_e5m2, out_grad_scale, out_grad_amax_history, compute_dtype
new_out_grad_scale, new_out_grad_amax_history = quantize_and_update(
g,
jnp.float8_e5m2,
out_grad_scale,
out_grad_amax_history,
compute_dtype,
use_direct_quant=True
)

q_g = quantize(g, jnp.float8_e5m2, new_out_grad_scale, preferred_element_type)
Expand Down Expand Up @@ -333,7 +384,73 @@ def _q_dot_dq_bwd(
new_out_grad_amax_history,
)

_q_dot_dq.defvjp(_q_dot_dq_fwd, _q_dot_dq_bwd)
q_dot_dq.defvjp(q_dot_dq_fwd, q_dot_dq_bwd)


@partial(custom_vjp, nondiff_argnums=(0,))
def out_qdq(compute_dtype, out, scale, amax_history):
return out


def out_qdq_fwd(compute_dtype, out, scale, amax_history):
return out, (scale, amax_history)


def out_qdq_bwd(compute_dtype, res, g):
scale, amax_history = res
q_g, new_scale, new_history = quantize_and_update(
g, jnp.float8_e5m2, scale, amax_history, compute_dtype
)
return q_g, new_scale, new_history


out_qdq.defvjp(out_qdq_fwd, out_qdq_bwd)


@partial(custom_jvp, nondiff_argnums=(2, 3, 4))
def dot_general_with_precision(
lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None
):
if precision != None or preferred_element_type != None:
warnings.warn(
'The function dot_general_with_precision will set the '
'precision/preferred_element_type and disregard any provided '
'values.'
)
return lax.dot_general(
lhs, rhs, dimension_numbers, precision=lax.Precision.DEFAULT
)


@dot_general_with_precision.defjvp
def dot_general_with_precision_jvp(
dimension_numbers, precision, preferred_element_type, primals, tangents
):
lhs, rhs = primals
lhs_dot, rhs_dot = tangents

out = lax.dot_general(
lhs, rhs, dimension_numbers, precision=lax.Precision.DEFAULT
)
grad_out = lax.dot_general(
lhs_dot, rhs, dimension_numbers, precision=lax.Precision.HIGHEST
) + lax.dot_general(
lhs, rhs_dot, dimension_numbers, precision=lax.Precision.HIGHEST
)
return out, grad_out


def _get_dot_general_inputs(*args, **kwargs):
assert len(args) == 3
x = args[0]
k = args[1]
dimension_numbers = args[2]

# Use the `k.dtype` since it aligns with the `dtype` of its layers,
# namely, the computation data type.
comp_dtype = k.dtype
x = jnp.asarray(x, comp_dtype)
return x, k, dimension_numbers, comp_dtype


class Fp8DotGeneralOp(module.Module):
Expand Down Expand Up @@ -374,16 +491,35 @@ def setup(self) -> None:
)

def __call__(self, *args, **kwargs):
assert len(args) == 3
x = args[0]
k = args[1]
dimension_numbers = args[2]

# Use the `k.dtype` since it aligns with the `dtype` of its layers,
# namely, the computation data type.
comp_dtype = k.dtype
x = jnp.asarray(x, comp_dtype)
y = _q_dot_dq(
x, k, dimension_numbers, comp_dtype = _get_dot_general_inputs(
*args, **kwargs
)
x_qdq = in_qdq(
comp_dtype, x, self.input_scale.value, self.input_amax_history.value
)
k_qdq = in_qdq(
comp_dtype, k, self.kernel_scale.value, self.kernel_amax_history.value
)

y_qdq = dot_general_with_precision(x_qdq, k_qdq, dimension_numbers)
y = out_qdq(
comp_dtype,
y_qdq,
self.output_grad_scale.value,
self.output_grad_amax_history.value,
)

return y # type: ignore


class Fp8DirectDotGeneralOp(Fp8DotGeneralOp):

def __call__(self, *args, **kwargs):
x, k, dimension_numbers, comp_dtype = _get_dot_general_inputs(
*args, **kwargs
)

y = q_dot_dq(
x,
k,
self.input_scale.value,
Expand Down
36 changes: 26 additions & 10 deletions tests/linen/linen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,16 +1244,14 @@ def test_hashable(self):
self.assertNotEqual(hash(id1), hash(id1dc))


class Fp8Test(absltest.TestCase):
def test_fp8_dot_general_injection(self):
class Fp8Test(parameterized.TestCase):

@parameterized.parameters([True, False])
def test_fp8_dot_general_injection(self, use_direct_quant):
# Used to cast the inputs to be representable in FP8, so that the difference
# of the results from the original gemm and fp8 gemm is small.
def quantize_dequantize(x, q_dtype, scale, compute_dtype):
qx = fp8_ops.quantize(x, q_dtype, scale, compute_dtype)
return fp8_ops.dequantize(qx, x.dtype, scale)

cast_to_representable = functools.partial(
quantize_dequantize,
fp8_ops.quantize_dequantize,
scale=jnp.ones((1,)),
compute_dtype=jnp.float32,
)
Expand All @@ -1269,7 +1267,10 @@ def quantize_dequantize(x, q_dtype, scale, compute_dtype):
def run(fp8_injection, expected_shapes):
p = nn.DenseGeneral(features=64, name='dense')
if fp8_injection:
p.dot_general_cls = nn.Fp8DotGeneralOp
p.dot_general_cls = (
nn.Fp8DirectDotGeneralOp
if use_direct_quant else nn.Fp8DotGeneralOp
)

init_fn = jax.jit(p.init_with_output)
y, initial_vars = init_fn(init_key, x)
Expand All @@ -1288,7 +1289,7 @@ def _train(variables, x):
expected_shapes_original = {
'params': {'kernel': (32, 64), 'bias': (64,)},
}
expected_shapes_new = {
expected_shapes_fake = {
'params': {'kernel': (32, 64), 'bias': (64,)},
fp8_ops.OVERWRITE_WITH_GRADIENT: {
'Fp8DotGeneralOp_0': {
Expand All @@ -1301,7 +1302,22 @@ def _train(variables, x):
}
},
}

expected_shapes_direct = {
fp8_ops.OVERWRITE_WITH_GRADIENT: {
'Fp8DirectDotGeneralOp_0': {
'input_amax_history': (1024,),
'input_scale': (1,),
'kernel_amax_history': (1024,),
'kernel_scale': (1,),
'output_grad_amax_history': (1024,),
'output_grad_scale': (1,),
}
},
'params': {'bias': (64,), 'kernel': (32, 64)},
}
expected_shapes_new = (
expected_shapes_direct if use_direct_quant else expected_shapes_fake
)
output1a, output1b = run(False, expected_shapes_original)
output2a, output2b = run(True, expected_shapes_new)
dw1, dw2 = output1b[0]['params']['kernel'], output2b[0]['params']['kernel']
Expand Down

0 comments on commit 2d3e9f5

Please sign in to comment.