Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support direct quantization for FP8 matmul #3922

Merged
merged 4 commits into from
Sep 4, 2024

Conversation

wenscarl
Copy link
Collaborator

Historically, FP8 matmul quantization followed the pattern of fake quantization, which involved a sequence of operations: quantization -> dequantization -> dot. Here, (de)quantization refers to type casting and the application of scaling factors. The XLA GemmWriter pass was designed to transform this pattern into a custom cublasLt call.

This PR proposes a departure from the historical approach by adopting direct quantization, which is quantization -> dot -> dequantization. This adjustment aligns better with mainstream quantization implementations for other data types. However, the success of this PR hinges on another PR in JAX (PR-21211) because of the mixed fp8 type matmul.
cc @lukaszlew @kaixih

@wenscarl wenscarl changed the title Support direct quantization for FP8 matmul [draft]Support direct quantization for FP8 matmul May 14, 2024
Copy link
Contributor

@kaixih kaixih left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, the placement of quant and dequant is a bit confusing and the q and dq ops seems to be included in our custom dot_general function. I am trying to summarize the rationale here:
Basically, q means a pure quantize without amax logics but xxx_q includes both quantization and amax math.

# Our original design:
x(in_qdq), k(in_qdq)->y
dy(out_qdq), x(in_qdq)->dk
dy(out_qdq), k(in_qdq)->dx

# New direct design:
x(in_q), k(in_q)->y(dq)
dy(out_q), x(in_q)->dk(dq??)
dy(out_q), k(in_q)->dx(dq??)

?? indicates the problem about where to place these dq ops. In the original design, we don't need to worry about where the dk and dx are defined, because we don't apply any qdq there. However, in the new design, we need to find them and apply the dq ops explicitly and because we are using jvp mode (forward autograd mode), where we express the grad like:

dy = dx@k + x@dk

So, it seems we have to include the dq ops inside the dot_general function.

So, if that is the case, should we move all the q and dq into the dot_general function, esp in the jvp:

in_q(x)
in_q(y)
y = x@y
dq(y)
dq(dx)
dq(dk)
dy = dx@k + x@dk
in_q(dy)

Also, by doing this, we don't need vjp on the in_q or out_q, since logics are already expressed into the custom jvp dot_general function.

@@ -142,7 +141,7 @@ def qdq_and_return(x, q_dtype, scale, amax_history, compute_dtype):
amax_from_history = jnp.max(amax_history, axis=0)
new_scale = compute_scale(amax_from_history, scale, dtype_max)

