Skip to content

Commit

Permalink
add reduction of sum and none for CrossEntropyLoss (#41)
Browse files Browse the repository at this point in the history
* modify name && add reduce function

* add reduce none

* add test

* clean code

* add reduce enum

* Replacing the enum interface with Intenum & add illegal detection of reduction

---------

Co-authored-by: Jiang Bin <[email protected]>
  • Loading branch information
FatJhon and Jiang Bin committed May 31, 2024
1 parent bbd5386 commit 679ce8b
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 28 deletions.
4 changes: 2 additions & 2 deletions src/flag_gems/ops/argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def argmax_kernel_2(mid_value, mid_index, out, mid_size, BLOCK_MID: tl.constexpr
mid_ptrs = mid_value + offset
mask = offset < mid_size
mid_val = tl.load(mid_ptrs, mask=mask, other=-float("inf"))
sum_val = tl.argmax(mid_val, axis=0)
mid_index_ptrs = mid_index + sum_val
index_val = tl.argmax(mid_val, axis=0)
mid_index_ptrs = mid_index + index_val
out_val = tl.load(mid_index_ptrs)
tl.store(out, out_val)

Expand Down
121 changes: 105 additions & 16 deletions src/flag_gems/ops/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,15 @@
import triton
import triton.language as tl
import logging
from enum import IntEnum
from ..utils import libentry
from .sum import sum
from .sum import sum, sum_dim


class Reduction(IntEnum):
NONE = 0
MEAN = 1
SUM = 2


@libentry()
Expand Down Expand Up @@ -56,7 +63,7 @@ def log_softmax_and_mul_kernel(
denominator = tl.sum(numerator, axis=1)[:, None]
softmax_output = tl.log(numerator / denominator)
target = tl.load(target_ptr + offset, mask=mask, other=0.0)
out = softmax_output * target / (-mean_num)
out = softmax_output * target / (mean_num)
output_ptrs = output_ptr + offset
tl.store(output_ptrs, out, mask=mask)

Expand Down Expand Up @@ -114,6 +121,68 @@ def softmax_and_sub_kernel(
softmax_output = numerator / denominator
target_ptrs = target_ptr + offset
target = tl.load(target_ptrs, mask=mask, other=0.0)
out_grad_ptr = out_grad + m_offset[:, None] * K + pid_k
out_grad_value = tl.load(out_grad_ptr)
out = out_grad_value * (softmax_output - target) / mean_num
output_ptrs = output_ptr + offset

tl.store(output_ptrs, out, mask=mask)


@libentry()
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 1}, num_stages=4),
triton.Config({"BLOCK_M": 1}, num_stages=5),
triton.Config({"BLOCK_M": 2}, num_stages=4),
triton.Config({"BLOCK_M": 2}, num_stages=5),
triton.Config({"BLOCK_M": 4}, num_stages=4),
triton.Config({"BLOCK_M": 4}, num_stages=5),
triton.Config({"BLOCK_M": 8}, num_stages=4),
triton.Config({"BLOCK_M": 8}, num_stages=5),
],
key=[
"M",
"N",
],
)
@triton.heuristics(
values={
"BLOCK_N": lambda args: triton.next_power_of_2(args["N"]),
"num_warps": lambda args: (
4 if args["N"] <= 1024 else (8 if args["N"] <= 2048 else 16)
),
},
)
@triton.jit
def softmax_and_sub_reduce_kernel(
output_ptr,
input_ptr,
target_ptr,
out_grad,
mean_num,
M,
N,
K,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_k = tl.program_id(1)
m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
n_offset = tl.arange(0, BLOCK_N)
offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
mask = m_offset[:, None] < M and n_offset[None, :] < N
input_ptrs = input_ptr + offset
inp = tl.load(input_ptrs, mask=mask, other=-float("inf"))
row_minus_max = inp - tl.max(inp, axis=1)[:, None]
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=1)[:, None]
# todo: reduce unnecessary calculations through mask operations to improve performance
softmax_output = numerator / denominator
target_ptrs = target_ptr + offset
target = tl.load(target_ptrs, mask=mask, other=0.0)

out_grad_value = tl.load(out_grad)
out = out_grad_value * (softmax_output - target) / mean_num
output_ptrs = output_ptr + offset
Expand All @@ -125,15 +194,18 @@ class CrossEntropyLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, input, target, weight, reduction, ignore_index, label_smoothing):
logging.debug("GEMS CrossEntropyLoss")
assert reduction in Reduction._value2member_map_, "Invalid reduction"
assert isinstance(input, torch.Tensor), "input is not a tensor"
if input.ndim >= 2:
dim = 1
else:
dim = 0

