Skip to content

BitsandBytes Enablement on ROCm #1207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
132 commits
Select commit Hold shift + click to select a range
2e10f67
hipify the csrc repo
Lzy17 Jul 7, 2023
1928960
hipify pythoninterface
Lzy17 Jul 9, 2023
8ca0b5c
copy from agrocylo
Lzy17 Jul 9, 2023
8acbcf2
hipify cuparse and cublas calls
Lzy17 Jul 9, 2023
e80a60c
fix compile error and Makefile
Lzy17 Jul 10, 2023
fb780a0
fixed runtime error (low accuracy)
Lzy17 Jul 12, 2023
1048264
FIX LOW ACCURACY
Lzy17 Jul 12, 2023
c330020
Update README.md
Lzy17 Jul 12, 2023
fcee2d6
add benchmarks
Lzy17 Jul 12, 2023
4c0ca08
Update README.md
Lzy17 Jul 18, 2023
c798616
First draft, getting error
jpvillam-amd Oct 11, 2023
37045e5
Small transform fix, still errors on igemm
jpvillam-amd Oct 19, 2023
524fa57
create HIP_ENVIRONMENT variable
pnunna93 Nov 15, 2023
d7f7a82
Skip failing tests on rocm
pnunna93 Nov 15, 2023
28b8056
Add default value for HIP_ENVIRONMENT
pnunna93 Nov 20, 2023
9dca4fa
Merge pull request #1 from ROCmSoftwarePlatform/skip_rocm_failing_tests
amathews-amd Nov 21, 2023
38c934e
skip failing triton tests on rocm
pnunna93 Nov 21, 2023
71bf2df
Merge pull request #2 from ROCmSoftwarePlatform/skip_triton
amathews-amd Nov 21, 2023
657ca4b
Enable col to row transformation
pnunna93 Jan 12, 2024
a390e0c
Add make functions for row to col transformation
pnunna93 Jan 12, 2024
99ad6b5
Update get_transform_buffer for row to col in HIP
pnunna93 Jan 12, 2024
039b808
Update igemmlt for col format
pnunna93 Jan 12, 2024
1a052ee
Unskip test_igemmlt_int on ROCm
pnunna93 Jan 12, 2024
b7ca5cf
Update igemmlt_int test for col inputs
pnunna93 Jan 12, 2024
a2cd90d
Skip transpose igemmlt test on ROCm
pnunna93 Jan 12, 2024
5b6c5ac
Revert "Update igemmlt_int test for col inputs"
pnunna93 Jan 12, 2024
218bf66
Return nvidia_transform from transform for HIP
pnunna93 Jan 12, 2024
8bb5c2f
Fix syntax error
pnunna93 Jan 12, 2024
eb2edf7
Add comment for shape change
pnunna93 Jan 16, 2024
a38ea0f
Enable nvidia_transform tests
pnunna93 Jan 16, 2024
fbacd7a
Merge branch 'fix_igemmlt_int' of https://github.com/pnunna93/bitsand…
pnunna93 Jan 16, 2024
67c383b
Enable igemmlt_half tests
pnunna93 Jan 16, 2024
42b860f
Revert col32 check in nvidia_transform test
pnunna93 Jan 16, 2024
7198d6b
Merge pull request #3 from pnunna93/fix_igemmlt_int
amathews-amd Jan 17, 2024
b1d484a
Merge remote-tracking branch 'upstream/main' into IFU-master-2024-01-24
pnunna93 Jan 26, 2024
c36085d
Update README.md
Lzy17 Jan 26, 2024
0e91e48
Update hip files with upstream changes
pnunna93 Jan 26, 2024
1295d53
Skip failing tests for now
pnunna93 Jan 27, 2024
48b7fa9
Merge pull request #4 from ROCm/IFU-master-2024-01-24
amathews-amd Jan 30, 2024
f1a0b8b
ops.hip: adapt to enum naming changes in ROCm/hipBLASLt@95131d6 and R…
iiisak Feb 2, 2024
a84c369
fix wmma api parity
Lzy17 Feb 6, 2024
b044010
hipify wmma datatype
Lzy17 Feb 7, 2024
7aa42be
Enable estimate quantile tests
pnunna93 Feb 12, 2024
85377e1
Merge pull request #5 from iiisak/rocm_enabled
pnunna93 Feb 13, 2024
ffb0c5d
Merge pull request #7 from ROCm/fix_estimate_quantiles
amathews-amd Feb 13, 2024
2b77380
Merge pull request #6 from ROCm/rocwmma_merge
Lzy17 Feb 19, 2024
fad7918
Enable transpose flag for row to col transform
pnunna93 Feb 20, 2024
e3021ee
Update descriptors for transpose flag
pnunna93 Feb 20, 2024
8c3476f
revert nvidia_transform to transform
pnunna93 Feb 20, 2024
5e1b152
update changes
Feb 20, 2024
386e16c
Merge pull request #8 from ROCm/enable_transform_with_transpose
pnunna93 Feb 23, 2024
389bb7d
fixed minor mistakes
Feb 23, 2024
b6770bf
Merge pull request #9 from ROCm/rocm_enabled_fix_bfloat16
pnunna93 Feb 23, 2024
fa28828
remove blocksize 64 on rocm
pnunna93 Mar 6, 2024
d86d24c
remove block size 64 and enable remaining tests
pnunna93 Mar 6, 2024
cf4a506
Fix cuda build errors
pnunna93 Mar 6, 2024
7077195
remove workspace in igemmlt
pnunna93 Mar 12, 2024
ec32fc1
Enabled igemmlt in matmul
pnunna93 Mar 12, 2024
4536b25
Fix shape issue in transform function
pnunna93 Mar 12, 2024
66e34c1
Enable igemmlt int8 output
pnunna93 Mar 12, 2024
7e5e223
Add col format for extract outliers
pnunna93 Mar 12, 2024
2e42adb
Enable dequant_mm
pnunna93 Mar 12, 2024
e32d277
Enable matmullt tests
pnunna93 Mar 12, 2024
8206bd1
Enabled linear_serialization tests
pnunna93 Mar 12, 2024
973a9f8
fix error with dequant_mm change
pnunna93 Mar 12, 2024
387a9b7
Enable extract outliers test
pnunna93 Mar 12, 2024
93dfb51
Enable test overflow
pnunna93 Mar 12, 2024
90bbdc6
Skip overflow and linear serialization for now
pnunna93 Mar 12, 2024
9890d5d
Merge pull request #10 from ROCm/remove_blocksize_64
pnunna93 Mar 12, 2024
1b6dd48
Merge pull request #11 from ROCm/fix_cuda_build_errs
pnunna93 Mar 12, 2024
fc9bf4d
Merge pull request #12 from ROCm/igemm_workspace
pnunna93 Mar 12, 2024
f30dc38
Merge pull request #13 from ROCm/enable_matmul
pnunna93 Mar 12, 2024
3dc14e8
improve the gemv 4bit accuracy by forcing the hipcub to 32
Mar 18, 2024
f4ac9ac
Merge pull request #14 from ROCm/fix_gemv_4bit
Lzy17 Mar 19, 2024
485ba8f
Update skip comment
pnunna93 Mar 19, 2024
a36bd1d
Merge pull request #15 from ROCm/gemv_skip_comment
pnunna93 Mar 19, 2024
a551c16
Merge remote-tracking branch 'upstream/main' into IFU-master-2024-03-28
pnunna93 Apr 4, 2024
a267221
update instructions
Apr 9, 2024
bcdcc0b
Merge pull request #19 from ROCm/updated_readme
amathews-amd Apr 9, 2024
ff33371
Update README.md
pnunna93 Apr 9, 2024
1157e73
Merge branch 'rocm_enabled' into IFU-master-2024-03-28
pnunna93 Apr 9, 2024
702ca1a
fix PEP errors
pnunna93 Apr 9, 2024
8c23dc0
Fix typos
pnunna93 Apr 9, 2024
971f4b1
Merge branch 'IFU-master-2024-03-28' of https://github.com/ROCm/bitsa…
pnunna93 Apr 9, 2024
4d6408a
Fix formatting in README file
pnunna93 Apr 10, 2024
79cb554
Update gpu arch setting
pnunna93 Apr 18, 2024
5c0414e
Add ROCM_PATH variable
pnunna93 Apr 18, 2024
47795f5
Add HIP_VERSION variable
pnunna93 Apr 18, 2024
6d90452
Add BNB_HIP_VERSION variable
pnunna93 Apr 18, 2024
049a2dc
Update supports igemmlt based on HIP version
pnunna93 Apr 18, 2024
47a0bc3
Skip failing tests based on HIP version
pnunna93 Apr 18, 2024
1b2a095
pre-commit fixes
pnunna93 Apr 18, 2024
4515a21
Update README file
pnunna93 Apr 18, 2024
e7ef75f
Update default arch list
pnunna93 Apr 19, 2024
c0d244c
update readme
pnunna93 Apr 19, 2024
c037a30
Merge pull request #17 from ROCm/IFU-master-2024-03-28
lcskrishna Apr 19, 2024
73f4f05
Merge remote-tracking branch 'TD_BnB/multi-backend-refactor' into dev…
pnunna93 Apr 22, 2024
79652a5
update igemmlt for hip
pnunna93 Apr 22, 2024
aedfa8f
Update mm_dequant for hip
pnunna93 Apr 22, 2024
7835282
Update transform function for hip
pnunna93 Apr 22, 2024
60d7560
adding arch detection for test_gemv_eye_4bit
Apr 26, 2024
cae33c3
implement get_rocm_gpu_arch
Apr 29, 2024
da53f39
fixing lint
Apr 30, 2024
ae4dcec
fixing lint
Apr 30, 2024
21d5ff6
correct lint error
Apr 30, 2024
5bada9b
Merge pull request #21 from ROCm/rocm_enabled_arch_detect
pnunna93 Apr 30, 2024
01abfde
Merge branch 'rocm_enabled' into device_abstraction
pnunna93 May 6, 2024
765bfc8
update extract_outliers, quantize_4bit, dequantize_4bit
lcskrishna May 6, 2024
d00c026
minor fixes for extract_outliers
lcskrishna May 6, 2024
e5574bd
update blocksizes for quantize and dequantize
lcskrishna May 6, 2024
a00bd1f
Merge branch 'rocm_enabled' of https://github.com/ROCm/bitsandbytes i…
May 7, 2024
7ab3a05
update reg expression for detecting arch
lcskrishna May 7, 2024
9cd1d8c
linter updates
lcskrishna May 7, 2024
62f8ed9
Merge branch 'device_abstraction' into cl/update-device-abs
lcskrishna May 7, 2024
d9e4803
Merge pull request #23 from ROCm/cl/update-device-abs
pnunna93 May 8, 2024
2af8568
Merge remote-tracking branch 'upstream/multi-backend-refactor' into d…
pnunna93 May 9, 2024
06f6b25
skip linear no igemmlt test
pnunna93 May 9, 2024
2359452
Remove archive functional file
pnunna93 May 9, 2024
f76d6ab
Sync README with upstream
pnunna93 May 9, 2024
576b62c
Remove bnb_accuracy file
pnunna93 May 9, 2024
dfb531b
Remove cuda_setup
pnunna93 May 9, 2024
31b1cbc
Remove test_delete_later.c
pnunna93 May 9, 2024
ed77476
Sync with upstream
pnunna93 May 9, 2024
943c57a
Sync files with upstream
pnunna93 May 9, 2024
71d1702
Fix lint errors
pnunna93 May 10, 2024
6886bc8
Exclude hip files from typo checks
pnunna93 May 8, 2024
0d445f4
update ops.hip
pnunna93 May 10, 2024
bc6d0b7
Merge pull request #27 from ROCm/dev_abs_IFU
lcskrishna May 10, 2024
15c7f77
Add install steps for ROCm
pnunna93 May 10, 2024
d62c835
Fix lint error
pnunna93 May 10, 2024
8aae7c9
Merge pull request #28 from ROCm/dev_abs_add_install_steps
lcskrishna May 10, 2024
410f499
Add comments for HIP changes
pnunna93 May 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ repos:
rev: v1.18.2
hooks:
- id: typos
exclude: ^.*\.hip$
82 changes: 79 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# For GCC: `cmake -B build . && cmake --build build`
# For MSVC: `cmake -B build . && cmake --build build --config Release`
# You can also use the following options and variables
# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, or `mps` to select the backend
# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, `hip` or `mps` to select the backend
# - NO_CUBLASLT: Default OFF, will skip building/linking CUBLASLT support
# - CUDA_VERSION: The expected CUDA version, for sanity checking. The actual version
# is whatever CMake finds on your path.
Expand All @@ -26,13 +26,14 @@ endif()
# Define included source files
set(CPP_FILES csrc/common.cpp csrc/cpu_ops.cpp csrc/pythonInterface.cpp)
set(CUDA_FILES csrc/ops.cu csrc/kernels.cu)
set(HIP_FILES csrc/ops.hip csrc/kernels.hip)
set(MPS_FILES csrc/mps_ops.mm)
set(METAL_FILES csrc/mps_kernels.metal)
# C++ sources are always included
list(APPEND SRC_FILES ${CPP_FILES})

