Skip to content

Commit

Permalink
[optimize] fuse celoss kernels into one
Browse files Browse the repository at this point in the history
  • Loading branch information
StrongSpoon committed Oct 31, 2024
1 parent 7ecdd12 commit b1a3b7e
Showing 1 changed file with 65 additions and 34 deletions.
99 changes: 65 additions & 34 deletions src/flag_gems/ops/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import triton.language as tl

from ..utils import libentry
from .sum import sum


@libentry()
Expand All @@ -16,13 +15,15 @@
for d in [1, 4, 16]
],
key=["C", "D"],
reset_to_zero=["out_sum_ptr", "w_tgt_ptr"],
)
@triton.jit(do_not_specialize=["ignore_index"])
def celoss_indice_kernel(
inp_ptr,
tgt_ptr,
w_ptr,
out_ptr,
out_sum_ptr,
w_tgt_ptr,
ignore_index,
N,
Expand All @@ -39,12 +40,11 @@ def celoss_indice_kernel(
tgt_mask = offset_d < D
tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0)

ignore_mask = not (tgt == ignore_index)
ignore_mask = not (tgt == ignore_index) and tgt_mask

w_ptrs = w_ptr + tgt
w_tgt = tl.load(w_ptrs, mask=tgt_mask, other=0).to(tl.float32)
w_tgt_ptrs = w_tgt_ptr + pid_n * D + offset_d
tl.store(w_tgt_ptrs, w_tgt, mask=tgt_mask and ignore_mask)
w_tgt = tl.load(w_ptr + tgt, mask=ignore_mask, other=0).to(tl.float32)
w_tgt_sum = tl.sum(w_tgt)
tl.atomic_add(w_tgt_ptr, w_tgt_sum)

tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
Expand All @@ -67,7 +67,9 @@ def celoss_indice_kernel(

out = (final_sum + final_max - inp_tgt) * w_tgt
out_ptrs = out_ptr + pid_n * D + offset_d
tl.store(out_ptrs, out, mask=tgt_mask and ignore_mask)
tl.store(out_ptrs, out)
out_sum = tl.sum(out)
tl.atomic_add(out_sum_ptr, out_sum)


@libentry()
Expand All @@ -78,13 +80,17 @@ def celoss_indice_kernel(
for d in [1, 4, 16]
],
key=["C", "D"],
reset_to_zero=[
"out_sum_ptr",
],
)
@triton.jit(do_not_specialize=["label_smoothing"])
def celoss_probability_kernel(
inp_ptr,
tgt_ptr,
w_ptr,
out_ptr,
out_sum_ptr,
label_smoothing,
N,
C,
Expand Down Expand Up @@ -121,15 +127,18 @@ def celoss_probability_kernel(
w_ptrs = w_ptr + offset_c
w_mask = offset_c < C
inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32)
tgt = tl.load(tgt_ptrs, mask, other=1).to(tl.float32)
tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32)
tgt = tgt * (1.0 - label_smoothing) + label_smoothing / C
w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32)[:, None]
log = final_sum + final_max - inp
_sum += w * log * tgt

out = tl.sum(_sum, axis=0)
out = tl.where(offset_d < D, out, 0)
out_ptrs = out_ptr + pid_n * D + offset_d
tl.store(out_ptrs, out, mask=offset_d < D)
out_sum = tl.sum(out)
tl.atomic_add(out_sum_ptr, out_sum)