qx = quantize_dequantize(x, q_dtype, new_scale, compute_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also remove the quantize_dequantize? I think it is no longer used.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's used in the test file. Removed.

q_g, new_scale, new_history = qdq_and_return(
g, jnp.float8_e5m2, scale, amax_history, compute_dtype
q_g, new_scale, new_history = q_and_return(
g, jnp.float8_e5m2, scale, amax_history, compute_dtype #elfie investigate
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment here still relevant? Or can it be more specific as a TODO note?

'The function dot_general_with_precision will set the '
'precision/preferred_element_type and disregard any provided '
'values.'
if precision != None or preferred_element_type != None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you accidentally changed the indent here.

)

lhs = quantize(lhs, jnp.float8_e4m3fn, lhs_scale, preferred_element_type)
rhs = quantize(rhs, jnp.float8_e4m3fn, rhs_scale, preferred_element_type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems in the forward pass, we directly call the quantize over the lhs and rhs. But do we need the amax computation?

self.output_grad_scale.value,
self.output_grad_amax_history.value,
)
y_q = dot_general_with_precision(x, k, dimension_numbers,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it would be better to write the code like:

qx = in_quant(x, ...) # which also includes the amax math
qk = in_quant(k, ...)
y = dot_general_and_dequant(qx, qk)
y = grad_quant(y) # let's call it grad_q since it is to apply quantize over gradients

flax/linen/fp8_ops.py Outdated Show resolved Hide resolved
flax/linen/fp8_ops.py Outdated Show resolved Hide resolved
flax/linen/fp8_ops.py Outdated Show resolved Hide resolved
@kaixih
Copy link
Contributor

kaixih commented May 23, 2024

I think the new design is much clearer of the idea of direct quantization. By the way, do you think we should create a new Fp8DotGeneral op for it and keep the existing fake quant Op untouched? And then we gradually change downstream uses to migrate to the new op?

@wenscarl
Copy link
Collaborator Author

I think the new design is much clearer of the idea of direct quantization. By the way, do you think we should create a new Fp8DotGeneral op for it and keep the existing fake quant Op untouched? And then we gradually change downstream uses to migrate to the new op?

Praxis doesn't use Fp8DotGeneralOp directly but dot_general_with_precision. Given that most PAXML models access fp8 from praxis. So instead, I think it makes more sense to keep both fake and direct quant Op there.

@wenscarl wenscarl requested a review from kaixih May 24, 2024 21:40
flax/linen/fp8_ops.py Outdated Show resolved Hide resolved
flax/linen/fp8_ops.py Outdated Show resolved Hide resolved
flax/linen/fp8_ops.py Outdated Show resolved Hide resolved
flax/linen/fp8_ops.py Outdated Show resolved Hide resolved
flax/linen/fp8_ops.py Outdated Show resolved Hide resolved
flax/linen/fp8_ops.py Outdated Show resolved Hide resolved
tests/linen/linen_test.py Outdated Show resolved Hide resolved
@wenscarl wenscarl requested a review from kaixih May 30, 2024 02:47
@wenscarl wenscarl force-pushed the direct_quant branch 2 times, most recently from 5e372dd to 2d3e9f5 Compare May 30, 2024 02:52
@wenscarl wenscarl changed the title [draft]Support direct quantization for FP8 matmul Support direct quantization for FP8 matmul May 30, 2024
flax/linen/fp8_ops.py Outdated Show resolved Hide resolved
flax/linen/fp8_ops.py Show resolved Hide resolved
flax/linen/fp8_ops.py Outdated Show resolved Hide resolved
tests/linen/linen_test.py Outdated Show resolved Hide resolved
@wenscarl wenscarl requested a review from kaixih May 30, 2024 18:57
@wenscarl
Copy link
Collaborator Author

A gentle reminder to @lukaszlew

@lukaszlew
Copy link
Contributor

Sorry, I don't have cycles to review this PR. I'm focusing on AQT.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@wenscarl wenscarl force-pushed the direct_quant branch 2 times, most recently from 7e3ee81 to a25f7fc Compare July 11, 2024 19:21
@zhangqiaorjc
Copy link
Member

Could we get @levskaya to help review since Lukasz is busy with something else these days?

@wenscarl
Copy link
Collaborator Author

@levskaya could you take a look?

Copy link
Collaborator

@levskaya levskaya left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for the long delay - I was visiting rural family who don't have an internet connection when I was pinged here.


q_g = quantize(g, jnp.float8_e5m2, new_out_grad_scale, preferred_element_type)

grad_lhs = _dot_general_transpose_lhs(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The JAX team really doesn't like us depending on their internal implementations. Could we inline this function logic here to make this free-standing?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, that 's also our main concern back then. Do you mean we should reimplement the logic of the two _xxx functions here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, they're fairly small functions and you don't need all the generality of them - it's just that JAX may need to change things in the future and we don't want to add external dependencies on their internals.

grad_lhs, preferred_element_type, new_rhs_scale * new_out_grad_scale
)

grad_rhs = _dot_general_transpose_rhs(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above

@@ -25,6 +25,7 @@
from jax import numpy as jnp
from jax._src import core
from jax._src import dtypes
from jax._src.lax.lax import _dot_general_transpose_lhs, _dot_general_transpose_rhs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to inline - see below.

@levskaya
Copy link
Collaborator

Sorry we block if trailing spaces are left in the file, there's some after the line class Fp8Test(parameterized.TestCase): in the tests (you can see the failed precommit) - could you fix and we can run the tests. I think things look ok.

flax/linen/fp8_ops.py Outdated Show resolved Hide resolved
flax/linen/fp8_ops.py Outdated Show resolved Hide resolved
flax/linen/fp8_ops.py Outdated Show resolved Hide resolved
@wenscarl wenscarl requested a review from kaixih August 22, 2024 18:37
@kaixih
Copy link
Contributor

kaixih commented Aug 29, 2024

@levskaya I think all the issues have been resolved by @wenscarl . Can you help review and merge?

@kaixih
Copy link
Contributor

kaixih commented Aug 30, 2024

@levskaya Just resolved some formatting issues. I think all the tests should pass now. Can you help review and merge?

Copy link
Collaborator

@levskaya levskaya left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one comment below

flax/linen/fp8_ops.py Outdated Show resolved Hide resolved
@levskaya
Copy link
Collaborator

levskaya commented Sep 4, 2024

@wenscarl @kaixih - hey also, I invited you as collaborators to this repo, apologies that I hadn't done that earlier, it should immediately trigger presubmit tests once you join.

@levskaya
Copy link
Collaborator

levskaya commented Sep 4, 2024

@wenscarl - I'm seeing a failed test?

E       UserWarning: The function dot_general_with_precision will set the precision/preferred_element_type and disregard any provided values.

@wenscarl
Copy link
Collaborator Author

wenscarl commented Sep 4, 2024

@levskaya Thanks for reviewing. All checks are passed.

@copybara-service copybara-service bot merged commit c44b916 into google:main Sep 4, 2024
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants