Skip to content

Commit

Permalink
CUDA: faster q2_K, q3_K MMQ + int8 tensor cores (#7921)
Browse files Browse the repository at this point in the history
* CUDA: faster q2_K, q3_K MMQ + int8 tensor cores

* try CI fix

* try CI fix

* try CI fix

* fix data race

* rever q2_K precision related changes
  • Loading branch information
JohannesGaessler authored Jun 14, 2024
1 parent 66ef1ce commit 76d66ee
Show file tree
Hide file tree
Showing 6 changed files with 457 additions and 319 deletions.
6 changes: 4 additions & 2 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,15 @@ static ggml_cuda_device_info ggml_cuda_init() {
info.default_tensor_split[id] = total_vram;
total_vram += prop.totalGlobalMem;

info.devices[id].nsm = prop.multiProcessorCount;
info.devices[id].smpb = prop.sharedMemPerBlock;
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
info.devices[id].smpbo = prop.sharedMemPerBlock;
info.devices[id].cc = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD;
#else
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
info.devices[id].cc = 100*prop.major + 10*prop.minor;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
info.devices[id].smpb = prop.sharedMemPerBlock;
info.devices[id].nsm = prop.multiProcessorCount;
}

for (int id = 0; id < info.device_count; ++id) {
Expand Down
1 change: 1 addition & 0 deletions ggml-cuda/argsort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
const dim3 block_nums(1, nrows, 1);
const size_t shared_mem = ncols_pad * sizeof(int);

// FIXME: this limit could be raised by ~2-4x on Ampere or newer
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);

if (order == GGML_SORT_ORDER_ASC) {
Expand Down
5 changes: 5 additions & 0 deletions ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,10 @@ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int
#define FP16_AVAILABLE
#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL

#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
#define FAST_FP16_AVAILABLE
#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610

#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
#define FP16_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
Expand Down Expand Up @@ -661,6 +665,7 @@ struct ggml_cuda_device_info {
int cc; // compute capability
int nsm; // number of streaming multiprocessors
size_t smpb; // max. shared memory per block
size_t smpbo; // max. shared memory per block (with opt-in)
bool vmm; // virtual memory support
size_t vmm_granularity; // granularity of virtual memory
size_t total_vram;
Expand Down
Loading

0 comments on commit 76d66ee

Please sign in to comment.