@libentry()
Expand All @@ -140,13 +149,15 @@ def celoss_probability_kernel(
for d in [1, 4, 16]
],
key=["C", "D"],
reset_to_zero=["out_sum_ptr", "w_tgt_ptr"],
)
@triton.jit(do_not_specialize=["ignore_index", "label_smoothing"])
def celoss_indice_smooth_kernel(
inp_ptr,
tgt_ptr,
w_ptr,
out_ptr,
out_sum_ptr,
w_tgt_ptr,
ignore_index,
label_smoothing,
Expand All @@ -164,11 +175,11 @@ def celoss_indice_smooth_kernel(
tgt_mask = offset_d < D
tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0)

ignore_mask = not (tgt == ignore_index)
ignore_mask = not (tgt == ignore_index) and tgt_mask

w_tgt = tl.load(w_ptr + tgt, mask=tgt_mask, other=0).to(tl.float32)
w_tgt_ptrs = w_tgt_ptr + pid_n * D + offset_d
tl.store(w_tgt_ptrs, w_tgt, mask=tgt_mask and ignore_mask)
w_tgt = tl.load(w_ptr + tgt, mask=ignore_mask, other=0).to(tl.float32)
w_tgt_sum = tl.sum(w_tgt)
tl.atomic_add(w_tgt_ptr, w_tgt_sum)

tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
Expand Down Expand Up @@ -208,8 +219,11 @@ def celoss_indice_smooth_kernel(
_sum += log * smooth * w[:, None]

out = tl.sum(_sum, axis=0)
out = tl.where(ignore_mask, out, 0)
out_ptrs = out_ptr + pid_n * D + offset_d
tl.store(out_ptrs, out, mask=tgt_mask and ignore_mask)
tl.store(out_ptrs, out)
out_sum = tl.sum(out)
tl.atomic_add(out_sum_ptr, out_sum)


@libentry()
Expand Down Expand Up @@ -466,7 +480,6 @@ class CrossEntropyLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, inp, target, weight, reduction, ignore_index, label_smoothing):
logging.debug("GEMS CrossEntropyLoss")
# label_smoothing not supported

shape = list(inp.shape)
dim = inp.ndim
Expand All @@ -488,47 +501,65 @@ def forward(ctx, inp, target, weight, reduction, ignore_index, label_smoothing):
inp = inp.contiguous()
tgt = target.contiguous()
weight = weight.contiguous()
out = torch.zeros(shape, dtype=torch.float32, device=inp.device)
out = torch.empty(shape, dtype=torch.float32, device=inp.device)
out_sum = torch.zeros([], dtype=torch.float32, device=inp.device)
grid = lambda meta: (N, triton.cdiv(D, meta["BLOCK_D"]))

if tgt.ndim == dim:
# target probabilities
with torch.cuda.device(inp.device):
celoss_probability_kernel[grid](
inp, tgt, weight, out, label_smoothing, N, C, D
inp, tgt, weight, out, out_sum, label_smoothing, N, C, D
)
elif label_smoothing == 0:
# target indices
w_tgt = torch.zeros(shape, dtype=torch.float32, device=inp.device)
w_tgt = torch.zeros([], dtype=torch.float32, device=inp.device)
with torch.cuda.device(inp.device):
celoss_indice_kernel[grid](
inp, tgt, weight, out, w_tgt, ignore_index, N, C, D
inp, tgt, weight, out, out_sum, w_tgt, ignore_index, N, C, D
)
else:
w_tgt = torch.zeros(shape, dtype=torch.float32, device=inp.device)
w_tgt = torch.zeros([], dtype=torch.float32, device=inp.device)
with torch.cuda.device(inp.device):
celoss_indice_smooth_kernel[grid](
inp, tgt, weight, out, w_tgt, ignore_index, label_smoothing, N, C, D
inp,
tgt,
weight,
out,
out_sum,
w_tgt,
ignore_index,
label_smoothing,
N,
C,
D,
)
ctx.save_for_backward(inp, tgt, weight)
ctx.N = N
ctx.C = C
ctx.D = D
ctx.ignore_index = ignore_index
ctx.label_smoothing = label_smoothing
ctx.mean_num = 1
ctx.shape = shape

if reduction == 0: # NONE
return out.to(inp.dtype)
ctx.mean_num = 1
out = out.to(inp.dtype)
elif reduction == 1: # MEAN
if tgt.ndim == dim:
ctx.mean_num = 1 / (N * D)
ctx.mean_num = N * D
else:
ctx.mean_num = 1 / sum(w_tgt).item()
return (sum(out) * ctx.mean_num).to(inp.dtype)
ctx.mean_num = w_tgt.item()
out = torch.tensor(
out_sum.item() / ctx.mean_num, dtype=inp.dtype, device=inp.device
)
else: # SUM
return sum(out).to(inp.dtype)
ctx.mean_num = 1
out = out_sum.to(inp.dtype)

if inp.requires_grad:
ctx.save_for_backward(inp, tgt, weight)
ctx.N = N
ctx.C = C
ctx.D = D
ctx.ignore_index = ignore_index
ctx.label_smoothing = label_smoothing
ctx.shape = shape

return out

@staticmethod
def backward(ctx, out_grad):
Expand All @@ -540,7 +571,7 @@ def backward(ctx, out_grad):
D = ctx.D
ignore_index = ctx.ignore_index
label_smoothing = ctx.label_smoothing
mean_num = ctx.mean_num
mean_num = 1 / ctx.mean_num
shape = ctx.shape

out_grad = out_grad.broadcast_to(shape).contiguous()
Expand Down

0 comments on commit b1a3b7e

Please sign in to comment.