if reduction != Reduction.MEAN.value:
mean_num = -1
else:
mean_num = -target.numel()
shape = list(input.shape)
shape[dim] = 1
mean_num = target.numel()
target = torch.zeros_like(input).scatter(dim, target.view(shape), 1)

M = 1
Expand All @@ -157,11 +229,15 @@ def forward(ctx, input, target, weight, reduction, ignore_index, label_smoothing
N,
K,
)
out_result = sum(out)
if reduction != Reduction.NONE.value:
out_result = sum(out)
else:
out_result = sum_dim(out, dim=[dim])

ctx.save_for_backward(input, target)
ctx.dim = dim
ctx.mean_num = mean_num
ctx.mean_num = -mean_num
ctx.reduction = reduction
return out_result

@staticmethod
Expand All @@ -170,6 +246,7 @@ def backward(ctx, out_grad):
input, target = ctx.saved_tensors
dim = ctx.dim
mean_num = ctx.mean_num
reduction = ctx.reduction

M = 1
N = input.shape[dim]
Expand All @@ -183,16 +260,28 @@ def backward(ctx, out_grad):
triton.cdiv(M, meta["BLOCK_M"]),
K,
)
softmax_and_sub_kernel[grid](
out,
inp,
target,
out_grad,
mean_num,
M,
N,
K,
)
if reduction != Reduction.NONE.value:
softmax_and_sub_reduce_kernel[grid](
out,
inp,
target,
out_grad,
mean_num,
M,
N,
K,
)
else:
softmax_and_sub_kernel[grid](
out,
inp,
target,
out_grad,
mean_num,
M,
N,
K,
)
return out, None, None, None, None, None


Expand Down
8 changes: 4 additions & 4 deletions src/flag_gems/ops/max.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def max_kernel_1(
inp_ptrs = inp + offset
mask = offset < M
inp_val = tl.load(inp_ptrs, mask=mask, other=-float("inf"))
sum_val = tl.max(inp_val)
max_val = tl.max(inp_val)
mid_ptr = mid + pid
tl.store(mid_ptr, sum_val)
tl.store(mid_ptr, max_val)


@libentry()
Expand All @@ -32,8 +32,8 @@ def max_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):
mid_ptrs = mid + offset
mask = offset < mid_size
mid_val = tl.load(mid_ptrs, mask=mask, other=-float("inf"))
sum_val = tl.max(mid_val)
tl.store(out, sum_val)
max_val = tl.max(mid_val)
tl.store(out, max_val)


@libentry()
Expand Down
8 changes: 4 additions & 4 deletions src/flag_gems/ops/min.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def min_kernel_1(
inp_ptrs = inp + offset
mask = offset < M
inp_val = tl.load(inp_ptrs, mask=mask, other=float("inf"))
sum_val = tl.min(inp_val)
min_val = tl.min(inp_val)
mid_ptr = mid + pid
tl.store(mid_ptr, sum_val)
tl.store(mid_ptr, min_val)


@libentry()
Expand All @@ -32,8 +32,8 @@ def min_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):
mid_ptrs = mid + offset
mask = offset < mid_size
mid_val = tl.load(mid_ptrs, mask=mask, other=float("inf"))
sum_val = tl.min(mid_val)
tl.store(out, sum_val)
min_val = tl.min(mid_val)
tl.store(out, min_val)


@libentry()
Expand Down
9 changes: 7 additions & 2 deletions tests/test_reduction_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,12 @@ def test_accuracy_argmax(shape, dim, keepdim, dtype):
gems_assert_equal(res_out, ref_out)


@pytest.mark.parametrize("size_average", [None, True, False])
@pytest.mark.parametrize("reduce", [None, True, False])
@pytest.mark.parametrize("reduction", ["mean", "none", "sum"])
@pytest.mark.parametrize("shape", REDUCTION_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_cross_entropy_loss(shape, dtype):
def test_accuracy_cross_entropy_loss(shape, dtype, size_average, reduce, reduction):
inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True)
dim = 1
up_limit = shape[dim] - 1
Expand All @@ -156,7 +159,9 @@ def test_accuracy_cross_entropy_loss(shape, dtype):
ref_inp = to_reference(inp, True)
ref_target = to_reference(target)

criterion = torch.nn.CrossEntropyLoss()
criterion = torch.nn.CrossEntropyLoss(
size_average=size_average, reduce=reduce, reduction=reduction
)

ref_out = criterion(ref_inp, ref_target)
with flag_gems.use_gems():
Expand Down

0 comments on commit 679ce8b

Please sign in to comment.