From cec14e051595a64c95ed987764117f3715dbe2b4 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Wed, 21 Aug 2024 14:36:19 +0530 Subject: [PATCH] Merging main (#4) * Fixed single GPU issue without setting up mp. Added toggles for server request batching parameters (#114) * Fixed single GPU issue without setting up mp. Added toggles for server request batching parameters * Adding HTTP headers * Add distributed executor backend to benchmark scripts (#118) * Add weight padding for moe (#119) * add weight padding for moe * enable padding by default * fix linter * fix linter * fix linter * using envs.py * fix linter * [BugFix] Fix navi build after many custom for MI kernels added (#116) * fix navi build * Created dummy kernels of unsupported on Navi to avoid function not found crashes at runtime * replacing ifdefs on host code with those on kernels * refactoring code to avoid unsupported call on Navi * syntactic change * import statements fix * moving env variables to envs.py * style fixes * cosmetic changes for isort * remved extra include * moving use_skinny to be member --------- Co-authored-by: lcskrishna Co-authored-by: maleksan85 Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> * add emtpy_cache() after each padding (#120) * [FIX] Gradlib OOM on Navi and sometimes on MI (#124) * add memory clean up after every shape and parameter to reduce cache invalidation buffers * small typo * syntax change --------- Co-authored-by: maleksan85 * save shape when fp8 solution not found (#123) Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> * Fix unit test for moe by adding padding (#128) * fix test_moe * fix linter * Llama3.1 (#129) * Add support for a rope extension method (#6553) * [BugFix] Fix RoPE error in Llama 3.1 (#6693) --------- Co-authored-by: Simon Mo Co-authored-by: Woosuk Kwon * chat/completions endpoint (#121) * Initial implementation of chat/completions endpoint and its streaming variant * Reusing datatypes from the openai entrypoints * Response role from arg * Added models endpoint and model validation from the request * Optimize custom all reduce (#130) * First version * Revert error. While there, add missing finalize. * Use the correct defaults for ROCm. Increase sampling area to capture crossover. * Scope end_sync as well. * Guard only volatile keyword for ifndef USE_ROCM * Document crossover * Add BF16 support to custom PA (#133) * tightened atol for custom PA; enable supported head size, block sizes in testing * update num_blocks and num_iters in benchmark PA to realistic settings * move to generic b16 type * bf16 first port * enabled all bf16 tests, set atol for bf16 * enable custom PA for bf16 as well as block size 32 and head size 64 * fix cast to zero in custom PA reduce * py linter fixes * clang format fixes * div round up clang-format --------- Co-authored-by: Charlie Fu Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> * Making check for output match in original types. It saves some memory. (#135) Co-authored-by: maleksan85 * Make CAR ROCm 6.1 compatible. (#137) * remove scoping * while there fix a typo * while there remove unused variable * Car revert (#140) * Per @iotamudelta suggestion until the deadlocks issue is better understood Revert "Make CAR ROCm 6.1 compatible. (#137)" This reverts commit 4d2dda61c18bf93fa591cd84a5481ee9dd8ee428. * Per @iotamudelta suggestion until the deadlocks issue is better understood Revert "Optimize custom all reduce (#130)" This reverts commit 636ff019a1c9164321ae4414b1b933cddf853b7e. * Using the correct datatypes for streaming non-chat completions (#134) * Adding UNREACHABLE_CODE macro for non MI300 and MI250 cards (#138) * Adding UNREACHABLE_CODE macro * clang format fixes * clang formatting fix * minor updates in syntax * clang format update * clang format fix one more try * clang format one more try * clang format fix one more try --------- Co-authored-by: Aleksandr Malyshev * gfx90a typo fix (#142) Co-authored-by: maleksan85 * wvsplitk templatized and better tuned for MI300 (#132) * improvements to wvSpltK * wvsplt gemm; better handle MI300 and large A[] sizes * lint fix * Adjustments to better handle small weights in TP8. * early-out bug fix * better wave load balancing in wvSplt * add missing skip for wvsplt_big * Bug fix for wvSplt_big in load balancing at M4, lint fix. * [Bugfix] Dockerfile.rocm (#141) * Dockerfile.rocm bug fix * naming preference --------- Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> * Update test-template.j2 (#145) * Adding Triton implementations awq_dequantize and awq_gemm to ROCm (#136) * basic support for AWQ added * awq_dequantize implementation in Triton * awq_gemm implementation in Triton * unit tests in tests/kernels/test_awq_triton.py --------- Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: Matt Wong <156021403+mawong-amd@users.noreply.github.com> Co-authored-by: Charlie Fu Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Co-authored-by: lcskrishna Co-authored-by: maleksan85 Co-authored-by: Simon Mo Co-authored-by: Woosuk Kwon Co-authored-by: iotamudelta Co-authored-by: sanyalington Co-authored-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Co-authored-by: Zachary Streeter <90640993+zstreet87@users.noreply.github.com> Co-authored-by: omkar kakarparthi <75638701+okakarpa@users.noreply.github.com> Co-authored-by: rasmith --- .buildkite/test-template.j2 | 4 +- CMakeLists.txt | 2 +- Dockerfile.rocm | 4 +- benchmarks/benchmark_latency.py | 53 +- benchmarks/benchmark_throughput.py | 15 +- .../kernels/benchmark_paged_attention.py | 4 +- csrc/custom/custom_kernels.cu | 1472 +++++------------ .../custom/paged_attention/attention_ll4mi.cu | 560 ++++--- gradlib/gradlib/GemmTuner.py | 12 +- tests/kernels/test_attention_custom.py | 46 +- tests/kernels/test_awq_triton.py | 237 +++ tests/kernels/test_moe.py | 21 +- vllm/_custom_ops.py | 37 +- vllm/attention/ops/paged_attn.py | 13 +- vllm/config.py | 58 +- vllm/engine/arg_utils.py | 2 +- vllm/entrypoints/fast_sync_llm.py | 21 +- vllm/entrypoints/sync_openai/api_server.py | 257 ++- vllm/entrypoints/sync_openai/protocol.py | 170 -- vllm/envs.py | 31 + .../layers/fused_moe/fused_moe.py | 9 +- .../layers/quantization/awq_triton.py | 311 ++++ .../layers/quantization/fp8_rocm.py | 2 + .../model_executor/layers/rotary_embedding.py | 166 +- vllm/model_executor/layers/tuned_gemm.py | 7 + vllm/model_executor/models/mixtral.py | 11 + 26 files changed, 2019 insertions(+), 1506 deletions(-) create mode 100644 tests/kernels/test_awq_triton.py delete mode 100644 vllm/entrypoints/sync_openai/protocol.py create mode 100644 vllm/model_executor/layers/quantization/awq_triton.py diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 index bdb6e05fd337c..fe83f1e6f7d68 100644 --- a/.buildkite/test-template.j2 +++ b/.buildkite/test-template.j2 @@ -11,7 +11,7 @@ steps: - "docker push {{ docker_image_amd }}" plugins: - docker-login#v3.0.0: - username: rocmshared + username: rocm key: "amd-build" env: DOCKER_BUILDKIT: "1" @@ -38,4 +38,4 @@ steps: priority: 100 soft_fail: true {% endif %} -{% endfor %} \ No newline at end of file +{% endfor %} diff --git a/CMakeLists.txt b/CMakeLists.txt index 2b3679a0548c9..38b67cd707be6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,7 +19,7 @@ set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11") set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0") # Supported AMD GPU architectures. -set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100") +set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101") # # Supported/expected torch versions for CUDA/ROCm. diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 39590cef0bd6a..726090fa212e1 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -22,8 +22,8 @@ USER root ARG BASE_IMAGE ARG COMMON_WORKDIR # Used as ARCHes for all components -ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942" -ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} +ARG ARG_PYTORCH_ROCM_ARCH="gfx90a;gfx942" +ENV PYTORCH_ROCM_ARCH=${ARG_PYTORCH_ROCM_ARCH} # Install some basic utilities RUN apt-get update && apt-get install python3 python3-pip - diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 400bb9936e02d..0532550359d6d 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -19,27 +19,30 @@ def main(args: argparse.Namespace): # NOTE(woosuk): If the request cannot be processed in a single batch, # the engine will automatically process the request in multiple batches. - llm = LLM(model=args.model, - speculative_model=args.speculative_model, - num_speculative_tokens=args.num_speculative_tokens, - tokenizer=args.tokenizer, - quantization=args.quantization, - quantized_weights_path=args.quantized_weights_path, - tensor_parallel_size=args.tensor_parallel_size, - trust_remote_code=args.trust_remote_code, - dtype=args.dtype, - enforce_eager=args.enforce_eager, - kv_cache_dtype=args.kv_cache_dtype, - quantization_param_path=args.quantization_param_path, - device=args.device, - ray_workers_use_nsight=args.ray_workers_use_nsight, - worker_use_ray=args.worker_use_ray, - use_v2_block_manager=args.use_v2_block_manager, - enable_chunked_prefill=args.enable_chunked_prefill, - download_dir=args.download_dir, - block_size=args.block_size, - disable_custom_all_reduce=args.disable_custom_all_reduce, - gpu_memory_utilization=args.gpu_memory_utilization) + llm = LLM( + model=args.model, + speculative_model=args.speculative_model, + num_speculative_tokens=args.num_speculative_tokens, + tokenizer=args.tokenizer, + quantization=args.quantization, + quantized_weights_path=args.quantized_weights_path, + tensor_parallel_size=args.tensor_parallel_size, + trust_remote_code=args.trust_remote_code, + dtype=args.dtype, + enforce_eager=args.enforce_eager, + kv_cache_dtype=args.kv_cache_dtype, + quantization_param_path=args.quantization_param_path, + device=args.device, + ray_workers_use_nsight=args.ray_workers_use_nsight, + worker_use_ray=args.worker_use_ray, + use_v2_block_manager=args.use_v2_block_manager, + enable_chunked_prefill=args.enable_chunked_prefill, + download_dir=args.download_dir, + block_size=args.block_size, + disable_custom_all_reduce=args.disable_custom_all_reduce, + gpu_memory_utilization=args.gpu_memory_utilization, + distributed_executor_backend=args.distributed_executor_backend, + ) sampling_params = SamplingParams( n=args.n, @@ -237,5 +240,13 @@ def run_to_completion(profile_dir: Optional[str] = None): help='the fraction of GPU memory to be used for ' 'the model executor, which can range from 0 to 1.' 'If unspecified, will use the default value of 0.9.') + parser.add_argument( + '--distributed-executor-backend', + choices=['ray', 'mp', 'torchrun'], + default=None, + help='Backend to use for distributed serving. When more than 1 GPU ' + 'is used, on CUDA this will be automatically set to "ray" if ' + 'installed or "mp" (multiprocessing) otherwise. On ROCm, this is ' + 'instead set to torchrun by default.') args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index eb59e38fa2c9d..302746e316514 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -79,6 +79,7 @@ def run_vllm( enable_prefix_caching: bool, enable_chunked_prefill: bool, max_num_batched_tokens: int, + distributed_executor_backend: Optional[str], gpu_memory_utilization: float = 0.9, worker_use_ray: bool = False, download_dir: Optional[str] = None, @@ -104,6 +105,7 @@ def run_vllm( download_dir=download_dir, enable_chunked_prefill=enable_chunked_prefill, max_num_batched_tokens=max_num_batched_tokens, + distributed_executor_backend=distributed_executor_backend, ) # Add the requests to the engine. @@ -229,8 +231,9 @@ def main(args: argparse.Namespace): args.max_model_len, args.enforce_eager, args.kv_cache_dtype, args.quantization_param_path, args.device, args.enable_prefix_caching, args.enable_chunked_prefill, - args.max_num_batched_tokens, args.gpu_memory_utilization, - args.worker_use_ray, args.download_dir) + args.max_num_batched_tokens, args.distributed_executor_backend, + args.gpu_memory_utilization, args.worker_use_ray, + args.download_dir) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -384,6 +387,14 @@ def main(args: argparse.Namespace): type=str, default=None, help='Path to save the throughput results in JSON format.') + parser.add_argument( + '--distributed-executor-backend', + choices=['ray', 'mp', 'torchrun'], + default=None, + help='Backend to use for distributed serving. When more than 1 GPU ' + 'is used, on CUDA this will be automatically set to "ray" if ' + 'installed or "mp" (multiprocessing) otherwise. On ROCm, this is ' + 'instead set to torchrun by default.') args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index d0d990410bc6e..f95a1f488bcf7 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -9,7 +9,7 @@ from vllm._custom_C import paged_attention_custom from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random -NUM_BLOCKS = 1024 +NUM_BLOCKS = 1024 * 1024 PARTITION_SIZE = 256 @@ -176,7 +176,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: if do_profile: latency = run_benchmark(num_iters=1, profile=True) else: - latency = run_benchmark(num_iters=100, profile=False) + latency = run_benchmark(num_iters=1000, profile=False) print(f"Kernel running time: {latency * 1000000:.3f} us") diff --git a/csrc/custom/custom_kernels.cu b/csrc/custom/custom_kernels.cu index 6b7969f035d8d..18679f86e82c1 100644 --- a/csrc/custom/custom_kernels.cu +++ b/csrc/custom/custom_kernels.cu @@ -3,6 +3,20 @@ #include #include +#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \ + defined(__gfx941__) || defined(__gfx942__)) + #define __HIP__MI300_MI250__ +#endif + +#if defined(NDEBUG) + #undef NDEBUG + #include + #define UNREACHABLE_CODE assert(false); + #define NDEBUG +#else + #define UNREACHABLE_CODE assert(false); +#endif + constexpr int WARP_SIZE = 64; template @@ -312,93 +326,221 @@ void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, ///////////////////////////////////////////// -using half8 = __attribute__((__vector_size__(4 * sizeof(float)))) float; - -/*template -__device__ __forceinline__ T loadnt(T* addr) { - return __builtin_nontemporal_load(addr); - //return *((T*)addr); -}*/ - -#define THRDS 64 -#define YTILE 2 -#define WvPrGrp 16 -#define A_CHUNK 8 -#define UNRL 2 -#define M 1 #define DTYPE half -__global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, - const int CuCount) { +__device__ __forceinline__ int mindiv(int N, int div1, int div2) { + int nPrRnd = div1 * div2; + int rnds0 = N / nPrRnd; + nPrRnd -= div1 * 3; + int rnds3 = N / nPrRnd; + nPrRnd -= div1; + int rnds4 = N / nPrRnd; + nPrRnd -= div1; + int rnds5 = N / nPrRnd; + nPrRnd -= div1; + int rnds6 = N / nPrRnd; + nPrRnd -= div1; + int rnds7 = N / nPrRnd; + nPrRnd -= div1; + int rnds8 = N / nPrRnd; + nPrRnd -= div1; + int rnds9 = N / nPrRnd; + nPrRnd -= div1; + int rtn = div2; + if (rnds0 == rnds3) rtn = div2 - 3; + if (rnds0 == rnds4) rtn = div2 - 4; + if (rnds0 == rnds5) rtn = div2 - 5; + if (rnds0 == rnds6) rtn = div2 - 6; + if (rnds0 == rnds7) rtn = div2 - 7; + if (rnds0 == rnds8) rtn = div2 - 8; + if (rnds0 == rnds9) rtn = div2 - 9; + return rtn; +} + +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +// This version targets cases where A[] fits LDS capacity +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { + using half8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; union bigType { DTYPE h[A_CHUNK]; float f[A_CHUNK / 2]; float2 f2[A_CHUNK / 4]; double d[A_CHUNK / 4]; - __int128_t b128; half8 h8; }; + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- __shared__ half s[1024 * 32]; - uint64_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + // uint32_t commitColumn[YTILE]; + // for (uint32_t i = 0; i < YTILE; i++) { + // commitColumn[i] = 1; + //} + + // It's worth trying to load-balance... + int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); + + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + uint32_t n = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + // if (n < N && (n + YTILE) >= N) { + // uint32_t startColumn = N - YTILE; + // for (uint32_t i = 0; i < (n - startColumn); i++) { + // commitColumn[i] = 0; + // } + // n = startColumn; + //} + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- for (uint32_t k = 0; k < min(K * M, 32 * 1024); k += THRDS * WvPrGrp * A_CHUNK) { uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + // Transpose of A implementation + // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for + // bank-conflict-free readback + if (k_in >= min(K * M, 32 * 1024)) break; - ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + + //((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; } __syncthreads(); + if (threadIdx.y >= _WvPrGrp) return; + float sum[M][YTILE]; + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of available columns + //---------------------------------------------------- while (n < N) { + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // split across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- for (int i = 0; i < YTILE; i++) for (int m = 0; m < M; m++) sum[m][i] = 0; bigType bigA[M][UNRL]; bigType bigB0[UNRL]; -#if (YTILE >= 2) bigType bigB1[UNRL]; -#endif + bigType bigB2[UNRL]; + bigType bigB3[UNRL]; + bigType bigB4[UNRL]; + bigType bigB5[UNRL]; + bigType bigB6[UNRL]; + bigType bigB7[UNRL]; + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { -#pragma unroll + // Fetch the weight matrix from memory! + #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; + const half* B_ = &B[(n + 0) * K + k_]; bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); -#if (YTILE >= 2) - bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); -#endif + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- + if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); + if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); + if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); + if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); + if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); + if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); + if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); } + // Fetch activation matrix from either just LDS or from both LDS / memory -#pragma unroll + #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; // Fetch A activation matrix in interleaved fashion from LDS or memory + for (int m = 0; m < M; m++) { + // if (k_ + K * m < 32 * 1024) bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); + // else + // bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); } } // Do the matrix multiplication in interleaved manner -#pragma unroll + #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; -#pragma unroll - for (uint32_t m = 0; m < M; m++) { // Do the matrix multiplication of activation and weight matrix // - Remember the accumulation is happening for K-split of 64! -#pragma unroll + #pragma unroll + for (uint32_t m = 0; m < M; m++) { + #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][0]) @@ -407,11 +549,34 @@ __global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B, //---------------------------------------------------- // The following code with YTILE > 1 //---------------------------------------------------- -#if (YTILE >= 2) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][1]) - : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); -#endif + if (YTILE >= 2) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][1]) + : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); + if (YTILE >= 3) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][2]) + : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); + if (YTILE >= 4) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][3]) + : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); + if (YTILE >= 5) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][4]) + : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); + if (YTILE >= 6) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][5]) + : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); + if (YTILE >= 7) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][6]) + : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); + if (YTILE >= 8) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][7]) + : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); } } } @@ -442,28 +607,50 @@ __global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B, : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); } } - if (threadIdx.x == 63) { for (int m = 0; m < M; m++) { for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); C[n + i + m * N] = __float2half(sum[m][i]); } } } - n += CuCount * WvPrGrp * YTILE; + n += CuCount * _WvPrGrp * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + // if (n < N && (n + YTILE) >= N) { + // uint32_t startColumn = N - YTILE; + // for (uint32_t i = 0; i < (n - startColumn); i++) { + // commitColumn[i] = 0; + // } + // n = startColumn; + //} } } - -__global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, - const int CuCount) { +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +template +__global__ void wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +// This version targets cases where A[] marginally exceeds LDS capacity +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSpltK_hf_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { + using half8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; union bigType { DTYPE h[A_CHUNK]; float f[A_CHUNK / 2]; float2 f2[A_CHUNK / 4]; double d[A_CHUNK / 4]; - __int128_t b128; half8 h8; }; @@ -484,12 +671,15 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, commitColumn[i] = 1; } + // It's worth trying to load-balance... + int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); + //---------------------------------------------------- // Indexing function into the column of weight matrix B // Algorithm does 64 lane k-splitting / wave and uses // WG ID and Thread ID to find the index. //---------------------------------------------------- - uint64_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; + uint32_t n = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; // Check whether there will be fragmenation! // This will happen only for the last wave! @@ -520,11 +710,14 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, if (k_in >= min(K * M, 32 * 1024)) break; - ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + //((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; } __syncthreads(); + if (threadIdx.y >= _WvPrGrp) return; + float sum[M][YTILE]; //---------------------------------------------------- @@ -555,36 +748,14 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, bigType bigA[M][UNRL]; bigType bigB0[UNRL]; -#if (YTILE >= 2) bigType bigB1[UNRL]; -#endif -#if (YTILE >= 3) bigType bigB2[UNRL]; -#endif -#if (YTILE >= 4) bigType bigB3[UNRL]; -#endif -#if (YTILE >= 5) bigType bigB4[UNRL]; -#endif -#if (YTILE >= 6) bigType bigB5[UNRL]; -#endif -#if (YTILE >= 7) bigType bigB6[UNRL]; -#endif -#if (YTILE >= 8) bigType bigB7[UNRL]; -#endif -#if (YTILE >= 9) bigType bigB8[UNRL]; -#endif -#if (YTILE >= 10) - bigType bigB9[UNRL]; -#endif -#if (YTILE >= 11) - bigType bigB10[UNRL]; -#endif //---------------------------------------------------- // Fetch weight matrix B in interleaved K-split! // - Each thread (lane) is fetching 8 elements (A_Chunk) @@ -604,61 +775,28 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, //---------------------------------------------------- for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { // Fetch the weight matrix from memory! -#pragma unroll + #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; - // if (k_ >= K) break; - // bool skip = (k_ >= K); - // bool dummy = (k_ >= K); - const half* B_ = &B[(n + 0) * K + k_]; bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); //---------------------------------------------------- // The following code with YTILE > 1 has to be deleted //---------------------------------------------------- -#if (YTILE >= 2) - // if (n+1>=N) continue; - bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); -#endif -#if (YTILE >= 3) - // if (n+2>=N) continue; - bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); -#endif -#if (YTILE >= 4) - // if (n+3>=N) continue; - bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); -#endif -#if (YTILE >= 5) - // if (n+4>=N) continue; - bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); -#endif -#if (YTILE >= 6) - // if (n+5>=N) continue; - bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); -#endif -#if (YTILE >= 7) - // if (n+6>=N) continue; - bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); -#endif -#if (YTILE >= 8) - // if (n+7>=N) continue; - bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); -#endif - /* - #if (YTILE >= 9) - if (n+8>=N) continue; bigB8[k2].h8 = - (loadnt((half8*)(&B_[8 * K]))); #endif #if (YTILE >= 10) if (n+9>=N) - continue; bigB9[k2].h8 = (loadnt((half8*)(&B_[9 * K]))); #endif #if - (YTILE >= 11) if (n+10>=N) continue; bigB10[k2].h8 = - (loadnt((half8*)(&B_[10 * K]))); #endif - */ + if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); + if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); + if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); + if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); + if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); + if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); + if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); } // Fetch activation matrix from either just LDS or from both LDS / memory -#pragma unroll + #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; @@ -675,16 +813,16 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, } // Do the matrix multiplication in interleaved manner -#pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; -#pragma unroll - for (uint32_t m = 0; m < M; m++) { - // Do the matrix multiplication of activation and weight matrix - // - Remember the accumulation is happening for K-split of 64! -#pragma unroll + #pragma unroll + for (uint32_t m = 0; m < M; m++) { + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! + #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][0]) @@ -693,56 +831,34 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, //---------------------------------------------------- // The following code with YTILE > 1 //---------------------------------------------------- -#if (YTILE >= 2) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][1]) - : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); -#endif -#if (YTILE >= 3) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][2]) - : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); -#endif -#if (YTILE >= 4) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][3]) - : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); -#endif -#if (YTILE >= 5) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][4]) - : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); -#endif -#if (YTILE >= 6) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][5]) - : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); -#endif -#if (YTILE >= 7) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][6]) - : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); -#endif -#if (YTILE >= 8) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][7]) - : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); -#endif -#if (YTILE >= 9) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][8]) - : "0"(sum[m][8]), "v"(bigA[m][k2].f[b]), "v"(bigB8[k2].f[b])); -#endif -#if (YTILE >= 10) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][9]) - : "0"(sum[m][9]), "v"(bigA[m][k2].f[b]), "v"(bigB9[k2].f[b])); -#endif -#if (YTILE >= 11) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][10]) - : "0"(sum[m][10]), "v"(bigA[m][k2].f[b]), "v"(bigB10[k2].f[b])); -#endif + if (YTILE >= 2) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][1]) + : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); + if (YTILE >= 3) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][2]) + : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); + if (YTILE >= 4) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][3]) + : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); + if (YTILE >= 5) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][4]) + : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); + if (YTILE >= 6) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][5]) + : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); + if (YTILE >= 7) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][6]) + : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); + if (YTILE >= 8) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][7]) + : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); } } } @@ -782,11 +898,7 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, } } - n += CuCount * WvPrGrp * YTILE; - - // if (threadIdx.x == 0) - // n = atomicAdd(((unsigned int*)(C)), YTILE); - // n = __shfl(n, 0, 64); + n += CuCount * _WvPrGrp * YTILE; // Check whether there will be fragmenation! // This will happen only for the last wave! @@ -800,23 +912,29 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, } } -#undef YTILE -#undef UNRL -#undef M +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +template +__global__ void wvSpltK_hf_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support -#define YTILE 2 -#define UNRL 2 -#define M 2 +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +// This version targets big A[] cases, where it is much larger than LDS capacity +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSpltK_hf_big_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { + using half8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; -__global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, - const int CuCount) { union bigType { DTYPE h[A_CHUNK]; float f[A_CHUNK / 2]; float2 f2[A_CHUNK / 4]; double d[A_CHUNK / 4]; - __int128_t b128; half8 h8; }; @@ -837,12 +955,16 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, commitColumn[i] = 1; } + // It's worth trying to load-balance... + int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); + if (threadIdx.y >= _WvPrGrp) return; + //---------------------------------------------------- // Indexing function into the column of weight matrix B // Algorithm does 64 lane k-splitting / wave and uses // WG ID and Thread ID to find the index. //---------------------------------------------------- - uint64_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; + uint32_t n = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; // Check whether there will be fragmenation! // This will happen only for the last wave! @@ -863,6 +985,8 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, // - Then the WG will move to another 8 K elements // TODO: Logic below will only work when K is multiple of 8 //---------------------------------------------------- + #define PCML + #ifndef PCML for (uint32_t k = 0; k < min(K * M, 32 * 1024); k += THRDS * WvPrGrp * A_CHUNK) { uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); @@ -873,10 +997,24 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, if (k_in >= min(K * M, 32 * 1024)) break; - ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + //((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; } __syncthreads(); + #endif + + #define TUC (THRDS * UNRL * A_CHUNK) + uint32_t kBase = 0; + // find biggest k size that fits in LDS + uint32_t kFit = (32 * 1024) / M; + // kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple + // of TUC + kFit = (kFit % TUC == 0) + ? kFit + : (kFit - kFit % TUC); // round up to multiple of TUC + // if (kFit == 0) kFit = TUC; + kFit = min(kFit, K); float sum[M][YTILE]; @@ -895,7 +1033,13 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, // - After completing first set of columns, WGs start // working on the next set of available columns //---------------------------------------------------- + #ifdef PCML + int YW = (YTILE * _WvPrGrp); + uint32_t Nrndp = (N % YW == 0) ? N : (N - N % YW + YW); + while (n < Nrndp) { + #else while (n < N) { + #endif //---------------------------------------------------- // 'sum' accumulates the matrix A x B computation // split across 64 lanes. @@ -908,36 +1052,16 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, bigType bigA[M][UNRL]; bigType bigB0[UNRL]; -#if (YTILE >= 2) bigType bigB1[UNRL]; -#endif -#if (YTILE >= 3) bigType bigB2[UNRL]; -#endif -#if (YTILE >= 4) bigType bigB3[UNRL]; -#endif -#if (YTILE >= 5) bigType bigB4[UNRL]; -#endif -#if (YTILE >= 6) bigType bigB5[UNRL]; -#endif -#if (YTILE >= 7) bigType bigB6[UNRL]; -#endif -#if (YTILE >= 8) bigType bigB7[UNRL]; -#endif -#if (YTILE >= 9) bigType bigB8[UNRL]; -#endif -#if (YTILE >= 10) bigType bigB9[UNRL]; -#endif -#if (YTILE >= 11) bigType bigB10[UNRL]; -#endif //---------------------------------------------------- // Fetch weight matrix B in interleaved K-split! // - Each thread (lane) is fetching 8 elements (A_Chunk) @@ -956,62 +1080,48 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, // TODO: Logic below will only work when K is multiple of 8 //---------------------------------------------------- for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { - // Fetch the weight matrix from memory! -#pragma unroll + #ifdef PCML + if ((k1 == 0) || (k1 == kBase + kFit)) { // load next chunk of A[] to LDS + if (k1 != 0) kBase += kFit; + __syncthreads(); + for (uint32_t k = 0; k < kFit; k += THRDS * _WvPrGrp * A_CHUNK) { + uint32_t kOff = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + if (kBase + kOff >= K) break; + if (kOff >= kFit) break; + for (uint32_t m = 0; m < M; m++) { + uint32_t k_in = kBase + m * K + kOff; + uint32_t k_ot = m * kFit + kOff; + *((bigType*)(&s[k_ot])) = *((bigType*)(&A[k_in])); + } + } + __syncthreads(); + } + if (n >= N) continue; + #endif + + // Fetch the weight matrix from memory! + #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; - // if (k_ >= K) break; - // bool skip = (k_ >= K); - // bool dummy = (k_ >= K); - const half* B_ = &B[(n + 0) * K + k_]; bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); //---------------------------------------------------- // The following code with YTILE > 1 has to be deleted //---------------------------------------------------- -#if (YTILE >= 2) - // if (n+1>=N) continue; - bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); -#endif -#if (YTILE >= 3) - // if (n+2>=N) continue; - bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); -#endif -#if (YTILE >= 4) - // if (n+3>=N) continue; - bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); -#endif -#if (YTILE >= 5) - // if (n+4>=N) continue; - bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); -#endif -#if (YTILE >= 6) - // if (n+5>=N) continue; - bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); -#endif -#if (YTILE >= 7) - // if (n+6>=N) continue; - bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); -#endif -#if (YTILE >= 8) - // if (n+7>=N) continue; - bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); -#endif - /* - #if (YTILE >= 9) - if (n+8>=N) continue; bigB8[k2].h8 = - (loadnt((half8*)(&B_[8 * K]))); #endif #if (YTILE >= 10) if (n+9>=N) - continue; bigB9[k2].h8 = (loadnt((half8*)(&B_[9 * K]))); #endif #if - (YTILE >= 11) if (n+10>=N) continue; bigB10[k2].h8 = - (loadnt((half8*)(&B_[10 * K]))); #endif - */ + if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); + if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); + if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); + if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); + if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); + if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); + if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); } // Fetch activation matrix from either just LDS or from both LDS / memory -#pragma unroll + #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; @@ -1020,24 +1130,28 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, // Fetch A activation matrix in interleaved fashion from LDS or memory for (int m = 0; m < M; m++) { + #ifdef PCML + bigA[m][k2] = *((const bigType*)(&(s[k_ - kBase + kFit * m]))); + #else if (k_ + K * m < 32 * 1024) bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); else bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); + #endif } } // Do the matrix multiplication in interleaved manner -#pragma unroll + #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; -#pragma unroll + #pragma unroll for (uint32_t m = 0; m < M; m++) { // Do the matrix multiplication of activation and weight matrix // - Remember the accumulation is happening for K-split of 64! -#pragma unroll + #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][0]) @@ -1046,61 +1160,47 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, //---------------------------------------------------- // The following code with YTILE > 1 //---------------------------------------------------- -#if (YTILE >= 2) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][1]) - : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); -#endif -#if (YTILE >= 3) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][2]) - : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); -#endif -#if (YTILE >= 4) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][3]) - : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); -#endif -#if (YTILE >= 5) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][4]) - : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); -#endif -#if (YTILE >= 6) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][5]) - : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); -#endif -#if (YTILE >= 7) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][6]) - : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); -#endif -#if (YTILE >= 8) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][7]) - : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); -#endif -#if (YTILE >= 9) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][8]) - : "0"(sum[m][8]), "v"(bigA[m][k2].f[b]), "v"(bigB8[k2].f[b])); -#endif -#if (YTILE >= 10) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][9]) - : "0"(sum[m][9]), "v"(bigA[m][k2].f[b]), "v"(bigB9[k2].f[b])); -#endif -#if (YTILE >= 11) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][10]) - : "0"(sum[m][10]), "v"(bigA[m][k2].f[b]), "v"(bigB10[k2].f[b])); -#endif + if (YTILE >= 2) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][1]) + : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); + if (YTILE >= 3) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][2]) + : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); + if (YTILE >= 4) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][3]) + : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); + if (YTILE >= 5) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][4]) + : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); + if (YTILE >= 6) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][5]) + : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); + if (YTILE >= 7) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][6]) + : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); + if (YTILE >= 8) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][7]) + : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); } } } } + #ifdef PCML + if (n >= N) { + n += CuCount * _WvPrGrp * YTILE; + kBase = 0; + continue; + } + #endif + //---------------------------------------------------- // Final reduction step using shuffle //---------------------------------------------------- @@ -1135,11 +1235,8 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, } } - n += CuCount * WvPrGrp * YTILE; - - // if (threadIdx.x == 0) - // n = atomicAdd(((unsigned int*)(C)), YTILE); - // n = __shfl(n, 0, 64); + n += CuCount * _WvPrGrp * YTILE; + kBase = 0; // Check whether there will be fragmenation! // This will happen only for the last wave! @@ -1152,742 +1249,53 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, } } } +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +template +__global__ void wvSpltK_hf_big_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support -#undef YTILE -#undef UNRL -#undef M - -#define YTILE 5 -#define UNRL 2 -#define M 3 - -__global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, - const int CuCount) { - union bigType { - DTYPE h[A_CHUNK]; - float f[A_CHUNK / 2]; - float2 f2[A_CHUNK / 4]; - double d[A_CHUNK / 4]; - __int128_t b128; - half8 h8; - }; - - //---------------------------------------------------- - // Reserving 64 KB of LDS to have 1 WG / CU - // Goal is to bring the activation matrix A to the LDS - // and use it across the lifetime of the work group - // TODO: When activation matrix is larger than 64 KB - // then this is not goint to work! - //---------------------------------------------------- - __shared__ half s[1024 * 32]; +void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, + const int K_in, const int N_in, cudaStream_t stream, + const int CuCount = 0) { + dim3 grid(CuCount); + half* af4 = reinterpret_cast(in_a); + const half* bf4 = reinterpret_cast(in_b); + auto* c = reinterpret_cast(out_c); - //---------------------------------------------------- - // Computation of columns that need to be committed to memory! - //---------------------------------------------------- - uint32_t commitColumn[YTILE]; - for (uint32_t i = 0; i < YTILE; i++) { - commitColumn[i] = 1; +#define WVSPLTK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ + _N) \ + { \ + dim3 block(64, _WvPrGrp); \ + /*wvSpltK_hf:*/ \ + if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \ + wvSpltK_hf_sml_<64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \ + <<>>(K_in, M_in, af4, bf4, c, CuCount); \ + } else if (K_in * N_in <= 32 * 1024 * 1.2) { \ + wvSpltK_hf_<64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \ + <<>>(K_in, M_in, af4, bf4, c, CuCount); \ + } else { \ + wvSpltK_hf_big_<64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \ + <<>>(K_in, M_in, af4, bf4, c, CuCount); \ + } \ } - //---------------------------------------------------- - // Indexing function into the column of weight matrix B - // Algorithm does 64 lane k-splitting / wave and uses - // WG ID and Thread ID to find the index. - //---------------------------------------------------- - uint64_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; - - // Check whether there will be fragmenation! - // This will happen only for the last wave! - if (n < N && (n + YTILE) >= N) { - uint32_t startColumn = N - YTILE; - for (uint32_t i = 0; i < (n - startColumn); i++) { - commitColumn[i] = 0; - } - n = startColumn; - } - - //---------------------------------------------------- - // Fetch the activation matrix to LDS - // Loop iteration: - // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements - // - Each WG will fetch 512 * 16 => 8K elements - // - Then the WG will move to another 8 K elements - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- - for (uint32_t k = 0; k < min(K * M, 32 * 1024); - k += THRDS * WvPrGrp * A_CHUNK) { - uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - - // Transpose of A implementation - // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for - // bank-conflict-free readback - - if (k_in >= min(K * M, 32 * 1024)) break; - - ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; - //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; - } - __syncthreads(); - - float sum[M][YTILE]; - - //---------------------------------------------------- - // Each wave works on a single column of weight matrix. - // There are 16 waves per WG, and hence, each WG is - // working on 16 columns of weight matrix. Moreover, - // we tile in column direction by YTILE, so when YTILE=1 - // the above math is right, however, when YTILE=2 then - // each wave will be working on 2 columns and WG will - // be working on 32 columns. - // - // Top level loop that makes WGs persistent! - // - WGs iterates across columns of weight matrix - // - Each wave within WG works on a given column(s) - // - After completing first set of columns, WGs start - // working on the next set of available columns - //---------------------------------------------------- - while (n < N) { - //---------------------------------------------------- - // 'sum' accumulates the matrix A x B computation - // split across 64 lanes. - // - // YTILE represents how many column of weight matrix - // are being worked on by each wave. - //---------------------------------------------------- - for (int i = 0; i < YTILE; i++) - for (int m = 0; m < M; m++) sum[m][i] = 0; - - bigType bigA[M][UNRL]; - bigType bigB0[UNRL]; -#if (YTILE >= 2) - bigType bigB1[UNRL]; -#endif -#if (YTILE >= 3) - bigType bigB2[UNRL]; -#endif -#if (YTILE >= 4) - bigType bigB3[UNRL]; -#endif -#if (YTILE >= 5) - bigType bigB4[UNRL]; -#endif -#if (YTILE >= 6) - bigType bigB5[UNRL]; -#endif -#if (YTILE >= 7) - bigType bigB6[UNRL]; -#endif -#if (YTILE >= 8) - bigType bigB7[UNRL]; -#endif -#if (YTILE >= 9) - bigType bigB8[UNRL]; -#endif -#if (YTILE >= 10) - bigType bigB9[UNRL]; -#endif -#if (YTILE >= 11) - bigType bigB10[UNRL]; -#endif - //---------------------------------------------------- - // Fetch weight matrix B in interleaved K-split! - // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements (1024B) - // - YTILE represents the number of column being serviced - // by wave - // - Loop for fetching weight matrix (B) are unrolled - // - // Fetch activation matrix A from LDS - // - Loop for fetching activation matrix (A) are unrolled - // - // Finally, do the matrix multiplication in an unrolled - // fashion. This provides lot of food for compiler - // scheduling. - // - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- - for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { - // Fetch the weight matrix from memory! -#pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - - // if (k_ >= K) break; - // bool skip = (k_ >= K); - // bool dummy = (k_ >= K); - - const half* B_ = &B[(n + 0) * K + k_]; - bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); - //---------------------------------------------------- - // The following code with YTILE > 1 has to be deleted - //---------------------------------------------------- -#if (YTILE >= 2) - // if (n+1>=N) continue; - bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); -#endif -#if (YTILE >= 3) - // if (n+2>=N) continue; - bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); -#endif -#if (YTILE >= 4) - // if (n+3>=N) continue; - bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); -#endif -#if (YTILE >= 5) - // if (n+4>=N) continue; - bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); -#endif -#if (YTILE >= 6) - // if (n+5>=N) continue; - bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); -#endif -#if (YTILE >= 7) - // if (n+6>=N) continue; - bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); -#endif -#if (YTILE >= 8) - // if (n+7>=N) continue; - bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); -#endif - /* - #if (YTILE >= 9) - if (n+8>=N) continue; bigB8[k2].h8 = - (loadnt((half8*)(&B_[8 * K]))); #endif #if (YTILE >= 10) if (n+9>=N) - continue; bigB9[k2].h8 = (loadnt((half8*)(&B_[9 * K]))); #endif #if - (YTILE >= 11) if (n+10>=N) continue; bigB10[k2].h8 = - (loadnt((half8*)(&B_[10 * K]))); #endif - */ - } - - // Fetch activation matrix from either just LDS or from both LDS / memory -#pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - - // Fetch A activation matrix in interleaved fashion from LDS or memory - - for (int m = 0; m < M; m++) { - if (k_ + K * m < 32 * 1024) - bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); - else - bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); - } - } - - // Do the matrix multiplication in interleaved manner -#pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; -#pragma unroll - for (uint32_t m = 0; m < M; m++) { - // Do the matrix multiplication of activation and weight matrix - // - Remember the accumulation is happening for K-split of 64! -#pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][0]) - : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); - - //---------------------------------------------------- - // The following code with YTILE > 1 - //---------------------------------------------------- -#if (YTILE >= 2) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][1]) - : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); -#endif -#if (YTILE >= 3) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][2]) - : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); -#endif -#if (YTILE >= 4) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][3]) - : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); -#endif -#if (YTILE >= 5) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][4]) - : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); -#endif -#if (YTILE >= 6) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][5]) - : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); -#endif -#if (YTILE >= 7) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][6]) - : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); -#endif -#if (YTILE >= 8) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][7]) - : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); -#endif -#if (YTILE >= 9) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][8]) - : "0"(sum[m][8]), "v"(bigA[m][k2].f[b]), "v"(bigB8[k2].f[b])); -#endif -#if (YTILE >= 10) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][9]) - : "0"(sum[m][9]), "v"(bigA[m][k2].f[b]), "v"(bigB9[k2].f[b])); -#endif -#if (YTILE >= 11) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][10]) - : "0"(sum[m][10]), "v"(bigA[m][k2].f[b]), "v"(bigB10[k2].f[b])); -#endif - } - } - } - } - - //---------------------------------------------------- - // Final reduction step using shuffle - //---------------------------------------------------- - for (int m = 0; m < M; m++) { - for (int y = 0; y < YTILE; y++) { - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - } - } - - if (threadIdx.x == 63) { - for (int m = 0; m < M; m++) { - for (int i = 0; i < YTILE; i++) { - if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); - } - } - } - - n += CuCount * WvPrGrp * YTILE; - - // if (threadIdx.x == 0) - // n = atomicAdd(((unsigned int*)(C)), YTILE); - // n = __shfl(n, 0, 64); - - // Check whether there will be fragmenation! - // This will happen only for the last wave! - if (n < N && (n + YTILE) >= N) { - uint32_t startColumn = N - YTILE; - for (uint32_t i = 0; i < (n - startColumn); i++) { - commitColumn[i] = 0; - } - n = startColumn; - } - } -} - -#undef YTILE -#undef UNRL -#undef M - -#define YTILE 7 -#define UNRL 1 -#define M 4 - -__global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, - const int CuCount) { - union bigType { - DTYPE h[A_CHUNK]; - float f[A_CHUNK / 2]; - float2 f2[A_CHUNK / 4]; - double d[A_CHUNK / 4]; - __int128_t b128; - half8 h8; - }; - - //---------------------------------------------------- - // Reserving 64 KB of LDS to have 1 WG / CU - // Goal is to bring the activation matrix A to the LDS - // and use it across the lifetime of the work group - // TODO: When activation matrix is larger than 64 KB - // then this is not goint to work! - //---------------------------------------------------- - __shared__ half s[1024 * 32]; - - //---------------------------------------------------- - // Computation of columns that need to be committed to memory! - //---------------------------------------------------- - uint32_t commitColumn[YTILE]; - for (uint32_t i = 0; i < YTILE; i++) { - commitColumn[i] = 1; - } - - //---------------------------------------------------- - // Indexing function into the column of weight matrix B - // Algorithm does 64 lane k-splitting / wave and uses - // WG ID and Thread ID to find the index. - //---------------------------------------------------- - uint64_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; - - // Check whether there will be fragmenation! - // This will happen only for the last wave! - if (n < N && (n + YTILE) >= N) { - uint32_t startColumn = N - YTILE; - for (uint32_t i = 0; i < (n - startColumn); i++) { - commitColumn[i] = 0; - } - n = startColumn; - } - - //---------------------------------------------------- - // Fetch the activation matrix to LDS - // Loop iteration: - // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements - // - Each WG will fetch 512 * 16 => 8K elements - // - Then the WG will move to another 8 K elements - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- - for (uint32_t k = 0; k < min(K * M, 32 * 1024); - k += THRDS * WvPrGrp * A_CHUNK) { - uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - - // Transpose of A implementation - // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for - // bank-conflict-free readback - - if (k_in >= min(K * M, 32 * 1024)) break; - - ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; - //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; - } - __syncthreads(); - - float sum[M][YTILE]; - - //---------------------------------------------------- - // Each wave works on a single column of weight matrix. - // There are 16 waves per WG, and hence, each WG is - // working on 16 columns of weight matrix. Moreover, - // we tile in column direction by YTILE, so when YTILE=1 - // the above math is right, however, when YTILE=2 then - // each wave will be working on 2 columns and WG will - // be working on 32 columns. - // - // Top level loop that makes WGs persistent! - // - WGs iterates across columns of weight matrix - // - Each wave within WG works on a given column(s) - // - After completing first set of columns, WGs start - // working on the next set of available columns - //---------------------------------------------------- - while (n < N) { - //---------------------------------------------------- - // 'sum' accumulates the matrix A x B computation - // split across 64 lanes. - // - // YTILE represents how many column of weight matrix - // are being worked on by each wave. - //---------------------------------------------------- - for (int i = 0; i < YTILE; i++) - for (int m = 0; m < M; m++) sum[m][i] = 0; - - bigType bigA[M][UNRL]; - bigType bigB0[UNRL]; -#if (YTILE >= 2) - bigType bigB1[UNRL]; -#endif -#if (YTILE >= 3) - bigType bigB2[UNRL]; -#endif -#if (YTILE >= 4) - bigType bigB3[UNRL]; -#endif -#if (YTILE >= 5) - bigType bigB4[UNRL]; -#endif -#if (YTILE >= 6) - bigType bigB5[UNRL]; -#endif -#if (YTILE >= 7) - bigType bigB6[UNRL]; -#endif -#if (YTILE >= 8) - bigType bigB7[UNRL]; -#endif -#if (YTILE >= 9) - bigType bigB8[UNRL]; -#endif -#if (YTILE >= 10) - bigType bigB9[UNRL]; -#endif -#if (YTILE >= 11) - bigType bigB10[UNRL]; -#endif - //---------------------------------------------------- - // Fetch weight matrix B in interleaved K-split! - // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements (1024B) - // - YTILE represents the number of column being serviced - // by wave - // - Loop for fetching weight matrix (B) are unrolled - // - // Fetch activation matrix A from LDS - // - Loop for fetching activation matrix (A) are unrolled - // - // Finally, do the matrix multiplication in an unrolled - // fashion. This provides lot of food for compiler - // scheduling. - // - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- - for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { - // Fetch the weight matrix from memory! -#pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - - // if (k_ >= K) break; - // bool skip = (k_ >= K); - // bool dummy = (k_ >= K); - - const half* B_ = &B[(n + 0) * K + k_]; - bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); - //---------------------------------------------------- - // The following code with YTILE > 1 has to be deleted - //---------------------------------------------------- -#if (YTILE >= 2) - // if (n+1>=N) continue; - bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); -#endif -#if (YTILE >= 3) - // if (n+2>=N) continue; - bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); -#endif -#if (YTILE >= 4) - // if (n+3>=N) continue; - bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); -#endif -#if (YTILE >= 5) - // if (n+4>=N) continue; - bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); -#endif -#if (YTILE >= 6) - // if (n+5>=N) continue; - bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); -#endif -#if (YTILE >= 7) - // if (n+6>=N) continue; - bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); -#endif -#if (YTILE >= 8) - // if (n+7>=N) continue; - bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); -#endif - /* - #if (YTILE >= 9) - if (n+8>=N) continue; bigB8[k2].h8 = - (loadnt((half8*)(&B_[8 * K]))); #endif #if (YTILE >= 10) if (n+9>=N) - continue; bigB9[k2].h8 = (loadnt((half8*)(&B_[9 * K]))); #endif #if - (YTILE >= 11) if (n+10>=N) continue; bigB10[k2].h8 = - (loadnt((half8*)(&B_[10 * K]))); #endif - */ - } - - // Fetch activation matrix from either just LDS or from both LDS / memory -#pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - - // Fetch A activation matrix in interleaved fashion from LDS or memory - - for (int m = 0; m < M; m++) { - if (k_ + K * m < 32 * 1024) - bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); - else - bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); - } - } - - // Do the matrix multiplication in interleaved manner -#pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; -#pragma unroll - for (uint32_t m = 0; m < M; m++) { - // Do the matrix multiplication of activation and weight matrix - // - Remember the accumulation is happening for K-split of 64! -#pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][0]) - : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); - - //---------------------------------------------------- - // The following code with YTILE > 1 - //---------------------------------------------------- -#if (YTILE >= 2) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][1]) - : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); -#endif -#if (YTILE >= 3) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][2]) - : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); -#endif -#if (YTILE >= 4) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][3]) - : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); -#endif -#if (YTILE >= 5) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][4]) - : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); -#endif -#if (YTILE >= 6) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][5]) - : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); -#endif -#if (YTILE >= 7) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][6]) - : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); -#endif -#if (YTILE >= 8) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][7]) - : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); -#endif -#if (YTILE >= 9) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][8]) - : "0"(sum[m][8]), "v"(bigA[m][k2].f[b]), "v"(bigB8[k2].f[b])); -#endif -#if (YTILE >= 10) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][9]) - : "0"(sum[m][9]), "v"(bigA[m][k2].f[b]), "v"(bigB9[k2].f[b])); -#endif -#if (YTILE >= 11) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][10]) - : "0"(sum[m][10]), "v"(bigA[m][k2].f[b]), "v"(bigB10[k2].f[b])); -#endif - } - } - } - } - - //---------------------------------------------------- - // Final reduction step using shuffle - //---------------------------------------------------- - for (int m = 0; m < M; m++) { - for (int y = 0; y < YTILE; y++) { - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - } - } - - if (threadIdx.x == 63) { - for (int m = 0; m < M; m++) { - for (int i = 0; i < YTILE; i++) { - if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); - } - } - } - - n += CuCount * WvPrGrp * YTILE; - - // if (threadIdx.x == 0) - // n = atomicAdd(((unsigned int*)(C)), YTILE); - // n = __shfl(n, 0, 64); - - // Check whether there will be fragmenation! - // This will happen only for the last wave! - if (n < N && (n + YTILE) >= N) { - uint32_t startColumn = N - YTILE; - for (uint32_t i = 0; i < (n - startColumn); i++) { - commitColumn[i] = 0; - } - n = startColumn; - } - } -} - -void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, - const int K_in, const int N_in, cudaStream_t stream, - const int CuCount = 0) { - dim3 grid(CuCount); - dim3 block(THRDS, WvPrGrp); - half* af4 = reinterpret_cast(in_a); - const half* bf4 = reinterpret_cast(in_b); - auto* c = reinterpret_cast(out_c); switch (N_in) { case 1: - if ((K_in <= 32 * 1024) && (M_in % 2 == 0)) { - wvSpltK_hf_m1_sml_<<>>(K_in, M_in, af4, bf4, c, - CuCount); - } else { - wvSpltK_hf_m1_<<>>(K_in, M_in, af4, bf4, c, - CuCount); - } + WVSPLTK(16, 2, 2, 2, 2, 2, 2, 1) // MI308 break; case 2: - wvSpltK_hf_m2_<<>>(K_in, M_in, af4, bf4, c, - CuCount); + WVSPLTK(16, 2, 2, 2, 2, 2, 2, 2) // MI308 break; case 3: - wvSpltK_hf_m3_<<>>(K_in, M_in, af4, bf4, c, - CuCount); + WVSPLTK(16, 4, 7, 7, 1, 1, 1, 3) // MI308 break; case 4: - wvSpltK_hf_m4_<<>>(K_in, M_in, af4, bf4, c, - CuCount); + WVSPLTK(16, 4, 7, 7, 1, 1, 1, 4) // MI308 break; default: throw std::runtime_error("Unsupported N value: " + std::to_string(M_in) + @@ -1899,4 +1307,4 @@ void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, if (cudaSuccess != err) { throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); } -} +} \ No newline at end of file diff --git a/csrc/custom/paged_attention/attention_ll4mi.cu b/csrc/custom/paged_attention/attention_ll4mi.cu index dcabc7932cfd5..97674cccb15fb 100644 --- a/csrc/custom/paged_attention/attention_ll4mi.cu +++ b/csrc/custom/paged_attention/attention_ll4mi.cu @@ -2,16 +2,33 @@ #include #include #include +#include #include +#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \ + defined(__gfx941__) || defined(__gfx942__)) + #define __HIP__MI300_MI250__ +#endif + +#if defined(NDEBUG) + #undef NDEBUG + #include + #define UNREACHABLE_CODE assert(false); + #define NDEBUG +#else + #define UNREACHABLE_CODE assert(false); +#endif + #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) #define WARP_SIZE 64 -#define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32 -#define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16 +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support + + #define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32 + #define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16 using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; using float16x4 = @@ -20,9 +37,17 @@ typedef float16x4 _Half4; typedef struct _Half8 { _Half4 xy[2]; } _Half8; + +using bit16_t = uint16_t; +using bit16x4 = __attribute__((__vector_size__(4 * sizeof(uint16_t)))) uint16_t; +typedef bit16x4 _B16x4; +typedef struct _B16x8 { + _B16x4 xy[2]; +} _B16x8; + ////// Non temporal load stores /////// -#if 1 + #if 1 template __device__ __forceinline__ T load(T* addr) { @@ -34,7 +59,7 @@ __device__ __forceinline__ void store(T value, T* addr) { addr[0] = value; } -#else + #else template __device__ __forceinline__ T load(const T* addr) { @@ -109,7 +134,103 @@ __device__ __forceinline__ void store(T value, T* addr) { return __builtin_nontemporal_store(value, addr); } -#endif + #endif + +template +__device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA, + const _B16x4& inpB, + const floatx4& inpC) { + if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_4x4x4f16(inpA, inpB, inpC, absz, cbid, + blgp); + } else if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(inpA, inpB, inpC, absz, cbid, + blgp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ float to_float(const T& inp) { + if constexpr (std::is_same::value) { + return (float)inp; + } else if constexpr (std::is_same::value) { + return __bfloat162float(inp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ T from_float(const float& inp) { + if constexpr (std::is_same::value) { + return (_Float16)inp; + } else if constexpr (std::is_same::value) { + return __float2bfloat16(inp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) { + union tmpcvt { + uint16_t u; + _Float16 f; + __hip_bfloat16 b; + } t16; + _B16x4 ret; + if constexpr (std::is_same::value) { + #pragma unroll + for (int i = 0; i < 4; i++) { + t16.f = (_Float16)inp[i]; + ret[i] = t16.u; + } + return ret; + } else if constexpr (std::is_same::value) { + #pragma unroll + for (int i = 0; i < 4; i++) { + t16.b = __float2bfloat16(inp[i]); + ret[i] = t16.u; + } + return ret; + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, + const _B16x4& inp2) { + union tmpcvt { + uint16_t u; + _Float16 f; + __hip_bfloat16 b; + } t1, t2, res; + _B16x4 ret; + if constexpr (std::is_same::value) { + #pragma unroll + for (int i = 0; i < 4; i++) { + t1.u = inp1[i]; + t2.u = inp2[i]; + res.f = t1.f + t2.f; + ret[i] = res.u; + } + return ret; + } else if constexpr (std::is_same::value) { + #pragma unroll + for (int i = 0; i < 4; i++) { + t1.u = inp1[i]; + t2.u = inp2[i]; + res.b = t1.b + t2.b; + ret[i] = res.u; + } + return ret; + } else { + static_assert(false, "unsupported 16b dtype"); + } +} /////////////////////////////////////// @@ -135,9 +256,9 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, // head_size] scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] -#if 0 + #if 0 scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size] -#endif + #endif int max_ctx_blocks) { constexpr int NWARPS = NUM_THREADS / WARP_SIZE; const int warpid = threadIdx.x / WARP_SIZE; @@ -161,19 +282,19 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( constexpr int GQA_RATIO4 = 4 * QHLOOP; __shared__ float shared_qk_max[NWARPS][GQA_RATIO4 + 1]; __shared__ float shared_exp_sum[NWARPS][GQA_RATIO4 + 1]; - _Half8 Qlocal[QHLOOP]; + _B16x8 Qlocal[QHLOOP]; constexpr int x = 16 / sizeof(scalar_t); constexpr int KHELOOP = HEAD_SIZE / x; - _Half8 Klocal[KHELOOP]; + _B16x8 Klocal[KHELOOP]; constexpr int VHELOOP = HEAD_SIZE / WARP_SIZE; // v head_size dimension is distributed across lanes constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2 // 8xtokens - _Half8 Vlocal[VHELOOP][VTLOOP]; + _B16x8 Vlocal[VHELOOP][VTLOOP]; floatx4 dout[QHLOOP]; float qk_max[QHLOOP]; -#pragma unroll + #pragma unroll for (int h = 0; h < QHLOOP; h++) { dout[h] = {0}; qk_max[h] = -FLT_MAX; @@ -186,7 +307,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( partition_start_token_idx + warpid * WARP_SIZE; if (warp_start_token_idx >= context_len) { // warp out of context -#pragma unroll + #pragma unroll for (int h = 0; h < GQA_RATIO4; h++) { shared_qk_max[warpid][h] = -FLT_MAX; shared_exp_sum[warpid][h] = 0.0f; @@ -204,18 +325,30 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( const int block_idx = (global_token_idx < context_len) ? global_token_idx / BLOCK_SIZE : last_ctx_block; - + // fetch block number for q and k // int32 physical_block_number leads to overflow when multiplied with // kv_block_stride const int64_t physical_block_number = static_cast(block_table[block_idx]); + // fetch vphysical block numbers up front + constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE; + int vphysical_blocks[VBLOCKS]; + + const int warp_start_block_idx = warp_start_token_idx / BLOCK_SIZE; + #pragma unroll + for (int b = 0; b < VBLOCKS; b++) { + const int vblock_idx = warp_start_block_idx + b; + const int vblock_idx_ctx = + (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block; + vphysical_blocks[b] = block_table[vblock_idx_ctx]; + } // each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems const scalar_t* q_ptr = q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE; - const _Half8* q_ptrh8 = reinterpret_cast(q_ptr); + const _B16x8* q_ptrh8 = reinterpret_cast(q_ptr); const int qhead_elemh8 = laneid / 4; -#pragma unroll + #pragma unroll for (int h = 0; h < QHLOOP - 1; h++) { const int qhead_idx = h * 4 + lane4id; Qlocal[h] = q_ptrh8[qhead_idx * HEAD_SIZE / 8 + qhead_elemh8]; @@ -231,20 +364,20 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride + wg_start_kv_head_idx * kv_head_stride; - const _Half8* k_ptrh8 = reinterpret_cast(k_ptr); const int physical_block_offset = local_token_idx % BLOCK_SIZE; // since x=half8, physical_block_offset // is already cast as _H8 -#pragma unroll + const _B16x8* k_ptrh8 = reinterpret_cast(k_ptr); + #pragma unroll for (int d = 0; d < KHELOOP; d++) { Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset]; } float alibi_slope[QHLOOP]; if (alibi_slopes != nullptr) { -#pragma unroll + #pragma unroll for (int h = 0; h < QHLOOP; h++) { const int qhead_idx = h * 4 + lane4id; alibi_slope[h] = (qhead_idx < GQA_RATIO) @@ -253,119 +386,106 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } - constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE; - int vphysical_blocks[VBLOCKS]; - - const int warp_start_block_idx = warp_start_token_idx / BLOCK_SIZE; -// fetch vphysical block numbers -#pragma unroll - for (int b = 0; b < VBLOCKS; b++) { - const int vblock_idx = warp_start_block_idx + b; - const int vblock_idx_ctx = - (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block; - vphysical_blocks[b] = block_table[vblock_idx_ctx]; - } - const scalar_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; - const _Half8* v_ptrh8 = reinterpret_cast(v_ptr); -// iterate over each v block -#pragma unroll + const _B16x8* v_ptrh8 = reinterpret_cast(v_ptr); + // iterate over each v block + #pragma unroll for (int b = 0; b < VBLOCKS; b++) { // int32 physical_block_number leads to overflow when multiplied with // kv_block_stride const int64_t vphysical_block_number = static_cast(vphysical_blocks[b]); - const _Half8* v_ptrh8b = + const _B16x8* v_ptrh8b = v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; -// iterate over each head elem (within head_size) -#pragma unroll + // iterate over each head elem (within head_size) + #pragma unroll for (int h = 0; h < VHELOOP; h++) { const int head_size_elem = h * WARP_SIZE + laneid; - const _Half8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; -// iterate over all velems within block -#pragma unroll + const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; + // iterate over all velems within block + #pragma unroll for (int d = 0; d < BLOCK_SIZE / 8; d++) { Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; } } } -#pragma unroll + #pragma unroll for (int h = 0; h < QHLOOP; h++) { - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[0].xy[0], dout[h], 4, 0, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[0].xy[1], dout[h], 4, 0, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[1].xy[0], dout[h], 4, 1, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[1].xy[1], dout[h], 4, 1, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[2].xy[0], dout[h], 4, 2, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[2].xy[1], dout[h], 4, 2, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[3].xy[0], dout[h], 4, 3, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[3].xy[1], dout[h], 4, 3, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[4].xy[0], dout[h], 4, 4, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[4].xy[1], dout[h], 4, 4, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[5].xy[0], dout[h], 4, 5, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[5].xy[1], dout[h], 4, 5, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[6].xy[0], dout[h], 4, 6, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[6].xy[1], dout[h], 4, 6, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[7].xy[0], dout[h], 4, 7, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[7].xy[1], dout[h], 4, 7, 0); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[0].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[0].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[1].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[1].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[2].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[2].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[3].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[3].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[4].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[4].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[5].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[5].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[6].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[6].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[7].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[7].xy[1], dout[h]); if constexpr (KHELOOP > 8) { - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[8].xy[0], dout[h], 4, 8, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[8].xy[1], dout[h], 4, 8, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[9].xy[0], dout[h], 4, 9, 0); - dout[h] = - GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[9].xy[1], dout[h], 4, 9, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[10].xy[0], dout[h], 4, - 10, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[10].xy[1], dout[h], 4, - 10, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[11].xy[0], dout[h], 4, - 11, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[11].xy[1], dout[h], 4, - 11, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[12].xy[0], dout[h], 4, - 12, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[12].xy[1], dout[h], 4, - 12, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[13].xy[0], dout[h], 4, - 13, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[13].xy[1], dout[h], 4, - 13, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[14].xy[0], dout[h], 4, - 14, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[14].xy[1], dout[h], 4, - 14, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[15].xy[0], dout[h], 4, - 15, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[15].xy[1], dout[h], 4, - 15, 0); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[8].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[8].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[9].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[9].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[10].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[10].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[11].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[11].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[12].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[12].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[13].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[13].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[14].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[14].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[15].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[15].xy[1], dout[h]); } // KHELOOP>8 dout[h] *= scale; } -// transpose dout so that 4 token ids are in each lane, and 4 heads are across 4 -// lanes -#pragma unroll + // transpose dout so that 4 token ids are in each lane, and 4 heads are across + // 4 lanes + #pragma unroll for (int h = 0; h < QHLOOP; h++) { floatx4 tmp = {0}; -#pragma unroll + #pragma unroll for (int i = 0; i < 4; i++) { const float B = (lane4id == i) ? 1.0f : 0.0f; // const float A = (global_token_idx < context_len) ? dout[h][i] : 0.0f; @@ -378,48 +498,48 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( const int lane4_token_idx = 4 * (global_token_idx >> 2); const int alibi_offset = lane4_token_idx - context_len + 1; if (alibi_slopes != nullptr) { -#pragma unroll + #pragma unroll for (int h = 0; h < QHLOOP; h++) { -#pragma unroll + #pragma unroll for (int i = 0; i < 4; i++) { dout[h][i] += alibi_slope[h] * (alibi_offset + i); } } } -#pragma unroll + #pragma unroll for (int h = 0; h < QHLOOP; h++) { qk_max[h] = -FLT_MAX; -#pragma unroll + #pragma unroll for (int i = 0; i < 4; i++) { qk_max[h] = (lane4_token_idx + i < context_len) ? fmaxf(qk_max[h], dout[h][i]) : qk_max[h]; } -#pragma unroll + #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask)); } } float exp_sum[QHLOOP]; -#pragma unroll + #pragma unroll for (int h = 0; h < QHLOOP; h++) { exp_sum[h] = 0.0f; -#pragma unroll + #pragma unroll for (int i = 0; i < 4; i++) { dout[h][i] = (lane4_token_idx + i < context_len) ? __expf(dout[h][i] - qk_max[h]) : 0.0f; exp_sum[h] += dout[h][i]; } -#pragma unroll + #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { exp_sum[h] += __shfl_xor(exp_sum[h], mask); } } -#pragma unroll + #pragma unroll for (int h = 0; h < QHLOOP; h++) { const int head_idx = 4 * h + lane4id; shared_qk_max[warpid][head_idx] = qk_max[h]; @@ -434,18 +554,18 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( max_logits + seq_idx * num_heads * max_num_partitions + partition_idx; float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + partition_idx; -#pragma unroll + #pragma unroll for (int h = 0; h < QHLOOP; h++) { float global_qk_max = -FLT_MAX; float warp_qk_max[NWARPS]; const int head_idx = 4 * h + lane4id; -#pragma unroll + #pragma unroll for (int w = 0; w < NWARPS; w++) { warp_qk_max[w] = shared_qk_max[w][head_idx]; global_qk_max = fmaxf(global_qk_max, warp_qk_max[w]); } float global_exp_sum = 0.0f; -#pragma unroll + #pragma unroll for (int w = 0; w < NWARPS; w++) { global_exp_sum += shared_exp_sum[w][head_idx] * __expf(warp_qk_max[w] - global_qk_max); @@ -462,56 +582,64 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } // logits[h] -> every 4 lanes hold 4 heads, each lane holds 4 tokens, there // are 4x16 tokens across warp - float16x4 logits[QHLOOP]; -#pragma unroll + _B16x4 logits[QHLOOP]; + #pragma unroll for (int h = 0; h < QHLOOP; h++) { -#pragma unroll - for (int i = 0; i < 4; i++) { - logits[h][i] = (scalar_t)dout[h][i]; - } + logits[h] = from_floatx4(dout[h]); } - __shared__ float16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1]; + __shared__ _B16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1]; if (warp_start_token_idx >= context_len) { // warp out of context -#pragma unroll + #pragma unroll for (int qh = 0; qh < QHLOOP; qh++) { -#pragma unroll + #pragma unroll for (int vh = 0; vh < VHELOOP; vh++) { vout_shared[qh][vh][laneid][warpid] = {0}; } } } else { // warp in context -// iterate across heads -#pragma unroll + // iterate across heads + #pragma unroll for (int qh = 0; qh < QHLOOP; qh++) { -// iterate over each v head elem (within head_size) -#pragma unroll + // iterate over each v head elem (within head_size) + #pragma unroll for (int vh = 0; vh < VHELOOP; vh++) { floatx4 acc = {0}; // iterate over tokens - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][0].xy[0], acc, 4, 0, 0); - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][0].xy[1], acc, 4, 1, 0); - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][1].xy[0], acc, 4, 2, 0); - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][1].xy[1], acc, 4, 3, 0); - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][2].xy[0], acc, 4, 4, 0); - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][2].xy[1], acc, 4, 5, 0); - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][3].xy[0], acc, 4, 6, 0); - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][3].xy[1], acc, 4, 7, 0); - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][4].xy[0], acc, 4, 8, 0); - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][4].xy[1], acc, 4, 9, 0); - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][5].xy[0], acc, 4, 10, 0); - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][5].xy[1], acc, 4, 11, 0); - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][6].xy[0], acc, 4, 12, 0); - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][6].xy[1], acc, 4, 13, 0); - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][7].xy[0], acc, 4, 14, 0); - acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][7].xy[1], acc, 4, 15, 0); - float16x4 tmp; -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp[i] = (scalar_t)acc[i]; - } - vout_shared[qh][vh][laneid][warpid] = tmp; + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][0].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][0].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][1].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][1].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][2].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][2].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][3].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][3].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][4].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][4].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][5].xy[0], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][5].xy[1], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][6].xy[0], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][6].xy[1], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][7].xy[0], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][7].xy[1], acc); + vout_shared[qh][vh][laneid][warpid] = from_floatx4(acc); } } } // warp in context @@ -519,7 +647,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( __syncthreads(); if (warpid == 0) { - float16x4 vout[QHLOOP][VHELOOP]; + _B16x4 vout[QHLOOP][VHELOOP]; // iterate across heads scalar_t* out_ptr; int out_num_partitions; @@ -531,38 +659,38 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( out_num_partitions = 1; out_ptr = final_out + seq_idx * num_heads * HEAD_SIZE; } -#pragma unroll + #pragma unroll for (int qh = 0; qh < QHLOOP; qh++) { -// iterate over each v head elem (within head_size) -#pragma unroll + // iterate over each v head elem (within head_size) + #pragma unroll for (int vh = 0; vh < VHELOOP; vh++) { vout[qh][vh] = {0}; -#pragma unroll + #pragma unroll for (int w = 0; w < NWARPS; w++) { - vout[qh][vh] += vout_shared[qh][vh][laneid][w]; + vout[qh][vh] = + addx4(vout[qh][vh], vout_shared[qh][vh][laneid][w]); } const int head_size_elem = vh * WARP_SIZE + laneid; -#pragma unroll + bit16_t* out_ptr_b16 = reinterpret_cast(out_ptr); + #pragma unroll for (int i = 0; i < 4; i++) { const int head_idx = 4 * qh + i; if (head_idx < GQA_RATIO) { - // out_ptr[(wg_start_head_idx + head_idx) * max_num_partitions * - // HEAD_SIZE + head_size_elem] = vout[qh][vh][i]; - out_ptr[(wg_start_head_idx + head_idx) * out_num_partitions * - HEAD_SIZE + - head_size_elem] = vout[qh][vh][i]; + out_ptr_b16[(wg_start_head_idx + head_idx) * out_num_partitions * + HEAD_SIZE + + head_size_elem] = vout[qh][vh][i]; } } } } } -#if 0 + #if 0 const int num_seqs = gridDim.x; const int global_token4id = global_token_idx/4; - #pragma unroll + #pragma unroll for (int t=0;t<4;t++) { - #pragma unroll + #pragma unroll for (int h=0;h= 1; mask /= 2) { max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask)); } @@ -643,7 +771,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( shared_exp_sums[threadIdx.x] = rescaled_exp_sum; shared_exp_sums[threadIdx.x + WARP_SIZE] = rescaled_exp_sum2; -#pragma unroll + #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { global_exp_sum += __shfl_xor(global_exp_sum, mask); } @@ -656,9 +784,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x; constexpr int MAX_NPAR = 64; scalar_t tmps[MAX_NPAR]; -#pragma unroll + const float dzero = 0.0f; + #pragma unroll for (int j = 0; j < MAX_NPAR; j++) { - tmps[j] = 0.0f; + tmps[j] = from_float(dzero); } const int last_partition_offset = (num_partitions - 1) * HEAD_SIZE; const int num_partition_offset = (num_partitions)*HEAD_SIZE; @@ -666,7 +795,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( constexpr int JCHUNK = 16; -#pragma unroll + #pragma unroll for (int j = 0; j < JCHUNK * HEAD_SIZE; j += HEAD_SIZE) { // lastj is last valid partition const int lastj_offset = @@ -677,7 +806,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( __syncthreads(); if (num_partitions > JCHUNK) { -#pragma unroll + #pragma unroll for (int j = JCHUNK * HEAD_SIZE; j < 2 * JCHUNK * HEAD_SIZE; j += HEAD_SIZE) { const int lastj_offset = @@ -687,7 +816,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( } if (num_partitions > 2 * JCHUNK) { -#pragma unroll + #pragma unroll for (int j = 2 * JCHUNK * HEAD_SIZE; j < MAX_NPAR * HEAD_SIZE; j += HEAD_SIZE) { const int lastj_offset = @@ -700,26 +829,26 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( // Aggregate tmp_out to out. float acc = 0.0f; -#pragma unroll + #pragma unroll for (int j = 0; j < JCHUNK; j++) { - acc += tmps[j] * shared_exp_sums[j]; + acc += to_float(tmps[j]) * shared_exp_sums[j]; } if (num_partitions > JCHUNK) { -#pragma unroll + #pragma unroll for (int j = JCHUNK; j < 2 * JCHUNK; j++) { - acc += tmps[j] * shared_exp_sums[j]; + acc += to_float(tmps[j]) * shared_exp_sums[j]; } if (num_partitions > 2 * JCHUNK) { -#pragma unroll + #pragma unroll for (int j = 2 * JCHUNK; j < MAX_NPAR; j++) { - acc += tmps[j] * shared_exp_sums[j]; + acc += to_float(tmps[j]) * shared_exp_sums[j]; } } } if (num_partitions > MAX_NPAR) { idx = 0; -#pragma unroll + #pragma unroll for (int j = MAX_NPAR * HEAD_SIZE; j < 2 * MAX_NPAR * HEAD_SIZE; j += HEAD_SIZE) { // lastj is last valid partition @@ -729,21 +858,66 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( idx++; } -#pragma unroll + #pragma unroll for (int j = 0; j < MAX_NPAR; j++) { - acc += tmps[j] * shared_exp_sums[j + MAX_NPAR]; + acc += to_float(tmps[j]) * shared_exp_sums[j + MAX_NPAR]; } } const float inv_global_exp_sum = __fdividef(1.0f, shared_global_exp_sum + 1e-6f); acc *= inv_global_exp_sum; - // from_float(out_ptr[threadIdx.x], acc); scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; - out_ptr[threadIdx.x] = (scalar_t)acc; + out_ptr[threadIdx.x] = from_float(acc); } +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +template +__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] + #if 0 + scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size] + #endif + int max_ctx_blocks) { + UNREACHABLE_CODE +} + +// Grid: (num_heads, num_seqs). +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_partitions){UNREACHABLE_CODE} + +#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support + #define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \ paged_attention_ll4mi_QKV_kernel \ <<>>( \ @@ -886,9 +1060,6 @@ void paged_attention_custom_launcher( #define CALL_CUSTOM_LAUNCHER_BLK(T, HEAD_SIZE) \ switch (block_size) { \ - case 8: \ - CALL_CUSTOM_LAUNCHER(T, 8, HEAD_SIZE); \ - break; \ case 16: \ CALL_CUSTOM_LAUNCHER(T, 16, HEAD_SIZE); \ break; \ @@ -934,9 +1105,12 @@ void paged_attention_custom( #endif const c10::optional& alibi_slopes, const std::string& kv_cache_dtype) { + assert(kv_cache_dtype == "auto"); const int head_size = query.size(2); if (query.dtype() == at::ScalarType::Half) { CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16); } else { TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); } diff --git a/gradlib/gradlib/GemmTuner.py b/gradlib/gradlib/GemmTuner.py index 50fb8e70bd471..8e10934f7f7ef 100644 --- a/gradlib/gradlib/GemmTuner.py +++ b/gradlib/gradlib/GemmTuner.py @@ -1,3 +1,4 @@ +import os import random from pathlib import Path @@ -13,6 +14,8 @@ rtol = 1e-5 atol = 1 +CACHE_INVALIDATE_BUFFERS = int(os.getenv("CACHE_INVALIDATE_BUFFERS", "37")) + class Gemm: @@ -24,7 +27,7 @@ def __init__(self, m, n, k, indtype, outdtype, rocblas_decode=False): self.outdtype = outdtype self.use_rocblas = (indtype == outdtype and indtype is not torch.float8_e4m3fnuz) - self.nb = 37 + self.nb = CACHE_INVALIDATE_BUFFERS self.inp = torch.randn((self.n, self.k), device='cuda').to(self.indtype) self.weights = torch.randn((self.m, self.k), @@ -73,8 +76,8 @@ def check_gemm_ref(self, libtype, solidx): self.outdtype) elif libtype == 'rocblas': c = rocsolidxgemm.rocb_mm(self.inp, self.weights.t(), solidx) - if torch.allclose(c.to(torch.float32), - ref.to(torch.float32), + if torch.allclose(c.to(self.outdtype), + ref.to(self.outdtype), atol=self.atol, rtol=self.rtol): return True @@ -283,6 +286,9 @@ def find_best_sols(self): soldf.loc[i, 'libtype'] = gemmobj.best_libtype soldf.loc[i, 'solidx'] = gemmobj.best_solidx soldf.loc[i, 'soltimems'] = gemmobj.best_soltime + del gemmobj + torch.cuda.empty_cache() + soldf['indtype'] = self.indtype soldf['outdtype'] = self.outdtype finaldf = pd.concat([self.gemm_problems, soldf], axis=1) diff --git a/tests/kernels/test_attention_custom.py b/tests/kernels/test_attention_custom.py index 5bdbf126c22fa..6ecc348e017e9 100644 --- a/tests/kernels/test_attention_custom.py +++ b/tests/kernels/test_attention_custom.py @@ -3,35 +3,29 @@ import pytest import torch -from allclose_default import get_default_atol, get_default_rtol -from vllm._C import cache_ops, ops +from vllm import _custom_ops as ops from vllm._custom_C import paged_attention_custom -from vllm.utils import get_max_shared_memory_bytes, is_hip +from vllm.utils import is_hip -FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 -# This will change depending on the compute capability. -# - 512 as a buffer -MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 +from .allclose_default import get_default_atol, get_default_rtol + +MAX_SEQ_LEN = 32 * 1024 # There may not be enough gpu memory due to large NUM_BLOCKS. # Reduce NUM_BLOCKS when it happens. -NUM_BLOCKS = 4321 # Arbitrary values for testing +NUM_BLOCKS = 128 * 1024 + 4321 # Arbitrary values for testing PARTITION_SIZE = 256 -# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16} -DTYPES = [torch.half, torch.bfloat16, torch.float - ] if not is_hip() else [torch.half] -NUM_GEN_SEQS = [1, 17, 64] # Arbitrary values for testing +DTYPES = [torch.bfloat16, torch.half] +NUM_GEN_SEQS = [1, 17] # Arbitrary values for testing NUM_HEADS = [(8 * x, 8) for x in range(1, 17)] # Arbitrary values for testing -# FlashAttention forward only supports head dimension at most 128 -# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 -HEAD_SIZES = [128] -BLOCK_SIZES = [16] -USE_ALIBI = [False, True] +HEAD_SIZES = [64, 128] +BLOCK_SIZES = [16, 32] +USE_ALIBI = [True, False] KV_CACHE_DTYPE = ["auto"] -SEEDS = [0] +SEEDS = [37] CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 1) ] @@ -254,14 +248,14 @@ def test_paged_attention( dequantized_key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device=device) - cache_ops.convert_fp8(key_cache, dequantized_key_cache) + ops.convert_fp8(key_cache, dequantized_key_cache) key_cache = dequantized_key_cache value_cache_shape = value_cache.shape dequantized_value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device) - cache_ops.convert_fp8(value_cache, dequantized_value_cache) + ops.convert_fp8(value_cache, dequantized_value_cache) value_cache = dequantized_value_cache ref_output = torch.empty_like(query) @@ -285,8 +279,14 @@ def test_paged_attention( # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, # so we use a relaxed tolerance for the test. - atol, rtol = 1e-3, 1e-5 - atol = 5e-3 + atol, rtol = 1e-4, 1e-5 + if dtype == torch.bfloat16: + atol, rtol = 2e-4, 1e-5 + if use_alibi: + if dtype == torch.half: + atol, rtol = 5e-4, 1e-5 + if dtype == torch.bfloat16: + atol, rtol = 1e-3, 1e-5 if kv_cache_dtype == "fp8": atol, rtol = 1e-2, 1e-5 assert torch.allclose(output, ref_output, atol=atol, rtol=rtol) diff --git a/tests/kernels/test_awq_triton.py b/tests/kernels/test_awq_triton.py new file mode 100644 index 0000000000000..55ab0451d828f --- /dev/null +++ b/tests/kernels/test_awq_triton.py @@ -0,0 +1,237 @@ +"""Tests for the AWQ Triton kernel. + +Run `pytest tests/kernels/test_awq_triton.py`. +""" +import argparse + +import pytest +import torch + +from vllm.model_executor.layers.quantization.awq_triton import ( + awq_dequantize_triton, awq_gemm_triton) + +device = "cuda" + +dequantize_threshold = 0.5 +# This seems large, but this is using float16 with splitK and large sizes. +gemm_threshold = 6 + + +def reverse_awq_order(t: torch.Tensor): + bits = 4 + AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] + reverse_order_tensor = torch.arange( + t.shape[-1], + dtype=torch.int32, + device=t.device, + ) + reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits) + reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER] + reverse_order_tensor = reverse_order_tensor.view(-1) + + t = t[:, reverse_order_tensor] & 0xF + return t + + +# qweights - [R , C // 8], int32 +# scales - [R // G, C ], float16 +# zeros - [R // G, C // 8], int32 +def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor, + qzeros: torch.Tensor) -> torch.Tensor: + bits = 4 + group_size = 128 + shifts = torch.arange(0, 32, bits, device=qzeros.device) + + iweights = torch.bitwise_right_shift(qweight[:, :, None], + shifts[None, None, :]).to(torch.int8) + + iweights = iweights.view(iweights.shape[0], -1) + + zeros = torch.bitwise_right_shift(qzeros[:, :, None], + shifts[None, None, :]).to(torch.int8) + zeros = zeros.view(qzeros.shape[0], -1) + zeros = reverse_awq_order(zeros) + + iweights = reverse_awq_order(iweights) + + iweights = torch.bitwise_and(iweights, (2**bits) - 1) + zeros = torch.bitwise_and(zeros, (2**bits) - 1) + + scales = scales.repeat_interleave(group_size, dim=0) + zeros = zeros.repeat_interleave(group_size, dim=0) + return (iweights - zeros) * scales, zeros + + +# input - [N, K] +# qweight - [K, M // 8] +# qzeros - [K // G, M // 8] +# scales - [K // G, M] +# split_k_iters - parallelism along K-dimension, int, power of 2. +def awq_gemm_torch(input: torch.Tensor, qweight: torch.Tensor, + scales: torch.Tensor, qzeros: torch.Tensor, + split_k_iters: int) -> torch.Tensor: + input_rows, input_cols = input.shape + qweight_rows, qweight_cols = qweight.shape + scales_rows, scales_cols = scales.shape + print(f"awq_gemm_torch:input_rows = {input_rows} input_cols = {input_cols}" + f" qweight_rows = {qweight_rows} qweight_cols = {qweight_cols}" + f" scales_rows = {scales_rows} scales_cols = {scales_cols}") + weights, zeros = awq_dequantize_torch(qweight, scales, qzeros) + return torch.matmul(input, weights) + + +# qweights - [R , C // 8], int32 +# scales - [R // G, C ], float16 +# zeros - [R // G, C // 8], int32 + + +@pytest.mark.parametrize("qweight_rows", [3584, 18944, 128, 256, 512, 1024]) +@pytest.mark.parametrize("qweight_cols", [448, 576, 4736, 16, 32, 64, 128]) +def test_dequantize(qweight_rows, qweight_cols): + print("=" * 10 + " TESTING DEQUANTIZE" + "=" * 10) + + group_size = 128 + + print(f"qweight_rows = {qweight_rows}, qweight_cols = {qweight_cols}") + qweight_dtype = torch.int32 + scales_rows = qweight_rows // group_size + scales_cols = qweight_cols * 8 + scales_dtype = torch.float16 + zeros_rows = scales_rows + zeros_cols = qweight_cols + zeros_dtype = torch.int32 + + torch.manual_seed(0) + + qweight = torch.randint(0, + 10000000, (qweight_rows, qweight_cols), + dtype=qweight_dtype, + device=device) + scales = torch.rand(scales_rows, + scales_cols, + dtype=scales_dtype, + device=device) + zeros = torch.randint(0, + 10000000, (zeros_rows, zeros_cols), + dtype=zeros_dtype, + device=device) + print(f"qweight = {qweight}") + + iweights_triton = awq_dequantize_triton(qweight, scales, zeros) + + print(f"Triton result:iweights_triton = {iweights_triton}") + print("Any infs in triton result? -->" + f"{torch.any(torch.isinf(iweights_triton))}") + + iweights_torch, _ = awq_dequantize_torch(qweight, scales, zeros) + print(f"Torch result:iweights_torch = {iweights_torch}") + + diff = iweights_torch - iweights_triton + error = torch.sum(torch.sqrt(diff * diff)) + print(f"error = {error}") + + assert error < dequantize_threshold + + +# input - [N, K] +# qweight - [K, M // 8] +# qzeros - [K // G, M // 8] +# scales - [K // G, M] +@pytest.mark.parametrize("N", [1, 2, 4, 8, 14, 16, 32, 64, 128]) +@pytest.mark.parametrize("K", [3584, 18944, 128, 256, 512, 1024]) +@pytest.mark.parametrize("M", [448, 576, 4736, 16, 32, 64, 128]) +@pytest.mark.parametrize("splitK", [1, 8, 16]) +def test_gemm(N, K, M, splitK): + print("=" * 10 + " TESTING GEMM " + "=" * 10) + + split_k_iters = splitK + group_size = 128 + + input_rows = N + input_cols = K + input_dtype = torch.float16 + qweight_rows = input_cols + qweight_cols = M // 8 + scales_rows = qweight_rows // group_size + scales_cols = M + scales_dtype = torch.float16 + qzeros_rows = scales_rows + qzeros_cols = qweight_cols + print(f"input_rows = {input_rows} input_cols = {input_cols}" + f" qweight_rows = {qweight_rows} qweight_cols = {qweight_cols}" + f" scales_rows = {scales_rows} scales_cols = {scales_cols}") + + torch.manual_seed(2) + input = torch.rand((input_rows, input_cols), + dtype=input_dtype, + device=device) + qweight = torch.randint(0, + torch.iinfo(torch.int32).max, + (qweight_rows, qweight_cols), + device=device) + qzeros = torch.randint(0, + torch.iinfo(torch.int32).max, + (qzeros_rows, qzeros_cols), + device=device) + scales = torch.rand((scales_rows, scales_cols), + dtype=scales_dtype, + device=device) + + # NOTE: Use to see more data and accuracy during testing. + # import numpy as np + # import sys + # torch.set_printoptions(precision = 3, + # threshold=10000000000000000000000000000, + # sci_mode = False) + # np.set_printoptions(threshold=sys.maxsize) + + output_torch = awq_gemm_torch(input.cpu(), qweight.cpu(), scales.cpu(), + qzeros.cpu(), split_k_iters) + print(f"output_torch = {output_torch}") + + output_triton = awq_gemm_triton(input, qweight, scales, qzeros, + split_k_iters) + + print(f"output_triton = {output_triton}") + print(f"output_triton.shape = {output_triton.shape}") + print(f"Any infs in triton result? --> " + f"{torch.any(torch.isinf(output_triton))}") + + diff = output_torch.cpu() - output_triton.cpu() + error = torch.sum(torch.sqrt(diff * diff) / torch.numel(diff)) + print(f"error = {error}") + + assert error < gemm_threshold + + +def main(): + parser = argparse.ArgumentParser( + description="awq_triton test driver", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--test") + known_args, unknown_args = parser.parse_known_args() + if known_args.test is not None: + if known_args.test == "dequantize": + qweight_rows = 3584 + qweight_cols = 576 + small_test_size = False + if small_test_size: + qweight_rows = 256 + qweight_cols = 128 + test_dequantize(qweight_rows, qweight_cols) + elif known_args.test == "gemm": + small_test_size = True + N = 1 + K = 256 if small_test_size else 3584 + M = 32 if small_test_size else 448 + splitK = 1 + test_gemm(N, K, M, splitK) + else: + print(f"Unknown test {known_args.test}") + else: + print("No test provided.") + parser.print_help() + + +if __name__ == '__main__': + main() diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 2356b9ec18b0d..1673e93ba4ea1 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -4,9 +4,12 @@ """ import pytest import torch +from torch.nn import Parameter +from torch.nn import functional as F from transformers import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock +from vllm import envs from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.models.mixtral import MixtralMoE @@ -48,8 +51,15 @@ def test_fused_moe( w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 score = torch.randn((m, e), device='cuda', dtype=dtype) - triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) torch_output = torch_moe(a, w1, w2, score, topk) + + # Pad the input if use padding + if envs.VLLM_MOE_PADDING: + w1 = F.pad(w1, (0, 128), "constant", 0) + torch.cuda.empty_cache() + w2 = F.pad(w2, (0, 128), "constant", 0) + torch.cuda.empty_cache() + triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0) @@ -85,6 +95,15 @@ def test_mixtral_moe(dtype: torch.dtype): # vLLM uses 1D query [num_tokens, hidden_dim] vllm_inputs = hf_inputs.flatten(0, 1) + # pad the weight if using padding + if envs.VLLM_MOE_PADDING: + w13_weight = F.pad(vllm_moe.w13_weight, (0, 128), "constant", 0) + torch.cuda.empty_cache() + w2_weight = F.pad(vllm_moe.w2_weight, (0, 128), "constant", 0) + torch.cuda.empty_cache() + vllm_moe.w13_weight = Parameter(w13_weight, requires_grad=False) + vllm_moe.w2_weight = Parameter(w2_weight, requires_grad=False) + # Run forward passes for both MoE blocks hf_states, _ = hf_moe.forward(hf_inputs) vllm_states = vllm_moe.forward(vllm_inputs) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 07e9c9d119906..1bcf6116c158d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2,11 +2,14 @@ import torch +import vllm.envs as envs +from vllm.utils import is_hip + try: from vllm._C import cache_ops as vllm_cache_ops from vllm._C import ops as vllm_ops -except ImportError: - pass +except ImportError as e: + print(f"Failed to import from vllm._C with {e}") # activation ops @@ -128,12 +131,42 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, split_k_iters: int, thx: int, thy: int) -> torch.Tensor: + print(f"awq_dequantize:qweight.shape = {qweight.shape}" + f"scales = {scales.shape}," + f"zeros = {zeros.shape}," + f"split_k_iters = {split_k_iters}," + f"thx = {thx}" + f"thy = {thy}") + if is_hip() and envs.VLLM_USE_TRITON_AWQ: + from vllm.model_executor.layers.quantization.awq_triton import ( + awq_dequantize_triton) + return awq_dequantize_triton(qweight, scales, zeros) + + if is_hip(): + return torch.zeros(qweight.shape[0], + 8 * qweight.shape[1], + device=qweight.device, + dtype=torch.float16) return vllm_ops.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, thy) def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: + if input.shape[0] > 1: + print(f"awq_gemm:input.shape = {input.shape}," + f"qweight = {qweight.shape}," + f"qzeros = {qzeros.shape}," + f"scales.shape = {scales.shape}," + f"split_k_iters = {split_k_iters}") + if is_hip() and envs.VLLM_USE_TRITON_AWQ: + from vllm.model_executor.layers.quantization.awq_triton import ( + awq_gemm_triton) + return awq_gemm_triton(input, qweight, qzeros, scales, split_k_iters) + if is_hip(): + return torch.zeros((input.shape[0], qweight.shape[1] * 8), + device=qweight.device, + dtype=torch.float16) return vllm_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 0d3ee47193306..d78aa975f61ff 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -1,4 +1,3 @@ -import os from dataclasses import dataclass from typing import List, Optional, Tuple @@ -6,10 +5,11 @@ from vllm import _custom_ops as ops from vllm.attention.ops.prefix_prefill import context_attention_fwd +from vllm.envs import VLLM_USE_ROCM_CUSTOM_PAGED_ATTN from vllm.utils import is_hip -custom_attn_available = is_hip() and \ - (os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN", "1") != "0") +custom_attn_available = is_hip() and VLLM_USE_ROCM_CUSTOM_PAGED_ATTN and \ + "gfx1" not in torch.cuda.get_device_properties('cuda').gcnArchName if custom_attn_available: from vllm._custom_C import paged_attention_custom @@ -117,8 +117,11 @@ def forward_decode( block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape gqa_ratio = num_heads // num_kv_heads - use_custom = (custom_attn_available and query.dtype == torch.half - and head_size == 128 and block_size == 16 + use_custom = (custom_attn_available + and (query.dtype == torch.half + or query.dtype == torch.bfloat16) + and (head_size == 128 or head_size == 64) + and (block_size == 16 or block_size == 32) and kv_cache_dtype == "auto" and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768) diff --git a/vllm/config.py b/vllm/config.py index 409b0e1a44f7a..84ff7e1e30f53 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -11,7 +11,8 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.models import ModelRegistry from vllm.transformers_utils.config import get_config, get_hf_text_config -from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron +from vllm.utils import (get_cpu_memory, is_cpu, is_hip, is_neuron, + print_warning_once) if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -133,6 +134,17 @@ def __init__( code_revision, rope_scaling) self.hf_text_config = get_hf_text_config(self.hf_config) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) + + if (not self.disable_sliding_window + and self.hf_text_config.model_type == "gemma2" + and self.hf_text_config.sliding_window is not None): + print_warning_once( + "Gemma 2 uses sliding window attention for every odd layer, " + "which is currently not supported by vLLM. Disabling sliding " + "window and capping the max length to the sliding window size " + f"({self.hf_text_config.sliding_window}).") + self.disable_sliding_window = True + self.max_model_len = _get_and_verify_max_len( hf_config=self.hf_text_config, max_model_len=max_model_len, @@ -172,7 +184,7 @@ def _parse_quant_hf_config(self): def _verify_quantization(self) -> None: supported_quantization = [*QUANTIZATION_METHODS] - rocm_supported_quantization = ["gptq", "squeezellm", "fp8"] + rocm_supported_quantization = ["awq", "gptq", "squeezellm", "fp8"] if self.quantization is not None: self.quantization = self.quantization.lower() @@ -599,7 +611,7 @@ def __init__( if self.distributed_executor_backend is None and self.world_size > 1: if is_hip(): logger.info("Using torchrun for multi-GPU on " - "ROCM platform. Use --worker-use-ray or " + "ROCm platform. Use --worker-use-ray or " "--distributed-executor-backend={ray, mp} to " "override") if not os.environ.get("RANK"): @@ -1225,20 +1237,32 @@ def _get_and_verify_max_len( derived_max_model_len = default_max_len rope_scaling = getattr(hf_config, "rope_scaling", None) - if rope_scaling is not None and rope_scaling["type"] != "su": - if disable_sliding_window: - # TODO(robertgshaw): Find a model that supports rope_scaling - # with sliding window to see if this case should be allowed. - raise NotImplementedError( - "Disabling sliding window is not supported for models " - "with rope_scaling. Please raise an issue so we can " - "investigate.") - assert "factor" in rope_scaling - scaling_factor = rope_scaling["factor"] - if rope_scaling["type"] == "yarn": - derived_max_model_len = rope_scaling[ - "original_max_position_embeddings"] - derived_max_model_len *= scaling_factor + if rope_scaling is not None: + if "type" in rope_scaling: + rope_type = rope_scaling["type"] + elif "rope_type" in rope_scaling: + rope_type = rope_scaling["rope_type"] + else: + raise ValueError( + "rope_scaling must have a 'type' or 'rope_type' key.") + + # The correct one should be "longrope", kept "su" here + # to be backward compatible + if rope_type not in ("su", "longrope", "llama3"): + if disable_sliding_window: + # TODO(robertgshaw): Find a model that supports rope_scaling + # with sliding window to see if this case should be allowed. + raise NotImplementedError( + "Disabling sliding window is not supported for models " + "with rope_scaling. Please raise an issue so we can " + "investigate.") + + assert "factor" in rope_scaling + scaling_factor = rope_scaling["factor"] + if rope_type == "yarn": + derived_max_model_len = rope_scaling[ + "original_max_position_embeddings"] + derived_max_model_len *= scaling_factor # If the user specified a max length, make sure it is smaller than the # derived length from the HF model config. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index e461feb5e05a7..059407c05b220 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -233,7 +233,7 @@ def add_cli_args( help='Backend to use for distributed serving. When more than 1 GPU ' 'is used, on CUDA this will be automatically set to "ray" if ' 'installed or "mp" (multiprocessing) otherwise. On ROCm, this is ' - 'instead automatically set to torchrun.') + 'instead set to torchrun by default.') parser.add_argument( '--worker-use-ray', action='store_true', diff --git a/vllm/entrypoints/fast_sync_llm.py b/vllm/entrypoints/fast_sync_llm.py index 84f094a749edf..cc86528e4d62a 100644 --- a/vllm/entrypoints/fast_sync_llm.py +++ b/vllm/entrypoints/fast_sync_llm.py @@ -2,9 +2,12 @@ from queue import Empty from typing import Union +from vllm import envs from vllm.distributed.communication_op import broadcast_tensor_dict from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine +from vllm.executor.multiproc_gpu_executor import MultiprocessingGPUExecutor +from vllm.executor.ray_gpu_executor import RayGPUExecutor from vllm.inputs import PromptInputs, TextTokensPrompt from vllm.logger import init_logger from vllm.pooling_params import PoolingParams @@ -33,6 +36,7 @@ def __init__( self.result_queue = result_queue self.finish = False self.need_restart = False + self.llm_engine: LLMEngine def _add_request( self, @@ -49,7 +53,9 @@ def _poll_requests(self): if not self.llm_engine.has_unfinished_requests(): logger.info("No unfinished requests. Waiting...") (request_id, prompt, sampling_params) = self.input_queue.get() - if self.need_restart: + if self.need_restart and isinstance( + self.llm_engine.model_executor, + MultiprocessingGPUExecutor): logger.info("Restarting worker loops") for worker in self.llm_engine.model_executor.workers: worker.execute_method("start_worker_execution_loop") @@ -66,13 +72,24 @@ def _poll_requests(self): def run_engine(self): self.llm_engine = LLMEngine.from_engine_args( self.engine_args, usage_context=UsageContext.LLM_CLASS) + assert not isinstance( + self.llm_engine.model_executor, + RayGPUExecutor), "Ray is not supported in sync openai mode" self.result_queue.put(("Ready", None, None)) request_stats = {} log_interval = 100 + poll_interval = envs.VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS try: while True: - self._poll_requests() + poll_interval -= 1 + if (self.input_queue.qsize() >= + envs.VLLM_SYNC_SERVER_ACCUM_REQUESTS + or poll_interval <= 0 + or not self.llm_engine.has_unfinished_requests()): + self._poll_requests() + poll_interval = \ + envs.VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS step_outputs = self.llm_engine.step() log_interval -= 1 if log_interval == 0: diff --git a/vllm/entrypoints/sync_openai/api_server.py b/vllm/entrypoints/sync_openai/api_server.py index 153a887a67b20..845e6de370ae4 100644 --- a/vllm/entrypoints/sync_openai/api_server.py +++ b/vllm/entrypoints/sync_openai/api_server.py @@ -4,23 +4,34 @@ import threading import time from contextlib import asynccontextmanager -from typing import Dict +from http import HTTPStatus +from typing import Dict, Iterable, List, Union, cast import uvicorn from fastapi import FastAPI, Request -from fastapi.responses import StreamingResponse +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, StreamingResponse from fastapi.routing import Mount +from openai.types.chat import ChatCompletionContentPartTextParam from prometheus_client import make_asgi_app +import vllm from vllm import FastSyncLLM as LLM from vllm import envs from vllm.engine.arg_utils import EngineArgs from vllm.entrypoints.openai.cli_args import make_arg_parser -from vllm.entrypoints.sync_openai.protocol import (CompletionRequest, - CompletionResponse, - CompletionResponseChoice, - UsageInfo) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionContentPartParam, ChatCompletionMessageParam, + ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, ChatMessage, CompletionRequest, + CompletionResponse, CompletionResponseChoice, + CompletionResponseStreamChoice, CompletionStreamResponse, DeltaMessage, + ErrorResponse, ModelCard, ModelList, ModelPermission, UsageInfo) +from vllm.entrypoints.openai.serving_chat import (ChatMessageParseResult, + ConversationMessage) from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils import random_uuid mp = multiprocessing.get_context(envs.VLLM_WORKER_MULTIPROC_METHOD) @@ -40,7 +51,7 @@ class BackgroundRunner: def __init__(self): self.value = 0 - self.engine_args = None + self.engine_args: EngineArgs self.input_queue: multiprocessing.Queue = mp.Queue() self.result_queue: multiprocessing.Queue = mp.Queue() self.result_queues: Dict[str, asyncio.Queue] = {} @@ -48,6 +59,11 @@ def __init__(self): self.loop = None self.llm: LLM self.proc: multiprocessing.Process + self.tokenizer = None + self.response_role: str + + def set_response_role(self, role): + self.response_role = role def set_engine_args(self, engine_args): self.engine_args = engine_args @@ -74,6 +90,7 @@ async def run_main(self): input_queue=self.input_queue, result_queue=self.result_queue, ) + self.loop = asyncio.get_event_loop() self.proc = mp.Process(target=self.llm.run_engine) self.t.start() @@ -102,6 +119,15 @@ async def lifespan(app: FastAPI): asyncio.create_task(runner.run_main()) await runner.result_queues["Ready"].get() del runner.result_queues["Ready"] + + tokenizer = get_tokenizer( + engine_args.tokenizer, + tokenizer_mode=engine_args.tokenizer_mode, + tokenizer_revision=engine_args.tokenizer_revision, + trust_remote_code=engine_args.trust_remote_code, + truncation_side="left") + runner.tokenizer = tokenizer + yield @@ -114,6 +140,33 @@ async def lifespan(app: FastAPI): app.routes.append(route) +@app.get("/v1/models") +async def show_available_models(): + models = [ + ModelCard(id=runner.engine_args.model, + root=runner.engine_args.model, + permission=[ModelPermission()]) + ] + model_list = ModelList(data=models) + return JSONResponse(content=model_list.model_dump()) + + +@app.get("/version") +async def show_version(): + ver = {"version": vllm.__version__} + return JSONResponse(content=ver) + + +async def _check_model(request: Union[CompletionRequest, + ChatCompletionRequest]): + model = request.model + if model != runner.engine_args.model: + return ErrorResponse(message=f"The model {model} does not exist.", + type="NotFoundError", + code=HTTPStatus.NOT_FOUND) + return None + + async def completion_generator(model, result_queue, choices, created_time, ids): completed = 0 @@ -122,24 +175,25 @@ async def completion_generator(model, result_queue, choices, created_time, request_id, token, stats = await result_queue.get() choice_idx = choices[request_id] - res = CompletionResponse(id=request_id, - created=created_time, - model=model, - choices=[ - CompletionResponseChoice( - index=choice_idx, - text=token, - logprobs=None, - finish_reason=None, - stop_reason=None) - ], - usage=None) + res = CompletionStreamResponse(id=request_id, + created=created_time, + model=model, + choices=[ + CompletionResponseStreamChoice( + index=choice_idx, + text=token, + logprobs=None, + finish_reason=None, + stop_reason=None) + ], + usage=None) if stats is not None: res.usage = UsageInfo() res.usage.completion_tokens = stats.get("tokens", 0) res.usage.prompt_tokens = stats.get("prompt", 0) - res.usage.total_tokens = (res.usage.completion_tokens + - res.usage.prompt_tokens) + res.usage.total_tokens = ( + res.usage.completion_tokens + # type: ignore + res.usage.prompt_tokens) res.choices[0].finish_reason = stats["finish_reason"] res.choices[0].stop_reason = stats["stop_reason"] completed += 1 @@ -157,6 +211,10 @@ async def completion_generator(model, result_queue, choices, created_time, @app.post("/v1/completions") async def completions(request: CompletionRequest, raw_request: Request): + error_check_ret = await _check_model(request) + if error_check_ret is not None: + return JSONResponse(content=error_check_ret.model_dump(), + status_code=error_check_ret.code) sampling_params = request.to_sampling_params() ids, result_queue = await runner.add_request(request.prompt, sampling_params) @@ -198,6 +256,153 @@ async def completions(request: CompletionRequest, raw_request: Request): return res +def parse_chat_message_content_parts( + role: str, + parts: Iterable[ChatCompletionContentPartParam], +) -> ChatMessageParseResult: + texts: List[str] = [] + + for _, part in enumerate(parts): + part_type = part["type"] + if part_type == "text": + text = cast(ChatCompletionContentPartTextParam, part)["text"] + + texts.append(text) + else: + raise NotImplementedError(f"Unknown part type: {part_type}") + + messages = [ConversationMessage(role=role, content="\n".join(texts))] + + return ChatMessageParseResult(messages=messages) + + +def parse_chat_message_content( + message: ChatCompletionMessageParam, ) -> ChatMessageParseResult: + role = message["role"] + content = message.get("content") + + if content is None: + return ChatMessageParseResult(messages=[]) + if isinstance(content, str): + messages = [ConversationMessage(role=role, content=content)] + return ChatMessageParseResult(messages=messages) + + return parse_chat_message_content_parts(role, content) + + +async def chat_completion_generator(model, result_queue, created_time, id): + try: + first_token = ChatCompletionStreamResponse( + id=id, + created=created_time, + model=model, + choices=[ + ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(role=runner.response_role), + logprobs=None, + finish_reason=None, + stop_reason=None) + ], + usage=None) + response_json = first_token.model_dump_json(exclude_unset=True) + yield f"data: {response_json}\n\n" + + while True: + request_id, token, stats = await result_queue.get() + assert request_id == id + + res = ChatCompletionStreamResponse( + id=request_id, + created=created_time, + model=model, + choices=[ + ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(content=token), + logprobs=None, + finish_reason=None, + stop_reason=None) + ], + usage=None) + if stats is not None: + res.usage = UsageInfo() + res.usage.completion_tokens = stats.get("tokens", 0) + res.usage.prompt_tokens = stats.get("prompt", 0) + res.usage.total_tokens = ( + res.usage.completion_tokens + # type: ignore + res.usage.prompt_tokens) + res.choices[0].finish_reason = stats["finish_reason"] + res.choices[0].stop_reason = stats["stop_reason"] + response_json = res.model_dump_json(exclude_unset=True) + yield f"data: {response_json}\n\n" + if stats is not None: + runner.remove_result_queues([id]) + break + + yield "data: [DONE]\n\n" + except Exception as e: + logger.error("Error in completion_generator: %s", e) + return + + +@app.post("/v1/chat/completions") +async def chat_completions(request: ChatCompletionRequest, + raw_request: Request): + error_check_ret = await _check_model(request) + if error_check_ret is not None: + return JSONResponse(content=error_check_ret.model_dump(), + status_code=error_check_ret.code) + sampling_params = request.to_sampling_params() + conversation: List[ConversationMessage] = [] + + res = ChatCompletionResponse(model=request.model, + choices=[], + usage=UsageInfo(prompt_tokens=0, + total_tokens=0, + completion_tokens=0)) + + for msg in request.messages: + parsed_msg = parse_chat_message_content(msg) + conversation.extend(parsed_msg.messages) + + prompt = runner.tokenizer.apply_chat_template( # type: ignore + conversation=conversation, + tokenize=False, + add_generation_prompt=request.add_generation_prompt, + ) + + ids, result_queue = await runner.add_request(prompt, sampling_params) + assert len(ids) == 1 + + if request.stream: + created_time = int(time.time()) + return StreamingResponse(content=chat_completion_generator( + request.model, result_queue, created_time, ids[0]), + media_type="text/event-stream") + + res.choices.append( + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role=runner.response_role, content=""), + finish_reason=None, + stop_reason=None)) + + while True: + _, token, stats = await result_queue.get() + res.choices[0].message.content += str(token) + if stats is not None: + res.usage.completion_tokens += stats["tokens"] # type: ignore + res.usage.prompt_tokens += stats["prompt"] # type: ignore + res.choices[0].finish_reason = stats["finish_reason"] + res.choices[0].stop_reason = stats["stop_reason"] + runner.remove_result_queues(ids) + break + res.usage.total_tokens = ( # type: ignore + res.usage.completion_tokens + res.usage.prompt_tokens) # type: ignore + return res + + def parse_args(): parser = make_arg_parser() return parser.parse_args() @@ -207,4 +412,14 @@ def parse_args(): args = parse_args() engine_args = EngineArgs.from_cli_args(args) runner.set_engine_args(engine_args) + runner.set_response_role(args.response_role) + + app.add_middleware( + CORSMiddleware, + allow_origins=args.allowed_origins, + allow_credentials=args.allow_credentials, + allow_methods=args.allowed_methods, + allow_headers=args.allowed_headers, + ) + uvicorn.run(app, port=args.port, host=args.host) diff --git a/vllm/entrypoints/sync_openai/protocol.py b/vllm/entrypoints/sync_openai/protocol.py deleted file mode 100644 index c9e0537110c17..0000000000000 --- a/vllm/entrypoints/sync_openai/protocol.py +++ /dev/null @@ -1,170 +0,0 @@ -import time -from typing import Dict, List, Optional, Union - -import torch -from pydantic import BaseModel, Field -from typing_extensions import Annotated - -from vllm import SamplingParams -from vllm.utils import random_uuid - - -class UsageInfo(BaseModel): - prompt_tokens: int = 0 - total_tokens: int = 0 - completion_tokens: int = 0 - - -class CompletionLogProbs(BaseModel): - text_offset: List[int] = Field(default_factory=list) - token_logprobs: List[Optional[float]] = Field(default_factory=list) - tokens: List[str] = Field(default_factory=list) - top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None - - -class CompletionResponseChoice(BaseModel): - index: int - text: str - logprobs: Optional[CompletionLogProbs] = None - finish_reason: Optional[str] = None - stop_reason: Optional[Union[int, str]] = Field( - default=None, - description=( - "The stop string or token id that caused the completion " - "to stop, None if the completion finished for some other reason " - "including encountering the EOS token"), - ) - - -class CompletionResponse(BaseModel): - id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") - object: str = "text_completion" - created: int = Field(default_factory=lambda: int(time.time())) - model: str - choices: List[CompletionResponseChoice] - usage: Optional[UsageInfo] = Field(default=None) - - -class CompletionRequest(BaseModel): - # Ordered by official OpenAI API documentation - # https://platform.openai.com/docs/api-reference/completions/create - model: str - prompt: Union[List[int], List[List[int]], str, List[str]] - best_of: Optional[int] = None - echo: Optional[bool] = False - frequency_penalty: Optional[float] = 0.0 - logit_bias: Optional[Dict[str, float]] = None - logprobs: Optional[int] = None - max_tokens: Optional[int] = 16 - n: int = 1 - presence_penalty: Optional[float] = 0.0 - seed: Optional[int] = Field(None, - ge=torch.iinfo(torch.long).min, - le=torch.iinfo(torch.long).max) - stop: Optional[Union[str, List[str]]] = Field(default_factory=list) - stream: Optional[bool] = False - suffix: Optional[str] = None - temperature: Optional[float] = 1.0 - top_p: Optional[float] = 1.0 - user: Optional[str] = None - - # doc: begin-completion-sampling-params - use_beam_search: Optional[bool] = False - top_k: Optional[int] = -1 - min_p: Optional[float] = 0.0 - repetition_penalty: Optional[float] = 1.0 - length_penalty: Optional[float] = 1.0 - early_stopping: Optional[bool] = False - stop_token_ids: Optional[List[int]] = Field(default_factory=list) - ignore_eos: Optional[bool] = False - min_tokens: Optional[int] = 0 - skip_special_tokens: Optional[bool] = True - spaces_between_special_tokens: Optional[bool] = True - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None - # doc: end-completion-sampling-params - - # doc: begin-completion-extra-params - include_stop_str_in_output: Optional[bool] = Field( - default=False, - description=( - "Whether to include the stop string in the output. " - "This is only applied when the stop or stop_token_ids is set."), - ) - guided_json: Optional[Union[str, dict, BaseModel]] = Field( - default=None, - description=("If specified, the output will follow the JSON schema."), - ) - guided_regex: Optional[str] = Field( - default=None, - description=( - "If specified, the output will follow the regex pattern."), - ) - guided_choice: Optional[List[str]] = Field( - default=None, - description=( - "If specified, the output will be exactly one of the choices."), - ) - guided_grammar: Optional[str] = Field( - default=None, - description=( - "If specified, the output will follow the context free grammar."), - ) - guided_decoding_backend: Optional[str] = Field( - default=None, - description=( - "If specified, will override the default guided decoding backend " - "of the server for this specific request. If set, must be one of " - "'outlines' / 'lm-format-enforcer'")) - guided_whitespace_pattern: Optional[str] = Field( - default=None, - description=( - "If specified, will override the default whitespace pattern " - "for guided json decoding.")) - - # doc: end-completion-extra-params - - def to_sampling_params(self): - echo_without_generation = self.echo and self.max_tokens == 0 - - logits_processors = None - if self.logit_bias: - - def logit_bias_logits_processor( - token_ids: List[int], - logits: torch.Tensor) -> torch.Tensor: - assert self.logit_bias is not None - for token_id, bias in self.logit_bias.items(): - # Clamp the bias between -100 and 100 per OpenAI API spec - bias = min(100, max(-100, bias)) - logits[int(token_id)] += bias - return logits - - logits_processors = [logit_bias_logits_processor] - - return SamplingParams( - n=self.n, - best_of=self.best_of, - presence_penalty=self.presence_penalty, - frequency_penalty=self.frequency_penalty, - repetition_penalty=self.repetition_penalty, - temperature=self.temperature, - top_p=self.top_p, - top_k=self.top_k, - min_p=self.min_p, - seed=self.seed, - stop=self.stop, - stop_token_ids=self.stop_token_ids, - ignore_eos=self.ignore_eos, - max_tokens=self.max_tokens if not echo_without_generation else 1, - min_tokens=self.min_tokens, - logprobs=self.logprobs, - use_beam_search=self.use_beam_search, - early_stopping=self.early_stopping, - prompt_logprobs=self.logprobs if self.echo else None, - skip_special_tokens=self.skip_special_tokens, - spaces_between_special_tokens=(self.spaces_between_special_tokens), - include_stop_str_in_output=self.include_stop_str_in_output, - length_penalty=self.length_penalty, - logits_processors=logits_processors, - truncate_prompt_tokens=self.truncate_prompt_tokens, - ) diff --git a/vllm/envs.py b/vllm/envs.py index 35421b9026f1e..817fe2fc91b4d 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -9,6 +9,8 @@ VLLM_NCCL_SO_PATH: Optional[str] = None LD_LIBRARY_PATH: Optional[str] = None VLLM_USE_TRITON_FLASH_ATTN: bool = True + VLLM_USE_ROCM_SKINNY_GEMM: bool = True + VLLM_USE_ROCM_CUSTOM_PAGED_ATTN: bool = True RANK: int = 0 LOCAL_RANK: int = 0 CUDA_VISIBLE_DEVICES: Optional[str] = None @@ -38,6 +40,9 @@ VLLM_INSTALL_PUNICA_KERNELS: bool = False CMAKE_BUILD_TYPE: Optional[str] = None VERBOSE: bool = False + VLLM_SYNC_SERVER_ACCUM_REQUESTS: int = 1 + VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS: int = 1 + VLLM_MOE_PADDING: bool = True # The begin-* and end* here are used by the documentation generator # to extract the used env vars. @@ -135,6 +140,16 @@ lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")), + # small gemms custom implementation for MI3* cards + "VLLM_USE_ROCM_SKINNY_GEMM": + lambda: (os.getenv("VLLM_USE_ROCM_SKINNY_GEMM", "True").lower() in + ("true", "1")), + + # custom paged attention implemented for MI3* cards + "VLLM_USE_ROCM_CUSTOM_PAGED_ATTN": + lambda: (os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in + ("true", "1") != "0"), + # rank of the process in the distributed setting, used to determine # the driver worker "RANK": @@ -219,6 +234,22 @@ # Both spawn and fork work "VLLM_WORKER_MULTIPROC_METHOD": lambda: os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn"), + + # Try to accumulate this many requests before proceeding + "VLLM_SYNC_SERVER_ACCUM_REQUESTS": + lambda: int(os.getenv("VLLM_SYNC_SERVER_ACCUM_REQUESTS", "1")), + + # Poll for new requests every this many steps + "VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS": + lambda: int(os.getenv("VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS", "1")), + + # Pad the weight for moe kernel or not + "VLLM_MOE_PADDING": + lambda: bool(int(os.getenv("VLLM_MOE_PADDING", "1"))), + + # If set, vllm will print verbose logs during installation + "VLLM_USE_TRITON_AWQ": + lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", '1'))), } # end-env-vars-definition diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 7a3c6ec773358..e759d63b588b3 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -10,9 +10,11 @@ import vllm._moe_C as moe_kernels from vllm import _custom_ops as ops +from vllm import envs from vllm.logger import init_logger logger = init_logger(__name__) +padding_size = 128 if envs.VLLM_MOE_PADDING else 0 @triton.jit @@ -262,7 +264,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, expert_ids, num_tokens_post_padded, B.shape[1], - B.shape[2], + B.shape[2] - padding_size, sorted_token_ids.shape[0], topk_ids.numel(), A.stride(0), @@ -365,7 +367,8 @@ def fused_experts(hidden_states: torch.Tensor, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None): # Check constraints. - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert hidden_states.shape[ + 1] == w1.shape[2] - padding_size, "Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" @@ -381,7 +384,7 @@ def fused_experts(hidden_states: torch.Tensor, config = override_config else: # First try to load optimal config from the file - configs = get_moe_configs(E, w2.shape[2], + configs = get_moe_configs(E, w2.shape[2] - padding_size, "float8" if use_fp8 else None) if configs: diff --git a/vllm/model_executor/layers/quantization/awq_triton.py b/vllm/model_executor/layers/quantization/awq_triton.py new file mode 100644 index 0000000000000..0d6faf4e0ca3c --- /dev/null +++ b/vllm/model_executor/layers/quantization/awq_triton.py @@ -0,0 +1,311 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def awq_dequantize_kernel( + qweight_ptr, # quantized matrix + scales_ptr, # scales, per group + zeros_ptr, # zeros, per group + group_size, # currently is always 128 when model quantized + result_ptr, # Output matrix + num_cols, # input num cols in qweight + num_rows, # input num rows in qweight + BLOCK_SIZE_X: tl.constexpr, + BLOCK_SIZE_Y: tl.constexpr): + # Setup the pids. + pid_x = tl.program_id(axis=0) + pid_y = tl.program_id(axis=1) + + # Compute offsets and masks for qweight_ptr. + offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y) + offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X * 8) // 8 + offsets = num_cols * offsets_y[:, None] + offsets_x[None, :] + + masks_y = offsets_y < num_rows + masks_x = offsets_x < num_cols + + masks = masks_y[:, None] & masks_x[None, :] + + # Compute offsets and masks for result output ptr. + result_offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y) + result_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange( + 0, BLOCK_SIZE_X * 8) + result_offsets = (8 * num_cols * result_offsets_y[:, None] + + result_offsets_x[None, :]) + + result_masks_y = result_offsets_y < num_rows + result_masks_x = result_offsets_x < num_cols * 8 + result_masks = result_masks_y[:, None] & result_masks_x[None, :] + + # Load the weights. + iweights = tl.load(qweight_ptr + offsets, masks) + + # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] + # that will map given indices to the correct order. + reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] + + tl.arange(0, 4)[:, None]).reshape(8) + + # Use this to compute a set of shifts that can be used to unpack and + # reorder the values in iweights and zeros. + shifts = reverse_awq_order_tensor * 4 + shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_Y * BLOCK_SIZE_X, 8)) + shifts = tl.reshape(shifts, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) + + # Unpack and reorder: shift out the correct 4-bit value and mask. + iweights = (iweights >> shifts) & 0xF + + # Compute zero offsets and masks. + zero_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size + + tl.arange(0, BLOCK_SIZE_Y) // group_size) + zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X * 8) // 8 + zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :] + + zero_masks_y = zero_offsets_y < num_rows // group_size + zero_masks_x = zero_offsets_x < num_cols + zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :] + + # Load the zeros. + zeros = tl.load(zeros_ptr + zero_offsets, zero_masks) + + # Unpack and reorder: shift out the correct 4-bit value and mask. + zeros = (zeros >> shifts) & 0xF + + # Compute scale offsets and masks. + scale_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size + + tl.arange(0, BLOCK_SIZE_Y) // group_size) + scale_offsets_x = (pid_x * BLOCK_SIZE_X * 8 + + tl.arange(0, BLOCK_SIZE_X * 8)) + scale_offsets = (num_cols * 8 * scale_offsets_y[:, None] + + scale_offsets_x[None, :]) + scale_masks_y = scale_offsets_y < num_rows // group_size + scale_masks_x = scale_offsets_x < num_cols * 8 + scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :] + + # Load the scales. + scales = tl.load(scales_ptr + scale_offsets, scale_masks) + + # Dequantize. + iweights = (iweights - zeros) * scales + iweights = iweights.to(result_ptr.type.element_ty) + + # Finally, store. + tl.store(result_ptr + result_offsets, iweights, result_masks) + + +@triton.jit +def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, + awq_group_size, BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + SPLIT_K: tl.constexpr): + pid = tl.program_id(axis=0) + pid_z = tl.program_id(1) + + # NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead. + # num_pid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + accumulator_dtype = c_ptr.type.element_ty + + # NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead. + # accumulator = tl.arange(0, BLOCK_SIZE_N) + # accumulator = tl.broadcast_to(accumulator[None, :], + # (BLOCK_SIZE_M, BLOCK_SIZE_N)) + # accumulator = accumulator & 0x0 + # accumulator = accumulator.to(accumulator_dtype) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), + dtype=accumulator_dtype) + + # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] + # that will map given indices to the correct order. + reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] + + tl.arange(0, 4)[:, None]).reshape(8) + + # Create the necessary shifts to use to unpack. + shifts = reverse_awq_order_tensor * 4 + shifts = tl.broadcast_to(shifts[None, :], + (BLOCK_SIZE_K * (BLOCK_SIZE_N // 8), 8)) + shifts = tl.reshape(shifts, (BLOCK_SIZE_K, BLOCK_SIZE_N)) + + # Offsets and masks. + offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + masks_am = offsets_am < M + + offsets_bn = (pid_n * (BLOCK_SIZE_N // 8) + + tl.arange(0, BLOCK_SIZE_N) // 8) + masks_bn = offsets_bn < N // 8 + + offsets_zn = (pid_n * (BLOCK_SIZE_N // 8) + + tl.arange(0, BLOCK_SIZE_N) // 8) + masks_zn = offsets_zn < N // 8 + + offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + masks_sn = offsets_sn < N + + offsets_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offsets_a = K * offsets_am[:, None] + offsets_k[None, :] + offsets_b = (N // 8) * offsets_k[:, None] + offsets_bn[None, :] + + a_ptrs = a_ptr + offsets_a + b_ptrs = b_ptr + offsets_b + + # NOTE: Use this in TRITON_INTERPRET=1 mode instead of tl.cdiv + # block_offset = BLOCK_SIZE_K * SPLIT_K + # for k in range(0, (K + block_offset - 1) // (block_offset)): + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + masks_k = offsets_k < K + masks_a = masks_am[:, None] & masks_k[None, :] + a = tl.load(a_ptrs, mask=masks_a) + + masks_b = masks_k[:, None] & masks_bn[None, :] + b = tl.load(b_ptrs, mask=masks_b) + + # Dequantize b. + offsets_szk = ((BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K) // + awq_group_size + + tl.arange(0, BLOCK_SIZE_K) // awq_group_size) + offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :] + masks_zk = offsets_szk < K // awq_group_size + masks_z = masks_zk[:, None] & masks_zn[None, :] + zeros_ptrs = zeros_ptr + offsets_z + zeros = tl.load(zeros_ptrs, mask=masks_z) + + offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :] + masks_sk = offsets_szk < K // awq_group_size + masks_s = masks_sk[:, None] & masks_sn[None, :] + scales_ptrs = scales_ptr + offsets_s + scales = tl.load(scales_ptrs, mask=masks_s) + + b = (b >> shifts) & 0xF + zeros = (zeros >> shifts) & 0xF + b = (b - zeros) * scales + b = b.to(c_ptr.type.element_ty) + + # Accumulate results. + accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype) + + offsets_k += BLOCK_SIZE_K * SPLIT_K + a_ptrs += BLOCK_SIZE_K * SPLIT_K + b_ptrs += BLOCK_SIZE_K * SPLIT_K * (N // 8) + + c = accumulator.to(c_ptr.type.element_ty) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + N * offs_cm[:, None] + offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + if SPLIT_K == 1: + tl.store(c_ptrs, c, mask=c_mask) + else: + tl.atomic_add(c_ptrs, c, mask=c_mask) + + +# Example input: +# qweight.size=torch.Size([3584, 576]), +# qweight.dtype = torch.int32, +# scales.size=torch.Size([28, 4608]), +# scales.dtype=torch.float16, +# zeros.size=torch.Size([28, 576]), +# zeros.dtype=torch.int32 +# split_k_iters=0 +# thx=0 +# thy=0 + + +# qweights - [K , M // 8], int32 +# scales - [K // G, M ], float16 +# zeros - [K // G, M // 8], int32 +def awq_dequantize_triton( + qweight: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, +) -> torch.Tensor: + K = qweight.shape[0] + M = scales.shape[1] + awq_group_size = 128 + + assert K > 0 and M > 0 + assert scales.shape[0] == K // awq_group_size and scales.shape[1] == M + assert zeros.shape[0] == K // awq_group_size and zeros.shape[1] == M // 8 + + # Result tensor: + # number of rows = same as input tensor + # number of cols = 8 x input tensor num cols + result = torch.empty(qweight.shape[0], + qweight.shape[1] * 8, + device=qweight.device, + dtype=scales.dtype) + + block_size_x = 32 + block_size_y = 32 + + Y = qweight.shape[0] # num rows + X = qweight.shape[1] # num cols + group_size = 128 + grid = lambda META: ( + triton.cdiv(X, META['BLOCK_SIZE_X']), + triton.cdiv(Y, META['BLOCK_SIZE_Y']), + ) + awq_dequantize_kernel[grid](qweight, + scales, + zeros, + group_size, + result, + X, + Y, + BLOCK_SIZE_X=block_size_x, + BLOCK_SIZE_Y=block_size_y) + + return result + + +# input - [M, K] +# qweight - [K, N // 8] +# qzeros - [K // G, N // 8] +# scales - [K // G, N] +# split_k_iters - parallelism along K-dimension, int, power of 2. +def awq_gemm_triton(input: torch.Tensor, qweight: torch.Tensor, + scales: torch.Tensor, qzeros: torch.Tensor, + split_k_iters: int) -> torch.Tensor: + M, K = input.shape + N = qweight.shape[1] * 8 + awq_group_size = 128 + + assert N > 0 and K > 0 and M > 0 + assert qweight.shape[0] == K and qweight.shape[1] == N // 8 + assert qzeros.shape[0] == K // awq_group_size and qzeros.shape[1] == N // 8 + assert scales.shape[0] == K // awq_group_size and scales.shape[1] == N + assert split_k_iters & (split_k_iters - 1) == 0 and split_k_iters != 0 + assert split_k_iters <= 32 + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( + N, META['BLOCK_SIZE_N']), + split_k_iters, + ) + block_size_m = 32 + block_size_n = 32 + block_size_k = 32 + + result = torch.zeros((M, N), dtype=scales.dtype, device=input.device) + + # A = input, B = qweight, C = result + # A = M x K, B = K x N, C = M x N + awq_gemm_kernel[grid](input, + qweight, + result, + qzeros, + scales, + M, + N, + K, + awq_group_size, + BLOCK_SIZE_M=block_size_m, + BLOCK_SIZE_N=block_size_n, + BLOCK_SIZE_K=block_size_k, + SPLIT_K=split_k_iters) + + return result diff --git a/vllm/model_executor/layers/quantization/fp8_rocm.py b/vllm/model_executor/layers/quantization/fp8_rocm.py index 411d634503f23..20ffcfb6a8f91 100644 --- a/vllm/model_executor/layers/quantization/fp8_rocm.py +++ b/vllm/model_executor/layers/quantization/fp8_rocm.py @@ -242,6 +242,8 @@ def apply( k = x.shape[1] solidx = self._config._tuned.get((m, n, k), 0) + if solidx == 0: + self._config.save_shape(m, n, k) res = ops.fp8_mm(x_quant, weight.t(), out_dtype, asf, wsf, osf, int(solidx)) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index d03903d206d33..c15be15a9f0ce 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -503,6 +503,159 @@ def forward( return query.flatten(-2), key.flatten(-2) +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class DeepseekScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1, + mscale_all_dim: float = 0, + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation. + self.mscale = float( + yarn_get_mscale(self.scaling_factor, float(mscale)) / + yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * + attn_factor) + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base**(torch.arange( + 0, self.rotary_dim, 2, dtype=torch.float, device="cuda") / + self.rotary_dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, + self.rotary_dim, self.base, + self.max_position_embeddings) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = (1 - _yarn_linear_ramp_mask( + low, high, self.rotary_dim // 2, + dtype=torch.float)) * self.extrapolation_factor + inv_freq = inv_freq_interpolation * ( + 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange(self.max_position_embeddings * self.scaling_factor, + device="cuda", + dtype=torch.float32) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = (freqs.cos() * self.mscale) + sin = (freqs.sin() * self.mscale) + cache = torch.cat((cos, sin), dim=-1) + print("Cache shape", cache.shape) + return cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """PyTorch-native implementation equivalent to forward().""" + query_rot = query[..., :self.rotary_dim] + key_rot = key[..., :self.rotary_dim] + if self.rotary_dim < self.head_size: + query_pass = query[..., self.rotary_dim:] + key_pass = key[..., self.rotary_dim:] + + self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to( + positions.device) + cos_sin = self.cos_sin_cache[torch.add(positions, offsets) + if offsets is not None else positions] + cos, sin = cos_sin.chunk(2, dim=-1) + if self.is_neox_style: + # NOTE(woosuk): Here we assume that the positions tensor has the + # shape [batch_size, seq_len]. + cos = cos.repeat(1, 1, 2).unsqueeze(-2) + sin = sin.repeat(1, 1, 2).unsqueeze(-2) + else: + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) + + rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj + query_rot = query_rot * cos + rotate_fn(query_rot) * sin + key_rot = key_rot * cos + rotate_fn(key_rot) * sin + + if self.rotary_dim < self.head_size: + query = torch.cat((query_rot, query_pass), dim=-1) + key = torch.cat((key_rot, key_pass), dim=-1) + else: + query = query_rot + key = key_rot + return query, key + + +class GemmaRotaryEmbedding(RotaryEmbedding): + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107 + inv_freq = 1.0 / (base**( + torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float() / + self.rotary_dim)) + return inv_freq + + +class ExtendedRotaryEmbedding(RotaryEmbedding): + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + inv_freqs = super()._compute_inv_freq(base) + return self.apply_scaling(inv_freqs) + + def apply_scaling(self, freqs: torch.Tensor): + scale_factor = 8 + low_freq_factor = 1 + high_freq_factor = 4 + old_context_len = 8192 + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + new_freqs = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / scale_factor) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor) + new_freqs.append((1 - smooth) * freq / scale_factor + + smooth * freq) + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + + _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} @@ -534,10 +687,17 @@ def get_rope( rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, dtype) else: - scaling_type = rope_scaling["type"] - if scaling_type != "su": + scaling_type = rope_scaling[ + "type"] if "type" in rope_scaling else rope_scaling["rope_type"] + # The correct one should be "longrope" but keep "su" here + # for backward compatible + if scaling_type not in {"su", "longrope", "llama3"}: scaling_factor = rope_scaling["factor"] - if scaling_type == "linear": + if scaling_type == "llama3": + rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, dtype) + elif scaling_type == "linear": rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index b3783cecfdce8..3cef301bfa442 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -8,6 +8,8 @@ from rocsolidxgemm import rocb_create_extension, rocb_mm from vllm import _custom_C +from vllm.envs import VLLM_USE_ROCM_SKINNY_GEMM +from vllm.utils import is_hip class TunedGemm: @@ -26,6 +28,9 @@ def __init__(self): self.cu_count = torch.cuda.get_device_properties( device='cuda').multi_processor_count + self.use_skinny = is_hip() and VLLM_USE_ROCM_SKINNY_GEMM and \ + "gfx1" not in torch.cuda.get_device_properties('cuda').gcnArchName + if (self.save_gemm == 1): self.tuned_df = pd.DataFrame(columns=['M', 'N', 'K']) else: @@ -52,6 +57,8 @@ def query_sol(self, m, n, k): return self.solids.get((m, n, k), (0, 0)) def apply_skinny(self, m, n, k, inp_view, weights): + if not self.use_skinny: + return None if inp_view.dtype != torch.float16 or k % 8 != 0: return None if m > 8 and n <= 4: diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 2f4237339486e..ee9db7048f1f6 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -24,10 +24,12 @@ from typing import Iterable, List, Optional, Tuple import torch +import torch.nn.functional as F from torch import nn from transformers import MixtralConfig from vllm import _custom_ops as ops +from vllm import envs from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, @@ -181,6 +183,15 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, def process_weights_after_loading(self): # Fp8 is the only case where we need to process after loading. if not self.use_fp8: + if envs.VLLM_MOE_PADDING: + self.w13_weight = nn.Parameter(F.pad(self.w13_weight.data, + (0, 128), "constant", 0), + requires_grad=False) + torch.cuda.empty_cache() + self.w2_weight = nn.Parameter(F.pad(self.w2_weight.data, + (0, 128), "constant", 0), + requires_grad=False) + torch.cuda.empty_cache() return # If checkpoint is fp16, quantize here.