set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, mps)")
set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda mps)
set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps)")
set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps)
option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF)

if(APPLE)
Expand All @@ -49,16 +50,28 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda")
endif()
option(NO_CUBLASLT "Disable CUBLAS" OFF)
set(BUILD_CUDA ON)
set(BUILD_HIP OFF)
set(BUILD_MPS OFF)
message(STATUS "NO_CUBLASLT := ${NO_CUBLASLT}")
elseif(${COMPUTE_BACKEND} STREQUAL "hip")
if(APPLE)
message(FATAL_ERROR "HIP is not supported on macOS" )
endif()
option(NO_CUBLASLT "Disable HIPBLASLT" OFF)
set(BUILD_CUDA OFF)
set(BUILD_HIP ON)
set(BUILD_MPS OFF)
message(STATUS "NO_CUBLASLT := ${NO_CUBLASLT}")
elseif(${COMPUTE_BACKEND} STREQUAL "mps")
if(NOT APPLE)
message(FATAL_ERROR "MPS is only supported on macOS" )
endif()
set(BUILD_CUDA OFF)
set(BUILD_HIP OFF)
set(BUILD_MPS ON)
else()
set(BUILD_CUDA OFF)
set(BUILD_HIP OFF)
set(BUILD_MPS OFF)
endif()

