Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure In-place correctness checks work properly #273

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
55 changes: 38 additions & 17 deletions src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ def liger_cross_entropy_kernel(
Y_stride,
loss_ptr,
loss_stride,
dX_ptr,
dX_stride,
n_cols,
n_non_ignore,
ignore_index,
Expand Down Expand Up @@ -49,12 +51,13 @@ def liger_cross_entropy_kernel(

# 2. locate the start index
X_ptr += program_id * X_stride
dX_ptr += program_id * dX_stride

if y == ignore_index:
# set all X_ptr as 0
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
dX_offsets = i + tl.arange(0, BLOCK_SIZE)
tl.store(dX_ptr + dX_offsets, 0.0, mask=dX_offsets < n_cols)
return

loss_ptr += program_id * loss_stride
Expand Down Expand Up @@ -106,15 +109,15 @@ def liger_cross_entropy_kernel(

for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(
dX_block = tl.load(
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
)
if reduction == "mean":
X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
dX_block = (tl.exp(dX_block - m) / d - eps) / (n_non_ignore)
else:
X_block = tl.exp(X_block - m) / d - eps
dX_block = tl.exp(dX_block - m) / d - eps

tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
tl.store(dX_ptr + X_offsets, dX_block, mask=X_offsets < n_cols)

# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
Expand Down Expand Up @@ -145,14 +148,14 @@ def liger_cross_entropy_kernel(
loss = loss / n_non_ignore

# 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
X_y = tl.load(X_ptr + y)
dX_y = tl.load(dX_ptr + y)
if reduction == "mean":
X_y += -(1 - label_smoothing) / (n_non_ignore)
dX_y += -(1 - label_smoothing) / (n_non_ignore)
else:
X_y += -(1 - label_smoothing)
dX_y += -(1 - label_smoothing)

tl.store(loss_ptr, loss)
tl.store(X_ptr + y, X_y)
tl.store(dX_ptr + y, dX_y)


# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
Expand All @@ -161,7 +164,9 @@ def liger_cross_entropy_kernel(
MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning


def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reduction):
def cross_entropy_forward(
_input, target, ignore_index, label_smoothing, reduction, inplace
):
BT, V = _input.shape
n_rows = BT

Expand All @@ -178,14 +183,17 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti
if target.stride(-1) != 1:
target = target.contiguous()

# Here we use a trick to store X_ptr gradient in X_ptr so we can save memory
dX = _input if inplace else torch.empty_like(_input)

liger_cross_entropy_kernel[(n_rows,)](
X_ptr=_input,
X_stride=_input.stride(-2),
Y_ptr=target,
Y_stride=target.stride(-1), # always 1
loss_ptr=loss_1d,
loss_stride=loss_1d.stride(-1), # always 1
dX_ptr=dX,
dX_stride=dX.stride(-2),
n_cols=V,
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
Expand All @@ -198,7 +206,7 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti
)

loss = torch.sum(loss_1d)
return loss, _input
return loss, dX


def cross_entropy_backward(_input, grad_output):
Expand Down Expand Up @@ -233,7 +241,13 @@ class LigerCrossEntropyFunction(torch.autograd.Function):

@staticmethod
def forward(
ctx, _input, target, ignore_index=-100, label_smoothing=0.0, reduction="mean"
ctx,
_input,
target,
ignore_index=-100,
label_smoothing=0.0,
reduction="mean",
inplace=True,
):
"""
The forward pass of the Liger Cross Entropy loss.
Expand All @@ -250,16 +264,21 @@ def forward(
tensor: The computed loss.
"""
loss, _input = cross_entropy_forward(
_input, target, ignore_index, label_smoothing, reduction
_input, target, ignore_index, label_smoothing, reduction, inplace
)
# TODO: investigation
# If we don't detach the _input tensor, the memory will double
# Not sure why but seems that there will be a time both grad and value exist but in different location
ctx.save_for_backward(_input.detach())
return loss

print(f"{inplace=}")
if inplace:
ctx.mark_dirty(_input)
ctx.mark_non_differentiable(_input)
return loss, _input

@staticmethod
def backward(ctx, grad_output):
def backward(ctx, grad_output, grad_output2):
"""
The backward pass of the Liger Cross Entropy loss.

Expand All @@ -270,6 +289,7 @@ def backward(ctx, grad_output):
Returns:
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
"""
del grad_output2
(_input,) = ctx.saved_tensors
_input = cross_entropy_backward(_input, grad_output)
return (
Expand All @@ -278,4 +298,5 @@ def backward(ctx, grad_output):
None,
None,
None,
None,
)
2 changes: 2 additions & 0 deletions src/liger_kernel/ops/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def fused_linear_cross_entropy_forward(
Y_stride=target_chunk.stride(-1), # always 1
loss_ptr=loss_1d_slice,
loss_stride=loss_1d_slice.stride(-1), # always 1
dX_ptr=logits_chunk,
dX_stride=logits_chunk.stride(-2),
n_cols=V,
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
Expand Down
29 changes: 19 additions & 10 deletions src/liger_kernel/transformers/cross_entropy.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,30 @@
from torch.nn import CrossEntropyLoss
import torch

from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction


class LigerCrossEntropyLoss(CrossEntropyLoss):
def __init__(self, *args, **kwargs):
super(LigerCrossEntropyLoss, self).__init__(*args, **kwargs)
assert (self.label_smoothing >= 0) and (
self.label_smoothing <= 1
class LigerCrossEntropyLoss(torch.nn.Module):
def __init__(self, ignore_index=-100, label_smoothing=0.0, reduction="mean"):
super().__init__()
assert (label_smoothing >= 0) and (
label_smoothing <= 1
), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}"
assert self.reduction in {
assert reduction in {
"mean",
"sum",
"none",
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {self.reduction}"
self.ignore_index = ignore_index
self.label_smoothing = label_smoothing
self.reduction = reduction

def forward(self, _input, target):
return LigerCrossEntropyFunction.apply(
_input, target, self.ignore_index, self.label_smoothing, self.reduction
def forward(self, _input, target, inplace):
loss, _ = LigerCrossEntropyFunction.apply(
_input,
target,
ignore_index=self.ignore_index,
label_smoothing=self.label_smoothing,
reduction=self.reduction,
inplace=inplace,
)
return loss
4 changes: 2 additions & 2 deletions test/transformers/test_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def _test_correctness_functional(B, T, V, scalar, dtype, atol, rtol):

target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long)

y1 = liger_cross_entropy(x1, target, 0)
y2 = LigerCrossEntropyFunction.apply(x2, target, 0)
y1, _ = liger_cross_entropy(x1, target, 0)
y2, _ = LigerCrossEntropyFunction.apply(x2, target, 0)

assert torch.allclose(y1, y2, atol=atol, rtol=rtol)

Expand Down
Loading