-
Notifications
You must be signed in to change notification settings - Fork 326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] Propagate fp8 scale-inverse modification to GroupedLinear
#1128
Conversation
fp8_meta["scaling_fwd"].scale_inv, | ||
_GEMM_WEIGHT, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @timmoon10, I want to confirm that the previous way to get scale_inv
from fp8_meta["scaling_fwd"]
in forward still works but not encouraged, right?
I'm pretty sure this was a bug and maybe I should create an issue to let users be aware of it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the bug you are noticing? You are changing the API for fp8_grouped_gemm
, but other than that it looks like using fp8_meta
would be correct (but not recommended).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When casting weights to FP8 in get_fp8_workspace()
, the scale_inv
s of weight tensors are written to the private ones in Float8Tensors, while the scale_inv
s in fp8_meta
are updated after the first forward step is done. So for the first micro batch, they don't match.
Signed-off-by: Xin Yao <[email protected]>
f7ed83f
to
117aa98
Compare
Signed-off-by: Xin Yao <[email protected]>
4cd0fca
to
8eba144
Compare
@timmoon10 Can I have your review? |
Can we not add this functionality to the existing |
Signed-off-by: Xin Yao <[email protected]>
Done. |
ac034a6
to
12fad7d
Compare
Signed-off-by: Xin Yao <[email protected]>
12fad7d
to
e42a6d4
Compare
@ksivaman Can you trigger the CI? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks reasonable to me. I just have a suggestion to make the internal API more consistent.
fp8_meta["scaling_fwd"].scale_inv, | ||
_GEMM_WEIGHT, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the bug you are noticing? You are changing the API for fp8_grouped_gemm
, but other than that it looks like using fp8_meta
would be correct (but not recommended).
/te-ci pytorch |
Signed-off-by: Xin Yao <[email protected]>
Signed-off-by: Xin Yao <[email protected]>
/te-ci pytorch |
/te-ci pytorch |
@timmoon10 @ksivaman Can you take another look at this PR? Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Description
This PR
GroupedLinear
.scale_inv
is used for weights infp8_grouped_gemm
.Type of change
Checklist: