diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 48794da8e2da..fd6234902b4c 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3619,9 +3619,34 @@ def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr): torch.testing.assert_close(output, reference_out) -# Testing masked loads with an intermate copy to shared memory run. +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +@pytest.mark.parametrize("mask_val", [True, False]) +@pytest.mark.parametrize("other_val", [0, 1]) +def test_masked_load_scalar(num_ctas, mask_val, other_val, device): + input_val = 4.0 + size = 128 + dtype = torch.float32 + input = torch.full((size, ), input_val, dtype=dtype, device=device) + output = torch.zeros((size, ), dtype=dtype, device=device) + + @triton.jit + def kernel(in_ptr, out_ptr, size: tl.constexpr, mask: tl.constexpr, other: tl.constexpr): + offsets = tl.arange(0, size) + x = tl.load(in_ptr + offsets, mask=mask, other=other) + tl.store(out_ptr + offsets, x) + kernel[(1, )](input, output, size, mask_val, other_val, num_ctas=num_ctas) + if mask_val: + reference_out = torch.full((size, ), input_val, dtype=dtype, device=device) + else: + reference_out = torch.full((size, ), other_val, dtype=dtype, device=device) + + torch.testing.assert_close(output, reference_out) + + +# Testing masked loads with an intermate copy to shared memory run. # FIXME: Shape too small for ldmatrix when num_ctas=4 @pytest.mark.interpreter @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 1f4a3ec59bd2..83d4dfc8c409 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -944,7 +944,7 @@ def _canonicalize_boundary_check(boundary_check, block_shape): def _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder): # Load by a block pointer: `pointer_type>` # Block pointer can not have `mask` and `other` arguments - if mask or other: + if mask is not None or other is not None: raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers") elt_ty = ptr.type.element_ty.element_ty @@ -969,7 +969,7 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_ raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.load`") # Check `mask`, `other`, `boundary_check`, and `padding` arguments - if not mask and other: + if mask is None and other is not None: raise ValueError("`other` cannot be provided without `mask`") if padding or boundary_check: raise ValueError("`padding_option` or `boundary_check` argument is not supported for loading a tensor of" @@ -1013,7 +1013,7 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_ dst_ty = elt_ty # Build IR - if not mask: + if mask is None: return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty) else: return tl.tensor(