Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Fix a bug in argmax function
Browse files Browse the repository at this point in the history
  • Loading branch information
danny.jang authored Sep 22, 2023
1 parent 60707a8 commit a70a73e
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion trident/kernel/argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

class Argmax:
@staticmethod
@triton.heuristics({"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"]})
@triton.jit
def forward(
output_ptr: tl.tensor,
Expand All @@ -28,8 +29,10 @@ def forward(
x_stride: tl.int32,
dtype: tl.constexpr,
x_block_size: tl.constexpr,
require_x_boundary_check: tl.constexpr,
):
y_offset = tl.program_id(0)

output_block_ptr = tl.make_block_ptr(
output_ptr,
shape=(y_size,),
Expand All @@ -46,6 +49,13 @@ def forward(
block_shape=(1, x_block_size),
order=(1, 0),
)
input = tl.load(input_block_ptr, boundary_check=(1, 0), padding_option="zero")

if require_x_boundary_check:
input = tl.load(input_block_ptr, boundary_check=(1,))
condition = tl.arange(0, x_block_size) < x_size
input = tl.where(condition, input, float("-inf"))
else:
input = tl.load(input_block_ptr)

output = tl.argmax(input, 1)
tl.store(output_block_ptr, output.to(dtype))

0 comments on commit a70a73e

Please sign in to comment.