Skip to content

Commit 8907a65

Browse files
qsang-nvevezhier
authored andcommitted
update spec_dec (#6079)
Signed-off-by: Qidi Sang <[email protected]>
1 parent 4f15810 commit 8907a65

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

cpp/kernels/xqa/gen_cubins.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,12 @@
117117
};
118118
"""
119119

120-
is_medusa = False
120+
is_spec_dec = False
121121

122122

123123
def generate_cubin_meta_info_line(arch: int, compile_macros: List[CompileMacro],
124124
function_name: str, cubin_size: int,
125-
is_last: bool, is_medusa: bool):
125+
is_last: bool, is_spec_dec: bool):
126126
data_type_str = None
127127
kv_data_type_str = None
128128
head_dim = None
@@ -160,7 +160,7 @@ def generate_cubin_meta_info_line(arch: int, compile_macros: List[CompileMacro],
160160
assert (tokens_per_page % 2 == 0)
161161
paged_kv_cache = 'true' if tokens_per_page > 0 else 'false'
162162

163-
use_medusa = 'true' if is_medusa else 'false'
163+
use_medusa = 'true' if is_spec_dec else 'false'
164164
assert data_type_str is not None
165165
assert kv_data_type_str is not None
166166
assert head_dim is not None
@@ -376,7 +376,7 @@ def generate_compile_arch_macro_list(compile_macro_options: list):
376376
option_macro_names, option_short_names, option_combination)
377377
]
378378
if arch in (90, ) and option_combination[
379-
3] == 2 and option_combination[2] == 1 and not is_medusa:
379+
3] == 2 and option_combination[2] == 1 and not is_spec_dec:
380380
input_file_name = "mha_sm90.cu"
381381
else:
382382
input_file_name = "mha.cu"
@@ -387,7 +387,7 @@ def generate_compile_arch_macro_list(compile_macro_options: list):
387387

388388
def generate_header_file_contents(
389389
all_arch_macros: List[CompileArchMacrosAndFile],
390-
name_size_list: List[Tuple[str, int]], is_medusa: bool):
390+
name_size_list: List[Tuple[str, int]], is_spec_dec: bool):
391391
cubin_data_array = []
392392
cubin_length_array = []
393393
meta_line_array = []
@@ -406,7 +406,7 @@ def generate_header_file_contents(
406406
generate_cubin_meta_info_line(arch, macros, function_name,
407407
cubin_size,
408408
i == len(all_arch_macros) - 1,
409-
is_medusa))
409+
is_spec_dec))
410410
cubin_data = ''.join(cubin_data_array)
411411
cubin_length = ''.join(cubin_length_array)
412412
meta_struct = ''.join([
@@ -422,8 +422,8 @@ def generate_header_file_contents(
422422
shutil.rmtree(cubin_dir)
423423
os.mkdir(cubin_dir)
424424

425-
if len(sys.argv) > 1 and sys.argv[1] == 'medusa':
426-
is_medusa = True
425+
if len(sys.argv) > 1 and sys.argv[1] == 'spec_dec':
426+
is_spec_dec = True
427427
nvcc_flags = '-std=c++17 -O3 -cubin -DGENERATE_CUBIN=1 -DNDEBUG -DSPEC_DEC --use_fast_math -Xptxas=-v --allow-unsupported-compiler --expt-relaxed-constexpr -t 0'
428428
arch_options = [80, 86, 89, 90]
429429
config_list = [[
@@ -444,7 +444,7 @@ def generate_header_file_contents(
444444
with multiprocessing.Pool(processes=thread_count) as pool:
445445
name_size_list = pool.map(run_cubin_gen, arch_macro_lists)
446446
header_file_contents = generate_header_file_contents(
447-
arch_macro_lists, name_size_list, is_medusa)
447+
arch_macro_lists, name_size_list, is_spec_dec)
448448

449449
with open(cubin_dir + build_func_name_prefix + '_cubin.h', "w") as f:
450450
f.write("".join(

cpp/kernels/xqa/mha.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1609,7 +1609,7 @@ CUBIN_EXPORT __global__
16091609

16101610
uint32_t const nbSeqIters = useKVCache ? divUp(cacheSeqLen, ctaTile.x) : 0;
16111611
#if SPEC_DEC
1612-
uint32_t const nbSeqItersWithoutMask = (cacheSeqLen - qSeqLen) / ctaTile.x;
1612+
uint32_t const nbSeqItersWithoutMask = (cacheSeqLen - actualQSeqLen) / ctaTile.x;
16131613
#endif
16141614

16151615
uint32_t const seqStrideIters = nbSubSeqPerSeq;

0 commit comments

Comments
 (0)