Skip to content

Commit

Permalink
Prototype direct quantization
Browse files Browse the repository at this point in the history
Use only 1 output scale.

Improve based on review 1

update to new_scale

jvp impl

vjp impl

Add fm32 convert for new_scale

Improve

Use impl

Improve based on review 3

Use base class

Fix indent

Improve based on review 4

quote
  • Loading branch information
wenscarl committed Jul 11, 2024
1 parent 239d4e6 commit a25f7fc
Show file tree
Hide file tree
Showing 3 changed files with 290 additions and 36 deletions.
7 changes: 5 additions & 2 deletions flax/linen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,11 @@
)
from .batch_apply import BatchApply as BatchApply
from .combinators import Sequential as Sequential
from .fp8_ops import Fp8DotGeneralOp as Fp8DotGeneralOp
from .fp8_ops import NANOOFp8DotGeneralOp as NANOOFp8DotGeneralOp
from .fp8_ops import (
Fp8DotGeneralOp as Fp8DotGeneralOp,
Fp8DirectDotGeneralOp as Fp8DirectDotGeneralOp,
NANOOFp8DotGeneralOp as NANOOFp8DotGeneralOp,
)
from .initializers import (
ones_init as ones_init,
ones as ones,
Expand Down
289 changes: 267 additions & 22 deletions flax/linen/fp8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from jax import numpy as jnp
from jax._src import core
from jax._src import dtypes
from jax._src.lax.lax import _dot_general_transpose_lhs, _dot_general_transpose_rhs

try:
from jax._src import earray
Expand Down Expand Up @@ -163,39 +164,48 @@ def compute_amax_history(x, amax_history):
return new_history


def qdq_and_return(x, q_dtype, scale, amax_history, compute_dtype):
is_fmax32 = (scale.dtype == fp32_max_grad and
amax_history.dtype == fp32_max_grad)
# convert fmax32->f32 so we can do math
if is_fmax32:
def quantize_and_update(
x, q_dtype, scale, amax_history, compute_dtype, use_direct_quant=False
):
is_fm32 = scale.dtype == fm32 and amax_history.dtype == fm32
# convert fm32->f32 so we can do math
if is_fm32:
amax_history = lax.convert_element_type(amax_history, jnp.float32)
scale = lax.convert_element_type(scale, jnp.float32)

# Update the fp8 meta
dtype_max = get_fp8_max(q_dtype, jnp.float32)
amax_from_history = jnp.max(amax_history, axis=0)
new_scale = compute_scale(amax_from_history, scale, dtype_max)

qx = quantize_dequantize(x, q_dtype, new_scale, compute_dtype)

new_scale = compute_scale(amax_from_history, scale, dtype_max)
new_history = compute_amax_history(x, amax_history)

# convert f32->fmax32 so the autodiff system accumulates fp8 meta correctly
if is_fmax32:
new_history = lax.convert_element_type(new_history, fp32_max_grad)
new_scale = lax.convert_element_type(new_scale, fp32_max_grad)

# Quantize the input
if not use_direct_quant:
qx = quantize_dequantize(x, q_dtype, new_scale, compute_dtype)
return qx, new_scale, new_history

return new_scale, new_history


return qx, new_scale, new_history


@partial(custom_vjp, nondiff_argnums=(0, 1))
def in_qdq(compute_dtype, q_dtype, inp, scale, amax_history):
qin, _, _ = qdq_and_return(
qin, _, _ = quantize_and_update(
inp, q_dtype, scale, amax_history, compute_dtype
)
return qin


def in_qdq_fwd(compute_dtype, q_dtype, inp, scale, amax_history):
qin, new_scale, new_history = qdq_and_return(
qin, new_scale, new_history = quantize_and_update(
inp, q_dtype, scale, amax_history, compute_dtype
)
return qin, (new_scale, new_history)
Expand All @@ -221,7 +231,7 @@ def out_qdq_fwd(compute_dtype, q_dtype, out, scale, amax_history):

def out_qdq_bwd(compute_dtype, q_dtype, res, g):
scale, amax_history = res
q_g, new_scale, new_history = qdq_and_return(
q_g, new_scale, new_history = quantize_and_update(
g, q_dtype, scale, amax_history, compute_dtype
)
return q_g, new_scale, new_history
Expand All @@ -230,6 +240,208 @@ def out_qdq_bwd(compute_dtype, q_dtype, res, g):
out_qdq.defvjp(out_qdq_fwd, out_qdq_bwd)


def q_dot_dq_impl(
lhs,
rhs,
lhs_scale,
rhs_scale,
out_grad_scale,
lhs_amax_history,
rhs_amax_history,
out_grad_amax_history,
compute_dtype,
dimension_numbers,
precision,
preferred_element_type,
is_training
):
if precision != None or preferred_element_type != None:
warnings.warn(
"The function dot_general_with_precision will set the "
"precision/preferred_element_type and disregard any provided "
"values."
)
new_lhs_scale, new_lhs_amax_history = quantize_and_update(
lhs,
jnp.float8_e4m3fn,
lhs_scale,
lhs_amax_history,
compute_dtype,
use_direct_quant=True
)
new_rhs_scale, new_rhs_amax_history = quantize_and_update(
rhs,
jnp.float8_e4m3fn,
rhs_scale,
rhs_amax_history,
compute_dtype,
use_direct_quant=True
)

q_lhs = quantize(lhs, jnp.float8_e4m3fn, new_lhs_scale, preferred_element_type)
q_rhs = quantize(rhs, jnp.float8_e4m3fn, new_rhs_scale, preferred_element_type)

out = lax.dot_general(
q_lhs,
q_rhs,
dimension_numbers,
preferred_element_type=preferred_element_type,
precision=lax.Precision.DEFAULT,
)

out = dequantize(out, preferred_element_type, new_lhs_scale * new_rhs_scale)
if is_training:
res = (
lhs,
rhs,
q_lhs,
q_rhs,
new_lhs_scale,
new_rhs_scale,
out_grad_scale,
new_lhs_amax_history,
new_rhs_amax_history,
out_grad_amax_history,
)
return out, res
else:
return out


@partial(custom_vjp, nondiff_argnums=(8, 9, 10, 11))
def q_dot_dq(
lhs,
rhs,
lhs_scale,
rhs_scale,
out_grad_scale,
lhs_amax_history,
rhs_amax_history,
out_grad_amax_history,
compute_dtype,
dimension_numbers,
precision=None,
preferred_element_type=None
):
return q_dot_dq_impl(
lhs,
rhs,
lhs_scale,
rhs_scale,
out_grad_scale,
lhs_amax_history,
rhs_amax_history,
out_grad_amax_history,
compute_dtype,
dimension_numbers,
precision,
preferred_element_type,
is_training=False,
)


def q_dot_dq_fwd(
lhs,
rhs,
lhs_scale,
rhs_scale,
out_grad_scale,
lhs_amax_history,
rhs_amax_history,
out_grad_amax_history,
compute_dtype,
dimension_numbers,
precision,
preferred_element_type,
):
return q_dot_dq_impl(
lhs,
rhs,
lhs_scale,
rhs_scale,
out_grad_scale,
lhs_amax_history,
rhs_amax_history,
out_grad_amax_history,
compute_dtype,
dimension_numbers,
precision,
preferred_element_type,
is_training=True
)


def q_dot_dq_bwd(
compute_dtype,
dimension_numbers,
precision,
preferred_element_type,
res,
g
):
(
lhs,
rhs,
q_lhs,
q_rhs,
new_lhs_scale,
new_rhs_scale,
out_grad_scale,
new_lhs_amax_history,
new_rhs_amax_history,
out_grad_amax_history,
) = res

new_out_grad_scale, new_out_grad_amax_history = quantize_and_update(
g,
jnp.float8_e5m2,
out_grad_scale,
out_grad_amax_history,
compute_dtype,
use_direct_quant=True
)

q_g = quantize(g, jnp.float8_e5m2, new_out_grad_scale, preferred_element_type)

grad_lhs = _dot_general_transpose_lhs(
q_g,
lhs,
q_rhs,
dimension_numbers=dimension_numbers,
precision=lax.Precision.HIGHEST,
preferred_element_type=preferred_element_type,
)
grad_lhs = dequantize(
grad_lhs, preferred_element_type, new_rhs_scale * new_out_grad_scale
)

grad_rhs = _dot_general_transpose_rhs(
q_g,
q_lhs,
rhs,
dimension_numbers=dimension_numbers,
precision=lax.Precision.HIGHEST,
preferred_element_type=preferred_element_type,
)

grad_rhs = dequantize(
grad_rhs, preferred_element_type, new_lhs_scale * new_out_grad_scale
)

return (
grad_lhs,
grad_rhs,
new_lhs_scale,
new_rhs_scale,
new_out_grad_scale,
new_lhs_amax_history,
new_rhs_amax_history,
new_out_grad_amax_history,
)

q_dot_dq.defvjp(q_dot_dq_fwd, q_dot_dq_bwd)


@partial(custom_jvp, nondiff_argnums=(2, 3, 4))
def dot_general_with_precision(
lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None
Expand Down Expand Up @@ -263,7 +475,20 @@ def dot_general_with_precision_jvp(
return out, grad_out


class Fp8DotGeneralOp(module.Module):
def _parse_dot_inputs(*args, **kwargs):
assert len(args) == 3
x = args[0]
k = args[1]
dimension_numbers = args[2]

# Use the `k.dtype` since it aligns with the `dtype` of its layers,
# namely, the computation data type.
comp_dtype = k.dtype
x = jnp.asarray(x, comp_dtype)
return x, k, dimension_numbers, comp_dtype


class Fp8DotGeneralBase(module.Module):
amax_history_length: int = 1024
e4m3_dtype: DType = jnp.float8_e4m3fn
e5m2_dtype: DType = jnp.float8_e5m2
Expand Down Expand Up @@ -302,24 +527,21 @@ def setup(self) -> None:
OVERWRITE_WITH_GRADIENT, 'output_grad_scale', *scale_args
)

def __call__(self, *args, **kwargs):
assert len(args) == 3
x = args[0]
k = args[1]
dimension_numbers = args[2]

# Use the `k.dtype` since it aligns with the `dtype` of its layers,
# namely, the computation data type.
comp_dtype = k.dtype
x = jnp.asarray(x, comp_dtype)
class Fp8DotGeneralOp(Fp8DotGeneralBase):

def __call__(self, *args, **kwargs):
x, k, dimension_numbers, comp_dtype = _parse_dot_inputs(
*args, **kwargs
)
x_qdq = in_qdq(
comp_dtype, self.e4m3_dtype, x, self.input_scale.value, self.input_amax_history.value
)
k_qdq = in_qdq(
comp_dtype, self.e4m3_dtype, k, self.kernel_scale.value, self.kernel_amax_history.value
)
y_qdq = dot_general_with_precision(x_qdq, k_qdq, dimension_numbers) # type: ignore

y_qdq = dot_general_with_precision(x_qdq, k_qdq, dimension_numbers)
y = out_qdq(
comp_dtype,
self.e5m2_dtype,
Expand All @@ -330,6 +552,29 @@ def __call__(self, *args, **kwargs):

return y # type: ignore

class Fp8DirectDotGeneralOp(Fp8DotGeneralBase):

def __call__(self, *args, **kwargs):
x, k, dimension_numbers, comp_dtype = _parse_dot_inputs(
*args, **kwargs
)

y = q_dot_dq(
x,
k,
self.input_scale.value,
self.kernel_scale.value,
self.output_grad_scale.value,
self.input_amax_history.value,
self.kernel_amax_history.value,
self.output_grad_amax_history.value,
comp_dtype,
dimension_numbers,
preferred_element_type=x.dtype
)

return y # type: ignore

class NANOOFp8DotGeneralOp(Fp8DotGeneralOp):
e4m3_dtype: DType = jnp.float8_e4m3fnuz
e5m2_dtype: DType = jnp.float8_e5m2fnuz
Loading

0 comments on commit a25f7fc

Please sign in to comment.