Skip to content
This repository has been archived by the owner on Jun 28, 2024. It is now read-only.

Commit

Permalink
[FRONTEND] Rename tl.reduction -> tl.reduce and improve testing (t…
Browse files Browse the repository at this point in the history
…riton-lang#1521)

`tl.reduction` is currently tested indirectly through the existing
reduction operators, but it's good to have a direct test for the
function itself.

---------

Co-authored-by: Philippe Tillet <[email protected]>
  • Loading branch information
peterbell10 and ptillet authored Apr 14, 2023
1 parent bfd1f65 commit 0d76c4c
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ repos:
rev: v1.6.0
hooks:
- id: autopep8
args: ["-a", "-i", "--max-line-length", "88"]
args: ["-i"]
stages: [commit, push, manual]
- repo: https://github.com/pycqa/flake8
rev: 6.0.0
Expand Down
4 changes: 4 additions & 0 deletions docs/python-api/triton.language.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,13 @@ Reduction Ops
:toctree: generated
:nosignatures:

argmax
argmin
max
min
reduce
sum
xor_sum


Atomic Ops
Expand Down
8 changes: 8 additions & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

[build-system]
requires = ["setuptools>=40.8.0", "wheel", "cmake>=3.18"]

[tool.autopep8]
aggressive = 1
ignore = "E501,E701,E731,W690"
max_line_length = 88
8 changes: 0 additions & 8 deletions python/setup.cfg

This file was deleted.

37 changes: 37 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,6 +1334,43 @@ def test_reduce_layouts(M, N, src_layout, axis, device='cuda'):
np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)


@triton.jit
def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
delta = mean_2 - mean_1
new_weight = weight_1 + weight_2
w2_over_w = weight_2 / new_weight
return (
mean_1 + delta * w2_over_w,
m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w,
new_weight,
)


def test_generic_reduction(device='cuda'):

@triton.jit
def var_mean_kernel(X, out_mean, out_var, BLOCK: tl.constexpr):
xindex = tl.arange(0, BLOCK)
x = tl.load(X + xindex)
mean = x
m2 = tl.zeros_like(x)
weight = tl.full(x.shape, 1, x.dtype)
(mean, m2, weight) = tl.reduce((mean, m2, weight), 0, _welford_combine)
tl.store(out_mean, mean)
tl.store(out_var, m2 / weight)

SIZE = 512
x = torch.rand(SIZE, device=device)
out_mean = torch.empty((), device=device)
out_var = torch.empty((), device=device)

var_mean_kernel[(1,)](x, out_mean, out_var, BLOCK=SIZE)

expect_var, expect_mean = torch.var_mean(x, dim=0, correction=0)
torch.testing.assert_close(out_mean, expect_mean)
torch.testing.assert_close(out_var, expect_var)


# ---------------
# test permute
# ---------------
Expand Down
2 changes: 2 additions & 0 deletions python/triton/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
pointer_type,
program_id,
ravel,
reduce,
reshape,
sigmoid,
sin,
Expand Down Expand Up @@ -164,6 +165,7 @@
"randn",
"randn4x",
"ravel",
"reduce",
"reshape",
"sigmoid",
"sin",
Expand Down
20 changes: 10 additions & 10 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1199,7 +1199,7 @@ def _insertion_guard(builder):


@builtin
def reduction(input, axis, combine_fn, _builder=None, _generator=None):
def reduce(input, axis, combine_fn, _builder=None, _generator=None):
"""Applies the combine_fn to all elements in :code:`input` tensors along the provided :code:`axis`
:param input: the input tensor, or tuple of tensors
Expand All @@ -1208,8 +1208,8 @@ def reduction(input, axis, combine_fn, _builder=None, _generator=None):
"""
if isinstance(input, tensor):
return reduction((input,), axis, combine_fn,
_builder=_builder, _generator=_generator)[0]
return reduce((input,), axis, combine_fn,
_builder=_builder, _generator=_generator)[0]

def make_combine_region(reduce_op):
in_scalar_tys = [t.type.scalar for t in input]
Expand Down Expand Up @@ -1261,8 +1261,8 @@ def _argreduce(input, axis, combine_fn, _builder=None, _generator=None):
index = index.__getitem__(expand_dims_index, _builder=_builder)
index = broadcast_to(index, input.shape, _builder=_builder)

rvalue, rindices = reduction((input, index), axis, combine_fn,
_builder=_builder, _generator=_generator)
rvalue, rindices = reduce((input, index), axis, combine_fn,
_builder=_builder, _generator=_generator)
return rindices


Expand All @@ -1275,7 +1275,7 @@ def _max_combine(a, b):
@_add_reduction_docstr("maximum")
def max(input, axis):
input = _promote_reduction_input(input)
return reduction(input, axis, _max_combine)
return reduce(input, axis, _max_combine)


@triton.jit
Expand Down Expand Up @@ -1305,7 +1305,7 @@ def _min_combine(a, b):
@_add_reduction_docstr("minimum")
def min(input, axis):
input = _promote_reduction_input(input)
return reduction(input, axis, _min_combine)
return reduce(input, axis, _min_combine)


@triton.jit
Expand Down Expand Up @@ -1334,7 +1334,7 @@ def _sum_combine(a, b):
@_add_reduction_docstr("sum")
def sum(input, axis):
input = _promote_reduction_input(input)
return reduction(input, axis, _sum_combine)
return reduce(input, axis, _sum_combine)


@triton.jit
Expand All @@ -1350,8 +1350,8 @@ def xor_sum(input, axis, _builder=None, _generator=None):
raise ValueError("xor_sum only supported for integers")

input = _promote_reduction_input(input, _builder=_builder)
return reduction(input, axis, _xor_combine,
_builder=_builder, _generator=_generator)
return reduce(input, axis, _xor_combine,
_builder=_builder, _generator=_generator)


# -----------------------
Expand Down

0 comments on commit 0d76c4c

Please sign in to comment.