diff --git a/cute_kernels/kernels/contiguous_count/triton_implementation/kernels_forward.py b/cute_kernels/kernels/contiguous_count/triton_implementation/kernels_forward.py index f1c76627..ba1208da 100644 --- a/cute_kernels/kernels/contiguous_count/triton_implementation/kernels_forward.py +++ b/cute_kernels/kernels/contiguous_count/triton_implementation/kernels_forward.py @@ -33,15 +33,14 @@ def contiguous_count_triton_kernel(x_ptr, output_ptr, B, C, BLOCK_SIZE_B: tl.con x = tl.load(x_ptr + indices_b, mask=mask_b, other=-1) - equal = (x[:, None] == indices_c[None, :]) * 1 + equal = (x[:, None] == indices_c[None, :]).to(tl.int32) counts += tl.sum(equal, axis=0) - output_ptrs = output_ptr + pid * C + indices_c - tl.store(output_ptrs, counts, mask=mask_c) + tl.atomic_add(output_ptr + indices_c, counts, mask=mask_c) def _fake(x: torch.Tensor, size: int, BLOCK_SIZE_B: int) -> torch.Tensor: - return torch.empty(size, dtype=torch.long, device=x.device) + return torch.empty(size, dtype=torch.int32, device=x.device) @cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={}, fake_func=_fake) @@ -52,7 +51,7 @@ def contiguous_count_triton(x: torch.Tensor, size: int, BLOCK_SIZE_B: int) -> to sm_count = get_sm_count(x.device) num_programs = min(sm_count, ceil_divide(B, BLOCK_SIZE_B)) - output = torch.zeros(num_programs, size, dtype=torch.long, device=x.device) + output = torch.zeros(size, dtype=torch.int32, device=x.device) with torch.device(x.device): contiguous_count_triton_kernel[(num_programs,)]( @@ -64,4 +63,4 @@ def contiguous_count_triton(x: torch.Tensor, size: int, BLOCK_SIZE_B: int) -> to BLOCK_SIZE_C=BLOCK_SIZE_C, ) - return output.sum(dim=0) + return output