-
Notifications
You must be signed in to change notification settings - Fork 645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support direct quantization for FP8 matmul #3922
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, the placement of quant and dequant is a bit confusing and the q and dq ops seems to be included in our custom dot_general function. I am trying to summarize the rationale here:
Basically, q
means a pure quantize without amax logics but xxx_q
includes both quantization and amax math.
# Our original design:
x(in_qdq), k(in_qdq)->y
dy(out_qdq), x(in_qdq)->dk
dy(out_qdq), k(in_qdq)->dx
# New direct design:
x(in_q), k(in_q)->y(dq)
dy(out_q), x(in_q)->dk(dq??)
dy(out_q), k(in_q)->dx(dq??)
??
indicates the problem about where to place these dq ops. In the original design, we don't need to worry about where the dk and dx are defined, because we don't apply any qdq there. However, in the new design, we need to find them and apply the dq ops explicitly and because we are using jvp mode (forward autograd mode), where we express the grad like:
dy = dx@k + x@dk
So, it seems we have to include the dq ops inside the dot_general function.
So, if that is the case, should we move all the q and dq into the dot_general function, esp in the jvp:
in_q(x)
in_q(y)
y = x@y
dq(y)
dq(dx)
dq(dk)
dy = dx@k + x@dk
in_q(dy)
Also, by doing this, we don't need vjp on the in_q or out_q, since logics are already expressed into the custom jvp dot_general function.
@@ -142,7 +141,7 @@ def qdq_and_return(x, q_dtype, scale, amax_history, compute_dtype): | |||
amax_from_history = jnp.max(amax_history, axis=0) | |||
new_scale = compute_scale(amax_from_history, scale, dtype_max) | |||
|
|||
qx = quantize_dequantize(x, q_dtype, new_scale, compute_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also remove the quantize_dequantize
? I think it is no longer used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's used in the test file. Removed.
flax/linen/fp8_ops.py
Outdated
q_g, new_scale, new_history = qdq_and_return( | ||
g, jnp.float8_e5m2, scale, amax_history, compute_dtype | ||
q_g, new_scale, new_history = q_and_return( | ||
g, jnp.float8_e5m2, scale, amax_history, compute_dtype #elfie investigate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the comment here still relevant? Or can it be more specific as a TODO note?
flax/linen/fp8_ops.py
Outdated
'The function dot_general_with_precision will set the ' | ||
'precision/preferred_element_type and disregard any provided ' | ||
'values.' | ||
if precision != None or preferred_element_type != None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you accidentally changed the indent here.
flax/linen/fp8_ops.py
Outdated
) | ||
|
||
lhs = quantize(lhs, jnp.float8_e4m3fn, lhs_scale, preferred_element_type) | ||
rhs = quantize(rhs, jnp.float8_e4m3fn, rhs_scale, preferred_element_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems in the forward pass, we directly call the quantize
over the lhs and rhs. But do we need the amax computation?
flax/linen/fp8_ops.py
Outdated
self.output_grad_scale.value, | ||
self.output_grad_amax_history.value, | ||
) | ||
y_q = dot_general_with_precision(x, k, dimension_numbers, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel it would be better to write the code like:
qx = in_quant(x, ...) # which also includes the amax math
qk = in_quant(k, ...)
y = dot_general_and_dequant(qx, qk)
y = grad_quant(y) # let's call it grad_q since it is to apply quantize over gradients
3a4a72d
to
1097287
Compare
I think the new design is much clearer of the idea of direct quantization. By the way, do you think we should create a new Fp8DotGeneral op for it and keep the existing fake quant Op untouched? And then we gradually change downstream uses to migrate to the new op? |
Praxis doesn't use |
5e372dd
to
2d3e9f5
Compare
A gentle reminder to @lukaszlew |
Sorry, I don't have cycles to review this PR. I'm focusing on AQT. |
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
7e3ee81
to
a25f7fc
Compare
Could we get @levskaya to help review since Lukasz is busy with something else these days? |
@levskaya could you take a look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies for the long delay - I was visiting rural family who don't have an internet connection when I was pinged here.
flax/linen/fp8_ops.py
Outdated
|
||
q_g = quantize(g, jnp.float8_e5m2, new_out_grad_scale, preferred_element_type) | ||
|
||
grad_lhs = _dot_general_transpose_lhs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The JAX team really doesn't like us depending on their internal implementations. Could we inline this function logic here to make this free-standing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, that 's also our main concern back then. Do you mean we should reimplement the logic of the two _xxx functions here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, they're fairly small functions and you don't need all the generality of them - it's just that JAX may need to change things in the future and we don't want to add external dependencies on their internals.
flax/linen/fp8_ops.py
Outdated
grad_lhs, preferred_element_type, new_rhs_scale * new_out_grad_scale | ||
) | ||
|
||
grad_rhs = _dot_general_transpose_rhs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment as above
flax/linen/fp8_ops.py
Outdated
@@ -25,6 +25,7 @@ | |||
from jax import numpy as jnp | |||
from jax._src import core | |||
from jax._src import dtypes | |||
from jax._src.lax.lax import _dot_general_transpose_lhs, _dot_general_transpose_rhs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better to inline - see below.
Sorry we block if trailing spaces are left in the file, there's some after the line |
@levskaya Just resolved some formatting issues. I think all the tests should pass now. Can you help review and merge? |
93b70ed
to
6f5cee1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one comment below
@wenscarl - I'm seeing a failed test?
|
@levskaya Thanks for reviewing. All checks are passed. |
Historically, FP8 matmul quantization followed the pattern of fake quantization, which involved a sequence of operations: quantization -> dequantization -> dot. Here, (de)quantization refers to type casting and the application of scaling factors. The XLA GemmWriter pass was designed to transform this pattern into a custom cublasLt call.
This PR proposes a departure from the historical approach by adopting direct quantization, which is quantization -> dot -> dequantization. This adjustment aligns better with mainstream quantization implementations for other data types. However, the success of this PR hinges on another PR in JAX (PR-21211) because of the mixed fp8 type matmul.
cc @lukaszlew @kaixih