1- use std:: env;
1+ use std:: env:: { self , VarError } ;
22use std:: fs:: { read_dir, File } ;
33use std:: io:: Write ;
44use std:: path:: { Path , PathBuf } ;
55use std:: process:: Command ;
6+ use std:: str:: FromStr ;
67
78use cc:: Build ;
89use once_cell:: sync:: Lazy ;
10+ use glob:: glob;
911
1012// This build file is based on:
1113// https://github.com/mdrokz/rust-llama.cpp/blob/master/build.rs
@@ -365,23 +367,16 @@ fn compile_blis(cx: &mut Build) {
365367}
366368
367369fn compile_hipblas ( cx : & mut Build , cxx : & mut Build , mut hip : Build ) -> & ' static str {
368- const DEFAULT_ROCM_PATH_STR : & str = " /opt/rocm/";
370+ let rocm_path_str = env :: var ( "ROCM_PATH" ) . or ( Ok :: < String , VarError > ( String :: from_str ( " /opt/rocm/") . unwrap ( ) ) ) . unwrap ( ) ;
369371
370- let rocm_path_str = env:: var ( "ROCM_PATH" )
371- . map_err ( |_| DEFAULT_ROCM_PATH_STR . to_string ( ) )
372- . unwrap ( ) ;
373- println ! ( "Compiling HIPBLAS GGML. Using ROCm from {rocm_path_str}" ) ;
372+ println ! ( "Compiling hipBLAS GGML. Using ROCm from {rocm_path_str}" ) ;
374373
375374 let rocm_path = PathBuf :: from ( rocm_path_str) ;
376375 let rocm_include = rocm_path. join ( "include" ) ;
377376 let rocm_lib = rocm_path. join ( "lib" ) ;
378377 let rocm_hip_bin = rocm_path. join ( "bin/hipcc" ) ;
379378
380- let cuda_lib = "ggml-cuda" ;
381- let cuda_file = cuda_lib. to_string ( ) + ".cu" ;
382- let cuda_header = cuda_lib. to_string ( ) + ".h" ;
383-
384- let defines = [ "GGML_USE_HIPBLAS" , "GGML_USE_CUBLAS" ] ;
379+ let defines = [ "GGML_USE_HIPBLAS" , "GGML_USE_CUDA" ] ;
385380 for def in defines {
386381 cx. define ( def, None ) ;
387382 cxx. define ( def, None ) ;
@@ -390,24 +385,39 @@ fn compile_hipblas(cx: &mut Build, cxx: &mut Build, mut hip: Build) -> &'static
390385 cx. include ( & rocm_include) ;
391386 cxx. include ( & rocm_include) ;
392387
388+ let ggml_cuda = glob ( LLAMA_PATH . join ( "ggml-cuda" ) . join ( "*.cu" ) . to_str ( ) . unwrap ( ) )
389+ . unwrap ( ) . filter_map ( Result :: ok) . collect :: < Vec < _ > > ( ) ;
390+ let ggml_template_fattn = glob ( LLAMA_PATH . join ( "ggml-cuda" ) . join ( "template-instances" ) . join ( "fattn-vec*.cu" ) . to_str ( ) . unwrap ( ) )
391+ . unwrap ( ) . filter_map ( Result :: ok) . collect :: < Vec < _ > > ( ) ;
392+ let ggml_template_wmma = glob ( LLAMA_PATH . join ( "ggml-cuda" ) . join ( "template-instances" ) . join ( "fattn-wmma*.cu" ) . to_str ( ) . unwrap ( ) )
393+ . unwrap ( ) . filter_map ( Result :: ok) . collect :: < Vec < _ > > ( ) ;
394+ let ggml_template_mmq = glob ( LLAMA_PATH . join ( "ggml-cuda" ) . join ( "template-instances" ) . join ( "mmq*.cu" ) . to_str ( ) . unwrap ( ) )
395+ . unwrap ( ) . filter_map ( Result :: ok) . collect :: < Vec < _ > > ( ) ;
396+
393397 hip. compiler ( rocm_hip_bin)
394398 . std ( "c++11" )
395- . file ( LLAMA_PATH . join ( cuda_file) )
396- . include ( LLAMA_PATH . join ( cuda_header) )
399+ . define ( "LLAMA_CUDA_DMMV_X" , Some ( "32" ) )
400+ . define ( "LLAMA_CUDA_MMV_Y" , Some ( "1" ) )
401+ . define ( "LLAMA_CUDA_KQUANTS_ITER" , Some ( "2" ) )
402+ . file ( LLAMA_PATH . join ( "ggml-cuda.cu" ) )
403+ . files ( ggml_cuda)
404+ . files ( ggml_template_fattn)
405+ . files ( ggml_template_wmma)
406+ . files ( ggml_template_mmq)
407+ . include ( LLAMA_PATH . join ( "" ) )
408+ . include ( LLAMA_PATH . join ( "ggml-cuda" ) )
397409 . define ( "GGML_USE_HIPBLAS" , None )
398- . compile ( cuda_lib) ;
410+ . define ( "GGML_USE_CUDA" , None )
411+ . compile ( "ggml-cuda" ) ;
399412
400- println ! (
401- "cargo:rustc-link-search=native={}" ,
402- rocm_lib. to_string_lossy( )
403- ) ;
413+ println ! ( "cargo:rustc-link-search=native={}" , rocm_lib. to_string_lossy( ) ) ;
404414
405415 let rocm_libs = [ "hipblas" , "rocblas" , "amdhip64" ] ;
406416 for lib in rocm_libs {
407417 println ! ( "cargo:rustc-link-lib={lib}" ) ;
408418 }
409419
410- cuda_lib
420+ "ggml-cuda"
411421}
412422
413423fn compile_cuda ( cx : & mut Build , cxx : & mut Build , featless_cxx : Build ) -> & ' static str {
0 commit comments