Expand Down Expand Up @@ -158,6 +171,34 @@ if(BUILD_CUDA)
string(APPEND BNB_OUTPUT_NAME "_nocublaslt")
endif()
add_compile_definitions(BUILD_CUDA)
elseif(BUILD_HIP)
enable_language(HIP)
message(STATUS "HIP Compiler: ${CMAKE_HIP_COMPILER}")
if(DEFINED BNB_ROCM_ARCH)
set(CMAKE_HIP_ARCHITECTURES ${BNB_ROCM_ARCH})
else()
if (NOT AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
set(CMAKE_HIP_ARCHITECTURES "gfx908;gfx90a;gfx940;gfx941;gfx942")
elseif (AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS})
endif()
endif()
message(STATUS "HIP Targets: ${CMAKE_HIP_ARCHITECTURES}")

list(APPEND SRC_FILES ${HIP_FILES})

string(APPEND BNB_OUTPUT_NAME "_hip")

# get hip version
execute_process(COMMAND hipconfig --version OUTPUT_VARIABLE HIP_CONFIG_VERSION)
string(REGEX MATCH "[0-9]+\\.[0-9]+" HIP_VERSION "${HIP_CONFIG_VERSION}")

if(NO_CUBLASLT OR HIP_VERSION VERSION_LESS "6.1")
string(APPEND BNB_OUTPUT_NAME "_nohipblaslt")
endif()
add_compile_definitions(__HIP_PLATFORM_AMD__)
add_compile_definitions(__HIP_PLATFORM_HCC__)
add_compile_definitions(BUILD_HIP)
elseif(BUILD_MPS)
if(NOT APPLE)
message(FATAL_ERROR "MPS is only supported on macOS" )
Expand Down Expand Up @@ -213,6 +254,41 @@ if(BUILD_CUDA)
CUDA_SEPARABLE_COMPILATION ON
)
endif()
if(BUILD_HIP)
if(NOT DEFINED ENV{ROCM_PATH})
set(ROCM_PATH /opt/rocm)
else()
set(ROCM_PATH $ENV{ROCM_PATH})
endif()
list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH})
macro(find_package_and_print_version PACKAGE_NAME)
find_package("${PACKAGE_NAME}" ${ARGN})
message("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME}_VERSION}")
endmacro()
find_package_and_print_version(hipblas REQUIRED)
find_package_and_print_version(hiprand REQUIRED)
find_package_and_print_version(hipsparse REQUIRED)

