Skip to content

Commit

Permalink
fix: ROCm build (#817)
Browse files Browse the repository at this point in the history
* Some fixed (ig)

* Oopsie :3

* Remove a comment

* indent the block

* another indentation

* Revert stuff (hopefully)

---------

Co-authored-by: AlpinDale <[email protected]>
Co-authored-by: AlpinDale <[email protected]>
  • Loading branch information
3 people authored Dec 5, 2024
1 parent 9b56927 commit 4f9fea4
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 17 deletions.
30 changes: 14 additions & 16 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11" "3.12")
set(CUDA_SUPPORTED_ARCHS "6.0;6.1;7.0;7.5;8.0;8.6;8.9;9.0")

# Supported AMD GPU architectures.
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100")
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101")

#
# Supported/expected torch versions for CUDA/ROCm.
Expand Down Expand Up @@ -65,20 +65,19 @@ endif()
# etc.
#
find_package(Torch REQUIRED)
find_package(CUDA REQUIRED)
find_package(CUDAToolkit REQUIRED)

# Add cuBLAS to the list of libraries to link against
list(APPEND LIBS CUDA::cublas)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)

# Replace -std=c++20 with -std=c++17 in APHRODITE_GPU_FLAGS
if(APHRODITE_GPU_LANG STREQUAL "CUDA")
list(APPEND APHRODITE_GPU_FLAGS "--std=c++17" "-Xcompiler -Wno-return-type")
if(MSVC)
find_package(CUDA REQUIRED)
find_package(CUDAToolkit REQUIRED)
# Add cuBLAS to the list of libraries to link against
list(APPEND LIBS CUDA::cublas)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
# Replace -std=c++20 with -std=c++17 in APHRODITE_GPU_FLAGS
if(APHRODITE_GPU_LANG STREQUAL "CUDA")
list(APPEND APHRODITE_GPU_FLAGS "--std=c++17" "-Xcompiler -Wno-return-type")
endif()
endif()

#
Expand Down Expand Up @@ -222,7 +221,6 @@ if(APHRODITE_GPU_LANG STREQUAL "CUDA")
"kernels/permute_cols.cu"
"kernels/sampling/sampling.cu")

# Add CUTLASS and GPTQ Marlin kernels if not MSVC
if(NOT MSVC)
# Include CUTLASS only when needed
include(FetchContent)
Expand Down
2 changes: 1 addition & 1 deletion amdpatch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

ROCM_PATH=$(hipconfig --rocmpath)

sudo patch $ROCM_PATH/lib/llvm/lib/clang/18/include/__clang_hip_cmath.h ./patches/amd.patch
sudo patch $ROCM_PATH/lib/llvm/lib/clang/*/include/__clang_hip_cmath.h ./patches/amd.patch

0 comments on commit 4f9fea4

Please sign in to comment.