Skip to content

Commit

Permalink
improve contiguous count using atomic_add (#107)
Browse files Browse the repository at this point in the history
Signed-off-by: Mayank Mishra <[email protected]>
  • Loading branch information
mayank31398 authored Dec 20, 2024
1 parent 99b4111 commit b08a1c4
Showing 1 changed file with 5 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,14 @@ def contiguous_count_triton_kernel(x_ptr, output_ptr, B, C, BLOCK_SIZE_B: tl.con

x = tl.load(x_ptr + indices_b, mask=mask_b, other=-1)

equal = (x[:, None] == indices_c[None, :]) * 1
equal = (x[:, None] == indices_c[None, :]).to(tl.int32)
counts += tl.sum(equal, axis=0)

output_ptrs = output_ptr + pid * C + indices_c
tl.store(output_ptrs, counts, mask=mask_c)
tl.atomic_add(output_ptr + indices_c, counts, mask=mask_c)


def _fake(x: torch.Tensor, size: int, BLOCK_SIZE_B: int) -> torch.Tensor:
return torch.empty(size, dtype=torch.long, device=x.device)
return torch.empty(size, dtype=torch.int32, device=x.device)


@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={}, fake_func=_fake)
Expand All @@ -52,7 +51,7 @@ def contiguous_count_triton(x: torch.Tensor, size: int, BLOCK_SIZE_B: int) -> to
sm_count = get_sm_count(x.device)
num_programs = min(sm_count, ceil_divide(B, BLOCK_SIZE_B))

output = torch.zeros(num_programs, size, dtype=torch.long, device=x.device)
output = torch.zeros(size, dtype=torch.int32, device=x.device)

with torch.device(x.device):
contiguous_count_triton_kernel[(num_programs,)](
Expand All @@ -64,4 +63,4 @@ def contiguous_count_triton(x: torch.Tensor, size: int, BLOCK_SIZE_B: int) -> to
BLOCK_SIZE_C=BLOCK_SIZE_C,
)

return output.sum(dim=0)
return output

0 comments on commit b08a1c4

Please sign in to comment.