## hacky way of excluding hip::amdhip64 (with it linked many tests unexpectedly fail e.g. adam8bit because of inaccuracies)
set_target_properties(hip::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
set_target_properties(hip-lang::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
set(CMAKE_HIP_IMPLICIT_LINK_LIBRARIES "")

target_include_directories(bitsandbytes PRIVATE ${CMAKE_SOURCE_DIR} ${CMAKE_SOURCE_DIR}/include ${ROCM_PATH}/include /include)
target_link_directories(bitsandbytes PRIVATE ${ROCM_PATH}/lib /lib)
target_link_libraries(bitsandbytes PUBLIC roc::hipblas hip::hiprand roc::hipsparse)

target_compile_definitions(bitsandbytes PUBLIC BNB_USE_HIP)
set_source_files_properties(${HIP_FILES} PROPERTIES LANGUAGE HIP)
set_target_properties(bitsandbytes PROPERTIES LINKER_LANGUAGE CXX)

if(NO_CUBLASLT OR HIP_VERSION VERSION_LESS "6.1")
target_compile_definitions(bitsandbytes PUBLIC NO_HIPBLASLT)
else()
find_package(hipblaslt)
target_link_libraries(bitsandbytes PUBLIC roc::hipblaslt)
endif()
endif()
if(BUILD_MPS)
add_dependencies(bitsandbytes metallib)
target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph")
Expand Down
3 changes: 3 additions & 0 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch

from bitsandbytes.cextension import BNB_HIP_VERSION
import bitsandbytes.functional as F


Expand Down Expand Up @@ -222,6 +223,8 @@ def supports_igemmlt(device: torch.device) -> bool:
"""check if this device supports the optimized int8 kernel"""
if device == torch.device("cpu"):
return True
if torch.version.hip:
return False if BNB_HIP_VERSION < 601 else True
if torch.cuda.get_device_capability(device=device) < (7, 5):
return False
device_name = torch.cuda.get_device_name(device=device)
Expand Down
97 changes: 73 additions & 24 deletions bitsandbytes/backends/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

from bitsandbytes.cextension import lib
from bitsandbytes.cextension import HIP_ENVIRONMENT, lib
from bitsandbytes.functional import (
CUBLAS_Context,
coo_zeros,
Expand All @@ -14,6 +14,7 @@
get_ptr,
get_transform_buffer,
is_on_gpu,
nvidia_transform,
post_call,
pre_call,
prod,
Expand Down Expand Up @@ -184,6 +185,11 @@ def transform(
state: Optional[Tuple[torch.Size, str]] = None,
ld=None,
):
if HIP_ENVIRONMENT:
# transform kernel formats (col32/col_turing/col_ampere) are not applicable to ROCm
# Use nvidia_transform instead
return nvidia_transform(A, to_order, from_order, out, transpose, state, ld)

prev_device = pre_call(A.device)
if state is None:
state = (A.shape, from_order)
Expand Down Expand Up @@ -266,19 +272,33 @@ def igemmlt(
return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16)

if dimsA == 2 and out is None:
out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, "col32", "row")
if HIP_ENVIRONMENT:
# Use col format for HIP
out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, "col", "row")
else:
out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, "col32", "row")
elif dimsA == 3 and out is None:
out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row")
if HIP_ENVIRONMENT:
# Use col format for HIP
out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col", "row")
else:
out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row")

