Skip to content

Commit

Permalink
[TEST] fix error in argmin UT when dtype=int16
Browse files Browse the repository at this point in the history
  • Loading branch information
junjian.zhan committed Jan 21, 2025
1 parent 73acbe6 commit e05e3ee
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/flag_gems/ops/argmin.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def argmin_kernel(
m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)

# min_values = tl.full([BLOCK_M], dtype=tl.float32, value=float("inf"))
if tl_dtype is tl.int16:
tl_dtype = tl.int32
min_values = tl.full([BLOCK_M], dtype=tl_dtype, value=dtype_max_value)
argmin_values = tl.full([BLOCK_M], dtype=tl.int64, value=0)
for start_n in range(0, N, BLOCK_N):
Expand Down
27 changes: 27 additions & 0 deletions src/flag_gems/runtime/backend/_iluvatar/heuristics_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ def argmax_heur_block_n(args):
return min(4096, triton.next_power_of_2(args["N"]))


def argmin_heur_block_m(args):
return 4 if args["M"] < 4096 else 8


def argmin_heur_block_n(args):
return min(4096, triton.next_power_of_2(args["N"]))


def bmm_heur_divisible_m(args):
return args["M"] % args["BLOCK_M"] == 0

Expand Down Expand Up @@ -195,11 +203,26 @@ def upsample_nearest2d_SAME_W(args):
return args["OW"] == args["IW"]


def batch_norm_heur_block_m(args):
return min(2048, triton.next_power_of_2(args["batch_dim"]))


def batch_norm_heur_block_n(args):
# A maximum of 16384 elements are loaded at once.
BLOCK_M = batch_norm_heur_block_m(args)
BLOCK_N = triton.next_power_of_2(args["spatial_dim"])
return min(BLOCK_N, max(1, 2**14 // BLOCK_M))


HEURISTICS_CONFIGS = {
"argmax": {
"BLOCK_M": argmax_heur_block_m,
"BLOCK_N": argmax_heur_block_n,
},
"argmin": {
"BLOCK_M": argmin_heur_block_m,
"BLOCK_N": argmin_heur_block_n,
},
"bmm": {
"DIVISIBLE_M": bmm_heur_divisible_m,
"DIVISIBLE_N": bmm_heur_divisible_n,
Expand Down Expand Up @@ -262,4 +285,8 @@ def upsample_nearest2d_SAME_W(args):
"var_mean": {
"BLOCK_N": var_mean_heur_block_n,
},
"batch_norm": {
"BLOCK_M": batch_norm_heur_block_m,
"BLOCK_N": batch_norm_heur_block_n,
},
}
22 changes: 22 additions & 0 deletions src/flag_gems/runtime/backend/_iluvatar/tune_configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1402,6 +1402,16 @@ bmm:
num_stages: 2
num_warps: 16
argmax:
- META:
BLOCK_M: 8
num_warps: 8
- META:
BLOCK_M: 16
num_warps: 8
- META:
BLOCK_M: 32
num_warps: 8
argmin:
- META:
BLOCK_M: 8
num_warps: 8
Expand Down Expand Up @@ -3230,3 +3240,15 @@ var_mean:
- 4
- 8
- 16
batch_norm:
- gen: true
param_map:
META: {}
num_warps: warps
warps:
- 1
- 2
- 4
- 8
- 16
- 32

0 comments on commit e05e3ee

Please sign in to comment.