Skip to content

Commit

Permalink
[INTERPRETER] Fix scalar mask when the value is False (#3880)
Browse files Browse the repository at this point in the history
We hacked scalar tensor's `__bool__` method so that `if (scalar tensor)`
can be executed in the interpreter mode.


https://github.com/openai/triton/blob/main/python/triton/runtime/interpreter.py#L700

But scalar tensors shouldn't be evaluated in `semantic.py`.

We could have `tl.load(..., mask=tl.tensor([False]),
other=tl.tensor([1.0])`. It is still valid. Only when `mask=None,
other!=None` it is invalid.

Without this PR, `not tl.tensor([False])` is evaluated as True and
raises the wrong `ValueError`.

So we should instead check if the tensor is None or not.
  • Loading branch information
Jokeren authored May 11, 2024
1 parent afaf1f0 commit c549281
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 4 deletions.
27 changes: 26 additions & 1 deletion python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3619,9 +3619,34 @@ def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr):
torch.testing.assert_close(output, reference_out)


# Testing masked loads with an intermate copy to shared memory run.
@pytest.mark.interpreter
@pytest.mark.parametrize("num_ctas", num_ctas_list)
@pytest.mark.parametrize("mask_val", [True, False])
@pytest.mark.parametrize("other_val", [0, 1])
def test_masked_load_scalar(num_ctas, mask_val, other_val, device):
input_val = 4.0
size = 128
dtype = torch.float32
input = torch.full((size, ), input_val, dtype=dtype, device=device)
output = torch.zeros((size, ), dtype=dtype, device=device)

@triton.jit
def kernel(in_ptr, out_ptr, size: tl.constexpr, mask: tl.constexpr, other: tl.constexpr):
offsets = tl.arange(0, size)
x = tl.load(in_ptr + offsets, mask=mask, other=other)
tl.store(out_ptr + offsets, x)

kernel[(1, )](input, output, size, mask_val, other_val, num_ctas=num_ctas)

if mask_val:
reference_out = torch.full((size, ), input_val, dtype=dtype, device=device)
else:
reference_out = torch.full((size, ), other_val, dtype=dtype, device=device)

torch.testing.assert_close(output, reference_out)


# Testing masked loads with an intermate copy to shared memory run.
# FIXME: Shape too small for ldmatrix when num_ctas=4
@pytest.mark.interpreter
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
Expand Down
6 changes: 3 additions & 3 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,7 @@ def _canonicalize_boundary_check(boundary_check, block_shape):
def _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder):
# Load by a block pointer: `pointer_type<block_type<>>`
# Block pointer can not have `mask` and `other` arguments
if mask or other:
if mask is not None or other is not None:
raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers")

elt_ty = ptr.type.element_ty.element_ty
Expand All @@ -969,7 +969,7 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_
raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.load`")

# Check `mask`, `other`, `boundary_check`, and `padding` arguments
if not mask and other:
if mask is None and other is not None:
raise ValueError("`other` cannot be provided without `mask`")
if padding or boundary_check:
raise ValueError("`padding_option` or `boundary_check` argument is not supported for loading a tensor of"
Expand Down Expand Up @@ -1013,7 +1013,7 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_
dst_ty = elt_ty

# Build IR
if not mask:
if mask is None:
return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty)
else:
return tl.tensor(
Expand Down

0 comments on commit c549281

Please sign in to comment.