-
Notifications
You must be signed in to change notification settings - Fork 211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ensure In-place correctness checks work properly #273
base: main
Are you sure you want to change the base?
Conversation
@@ -213,6 +213,8 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti | |||
target = target.contiguous() | |||
|
|||
# Here we use a trick to store X_ptr gradient in X_ptr so we can save memory | |||
# explicitly declare in-place operation is performed | |||
_input.add_(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you explain a bit on what this is doing exactly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How torch checks whether an in-place op on a tensor would result in an incorrect gradient calculation, is by implicitly tracking the "version" of a tensor and comparing the saved version and afterward version.
When an in-place operation is performed on a tensor, the version of the tensor is incremented by 1, achieved by an internal function bump()
written in C.
Since the bump()
function is only called when doing "torch" in-place operation, i.e. in-place operations in "triton kernel" cannot trigger bump()
, which makes torch lose track of the version and unable to raise an error.
This approach is a hint for torch, by manually performing a torch's in-place op when we do that tensor dirty in triton kernel.
Reference:
torch in-place-correctness-checks
https://pytorch.org/docs/stable/autograd.html#in-place-correctness-checks
version tracking, bump (there's a note above the function block)
https://github.com/pytorch/pytorch/blob/190e09d8b6a13f789b143f0fbd1325f924550967/c10/core/TensorImpl.h#L382
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this operation have any performance impact? cc @ByronHsu
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
impact is huge for add_(0). will try some other inplace ops.
bench so i guess its not worth it?
|
Yeah, i was looking if we can call bump() from python... 50% cost does not worth .. |
I am wondering why the error does not happen for normal case? |
I left an explanation in issue |
With @mgrabban's suggestion in #343, I made another implmentation with mark_dirty(). Note: I haven't benchmarked this new approach against current liger_ce, will do it in few days. draft gist for benchmark Result with new approachimport torch
import torch.nn.functional as F
from liger_kernel.transformers.functional import liger_cross_entropy
def run_inplace_experiment(
logits_p, logits_q, cross_entropy_fn, is_liger=False, use_inplace=False
):
_p = logits_p.clone().detach().requires_grad_(True)
_p.retain_grad()
softmax = torch.nn.Softmax(dim=-1)
p = softmax(_p)
p.retain_grad()
try:
if is_liger:
loss, _ = cross_entropy_fn(p, logits_q, -100, 0.0, "mean", use_inplace)
else:
loss = cross_entropy_fn(p, logits_q)
loss.backward(retain_graph=True)
print(f"Cross Entropy Loss: {loss.item()}")
print(f"Input _p: {_p}")
print(f"Input logits_q: {logits_q}")
print(f"Gradients of p (batch item 0): {p.grad[0]}")
print(f"Gradients of _p (batch item 0): {_p.grad[0]}")
except Exception as e:
print(e)
torch.manual_seed(0)
logits_p = torch.randn(8, 8, requires_grad=True, device="cuda")
logits_q = torch.randint(0, 8, (8,), device="cuda", dtype=torch.long)
run_inplace_experiment(
logits_p, logits_q, cross_entropy_fn=F.cross_entropy, is_liger=False
)
print()
print("LIGER use_inplace=True:")
run_inplace_experiment(
logits_p,
logits_q,
cross_entropy_fn=liger_cross_entropy,
is_liger=True,
use_inplace=True,
)
print()
print("LIGER use_inplace=False:")
run_inplace_experiment(
logits_p,
logits_q,
cross_entropy_fn=liger_cross_entropy,
is_liger=True,
use_inplace=False,
)
|
Update: Here's the benchmark against liger's ce
|
Summary
Fix #272
It's a show case of how to trigger error properly.
I only apply it to cross_entropy for demonstration, can apply to others if we want.
Testing Done
same gist as the issue's
Properly raised the error
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence