117117};
118118"""
119119
120- is_medusa = False
120+ is_spec_dec = False
121121
122122
123123def generate_cubin_meta_info_line (arch : int , compile_macros : List [CompileMacro ],
124124 function_name : str , cubin_size : int ,
125- is_last : bool , is_medusa : bool ):
125+ is_last : bool , is_spec_dec : bool ):
126126 data_type_str = None
127127 kv_data_type_str = None
128128 head_dim = None
@@ -160,7 +160,7 @@ def generate_cubin_meta_info_line(arch: int, compile_macros: List[CompileMacro],
160160 assert (tokens_per_page % 2 == 0 )
161161 paged_kv_cache = 'true' if tokens_per_page > 0 else 'false'
162162
163- use_medusa = 'true' if is_medusa else 'false'
163+ use_medusa = 'true' if is_spec_dec else 'false'
164164 assert data_type_str is not None
165165 assert kv_data_type_str is not None
166166 assert head_dim is not None
@@ -376,7 +376,7 @@ def generate_compile_arch_macro_list(compile_macro_options: list):
376376 option_macro_names , option_short_names , option_combination )
377377 ]
378378 if arch in (90 , ) and option_combination [
379- 3 ] == 2 and option_combination [2 ] == 1 and not is_medusa :
379+ 3 ] == 2 and option_combination [2 ] == 1 and not is_spec_dec :
380380 input_file_name = "mha_sm90.cu"
381381 else :
382382 input_file_name = "mha.cu"
@@ -387,7 +387,7 @@ def generate_compile_arch_macro_list(compile_macro_options: list):
387387
388388def generate_header_file_contents (
389389 all_arch_macros : List [CompileArchMacrosAndFile ],
390- name_size_list : List [Tuple [str , int ]], is_medusa : bool ):
390+ name_size_list : List [Tuple [str , int ]], is_spec_dec : bool ):
391391 cubin_data_array = []
392392 cubin_length_array = []
393393 meta_line_array = []
@@ -406,7 +406,7 @@ def generate_header_file_contents(
406406 generate_cubin_meta_info_line (arch , macros , function_name ,
407407 cubin_size ,
408408 i == len (all_arch_macros ) - 1 ,
409- is_medusa ))
409+ is_spec_dec ))
410410 cubin_data = '' .join (cubin_data_array )
411411 cubin_length = '' .join (cubin_length_array )
412412 meta_struct = '' .join ([
@@ -422,8 +422,8 @@ def generate_header_file_contents(
422422 shutil .rmtree (cubin_dir )
423423 os .mkdir (cubin_dir )
424424
425- if len (sys .argv ) > 1 and sys .argv [1 ] == 'medusa ' :
426- is_medusa = True
425+ if len (sys .argv ) > 1 and sys .argv [1 ] == 'spec_dec ' :
426+ is_spec_dec = True
427427 nvcc_flags = '-std=c++17 -O3 -cubin -DGENERATE_CUBIN=1 -DNDEBUG -DSPEC_DEC --use_fast_math -Xptxas=-v --allow-unsupported-compiler --expt-relaxed-constexpr -t 0'
428428 arch_options = [80 , 86 , 89 , 90 ]
429429 config_list = [[
@@ -444,7 +444,7 @@ def generate_header_file_contents(
444444 with multiprocessing .Pool (processes = thread_count ) as pool :
445445 name_size_list = pool .map (run_cubin_gen , arch_macro_lists )
446446 header_file_contents = generate_header_file_contents (
447- arch_macro_lists , name_size_list , is_medusa )
447+ arch_macro_lists , name_size_list , is_spec_dec )
448448
449449 with open (cubin_dir + build_func_name_prefix + '_cubin.h' , "w" ) as f :
450450 f .write ("" .join (
0 commit comments