assert dimsB != 3, "len(B.shape)==3 not supported"
assert A.device.type == "cuda"
assert B.device.type == "cuda"
assert A.dtype == torch.int8
assert B.dtype == torch.int8
assert out.dtype == dtype
assert SA[1] == "col32"
assert SB[1] in ["col_turing", "col_ampere"]
assert Sout[1] == "col32"
if HIP_ENVIRONMENT:
# Use col format for HIP
assert SA[1] == "col"
assert SB[1] == "col"
assert Sout[1] == "col"
else:
assert SA[1] == "col32"
assert SB[1] in ["col_turing", "col_ampere"]
assert Sout[1] == "col32"
assert (
shapeA[-1] == shapeB[-1]
), f"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}"
Expand All @@ -293,17 +313,23 @@ def igemmlt(
ptrC = get_ptr(out)

k = shapeA[-1]
lda = ct.c_int32(m * 32)
if formatB == "col_turing":
# turing: tiles with rows filled up to multiple of 8 rows by 32 columns
# n = rows
ldb = ct.c_int32(((rows + 7) // 8) * 8 * 32)
if HIP_ENVIRONMENT:
# Set ld values for col format
lda = ct.c_int32(m)
ldb = ct.c_int32(shapeB[0])
ldc = ct.c_int32(m)
else:
# ampere: tiles with rows filled up to multiple of 32 rows by 32 columns
# n = rows
ldb = ct.c_int32(((rows + 31) // 32) * 32 * 32)
lda = ct.c_int32(m * 32)
if formatB == "col_turing":
# turing: tiles with rows filled up to multiple of 8 rows by 32 columns
# n = rows
ldb = ct.c_int32(((rows + 7) // 8) * 8 * 32)
else:
# ampere: tiles with rows filled up to multiple of 32 rows by 32 columns
# n = rows
ldb = ct.c_int32(((rows + 31) // 32) * 32 * 32)

ldc = ct.c_int32(m * 32)
ldc = ct.c_int32(m * 32)
m = ct.c_int32(m)
n = ct.c_int32(n)
k = ct.c_int32(k)
Expand All @@ -312,7 +338,7 @@ def igemmlt(
ptrRowScale = get_ptr(None)
is_on_gpu([A, B, out])

if formatB == "col_turing":
if formatB == "col_turing" or HIP_ENVIRONMENT:
if dtype == torch.int32:
has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
else:
Expand All @@ -324,7 +350,7 @@ def igemmlt(
else:
has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)

if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`
if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`, `ops.hip`
raise NotImplementedError("igemmlt not available (probably built with NO_CUBLASLT)")

if has_error:
Expand All @@ -348,6 +374,9 @@ def mm_dequant(
new_col_stats: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
):
if HIP_ENVIRONMENT:
# HIP kernel requires 'row' format
A, quant_state = nvidia_transform(A, "row", state=quant_state)
assert A.dtype == torch.int32
if bias is not None:
assert bias.dtype == torch.float16
Expand Down Expand Up @@ -386,7 +415,11 @@ def mm_dequant(
def extract_outliers(self, A: torch.Tensor, SA: Tuple[torch.Size, str], idx: torch.Tensor):
shapeA = SA[0]
formatA = SA[1]
assert formatA in ["col_turing", "col_ampere"]
if not HIP_ENVIRONMENT:
assert formatA in ["col_turing", "col_ampere"]
else:
# HIP uses col format
assert formatA in ["col"]
assert A.device.type == "cuda"

out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device)
Expand All @@ -400,7 +433,7 @@ def extract_outliers(self, A: torch.Tensor, SA: Tuple[torch.Size, str], idx: tor

prev_device = pre_call(A.device)

if formatA == "col_turing":
if formatA == "col_turing" or HIP_ENVIRONMENT:
lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
elif formatA == "col_ampere":
lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
Expand All @@ -414,11 +447,15 @@ def quantize_4bit(
A: torch.Tensor,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize=64,
blocksize: Optional[int] = None,
compress_statistics=False,
quant_type: Literal["fp4", "nf4"] = "fp4",
quant_storage=torch.uint8,
) -> Tuple[torch.Tensor, QuantState]:
if blocksize is None:
# Some AMD GPUs have warpsize 64
# Set default blocksize to 128 (~warpsize 64 in kernel) for HIP
blocksize = 64 if not HIP_ENVIRONMENT else 128
if A.device.type != "cuda":
raise NotImplementedError(f"Device type not supported for FP4 quantization: {A.device.type}")
if quant_type not in ["fp4", "nf4"]:
Expand All @@ -436,7 +473,12 @@ def quantize_4bit(
mod = dtype2bytes[quant_storage] * 2
out = torch.zeros(((n + 1) // mod, 1), dtype=quant_storage, device=A.device)

assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
# Some AMD GPUs have warpsize 64
# Set min blocksize to 128 (~warpsize 64 in kernel) for HIP
if not HIP_ENVIRONMENT:
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
else:
assert blocksize in [4096, 2048, 1024, 512, 256, 128]

prev_device = pre_call(A.device)
is_on_gpu([A, out, absmax])
Expand Down Expand Up @@ -507,12 +549,19 @@ def dequantize_4bit(
quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize: int = 64,
blocksize: Optional[int] = None,
quant_type: Literal["fp4", "nf4"] = "fp4",
) -> torch.Tensor:
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
# Some AMD GPUs have warpsize 64
# Set default blocksize to 128 (~warpsize 64 in kernel) for HIP
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
supported_blocksizes = [2048, 4096, 1024, 512, 256, 128, 64]
if HIP_ENVIRONMENT:
supported_blocksizes = supported_blocksizes[:-1]
if blocksize not in supported_blocksizes:
raise ValueError(
f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]"
f"The blockwise of {blocksize} is not supported. Supported values: {supported_blocksizes}"
)

if quant_type not in ["fp4", "nf4"]:
Expand Down
Loading
Loading