diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 499f55896c35c..7d9cc52023acf 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -12,6 +12,7 @@ ARG BUILD_TRITON="1" # If "0", it is copied in from the local working directory. ARG REMOTE_VLLM="0" + # ----------------------- # vLLM base image FROM $BASE_IMAGE AS base @@ -27,7 +28,23 @@ ENV PYTORCH_ROCM_ARCH=${ARG_PYTORCH_ROCM_ARCH} # Install some basic utilities RUN apt-get update -q -y && apt-get install -q -y python3 python3-pip RUN apt-get update -q -y && apt-get install -q -y \ - sqlite3 libsqlite3-dev libfmt-dev libmsgpack-dev libsuitesparse-dev + ccache sqlite3 libsqlite3-dev libfmt-dev libmsgpack-dev libsuitesparse-dev + +ENV CCACHE_DIR=/root/.cache/ccache + +RUN python3 -m pip install --upgrade pip +# Remove sccache so it doesn't interfere with ccache +# TODO: implement sccache support across components +RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)" +# Install torch == 2.5.0 on ROCm +RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ + *"rocm-6.1"*) \ + python3 -m pip uninstall -y torch torchvision \ + && python3 -m pip install --no-cache-dir --pre \ + torch==2.5.0.dev20240726 \ + torchvision==0.20.0.dev20240726 \ + --index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \ + *) ;; esac ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer ENV PATH=$PATH:/opt/rocm/bin:/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/bin: @@ -36,6 +53,7 @@ ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/opt/conda/envs/py_3.9/lib/python3.9/ WORKDIR ${COMMON_WORKDIR} + # ----------------------- # hipBLASLt build stages FROM base AS build_hipblaslt @@ -52,6 +70,7 @@ COPY --from=build_hipblaslt ${COMMON_WORKDIR}/hipBLASLt/build/release/*.deb / FROM scratch AS export_hipblaslt_0 FROM export_hipblaslt_${BUILD_HIPBLASLT} AS export_hipblaslt + # ----------------------- # RCCL build stages FROM base AS build_rccl @@ -66,14 +85,15 @@ COPY --from=build_rccl ${COMMON_WORKDIR}/rccl/build/release/*.deb / FROM scratch AS export_rccl_0 FROM export_rccl_${BUILD_RCCL} AS export_rccl + # ----------------------- # flash attn build stages FROM base AS build_flash_attn -ARG FA_BRANCH="ae7928c" +ARG FA_BRANCH="23a2b1c2" ARG FA_REPO="https://github.com/ROCm/flash-attention.git" RUN git clone ${FA_REPO} \ && cd flash-attention \ - && git checkout ${FA_BRANCH} \ + && git checkout "${FA_BRANCH}" \ && git submodule update --init \ && GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist FROM scratch AS export_flash_attn_1 @@ -82,10 +102,11 @@ COPY --from=build_flash_attn ${COMMON_WORKDIR}/flash-attention/dist/*.whl / FROM scratch AS export_flash_attn_0 FROM export_flash_attn_${BUILD_FA} AS export_flash_attn + # ----------------------- # Triton build stages FROM base AS build_triton -ARG TRITON_BRANCH="6ddb79b" +ARG TRITON_BRANCH="e0fc12c" ARG TRITON_REPO="https://github.com/OpenAI/triton.git" RUN git clone ${TRITON_REPO} \ && cd triton \ @@ -98,6 +119,7 @@ COPY --from=build_triton ${COMMON_WORKDIR}/triton/python/dist/*.whl / FROM scratch AS export_triton_0 FROM export_triton_${BUILD_TRITON} AS export_triton + # AMD-SMI build stages FROM base AS build_amdsmi RUN cd /opt/rocm/share/amd_smi \ @@ -105,6 +127,7 @@ RUN cd /opt/rocm/share/amd_smi \ FROM scratch AS export_amdsmi COPY --from=build_amdsmi /opt/rocm/share/amd_smi/dist/*.whl / + # ----------------------- # vLLM (and gradlib) fetch stages FROM base AS fetch_vllm_0 @@ -117,6 +140,7 @@ ONBUILD RUN git clone ${VLLM_REPO} \ && git checkout ${VLLM_BRANCH} FROM fetch_vllm_${REMOTE_VLLM} AS fetch_vllm + # ----------------------- # vLLM (and gradlib) build stages FROM fetch_vllm AS build_vllm @@ -130,7 +154,8 @@ if ls /install/*.deb; then \ && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \ fi # Build vLLM -RUN cd vllm \ +RUN --mount=type=cache,target=/root/.cache/ccache \ + cd vllm \ && python3 setup.py clean --all && python3 setup.py bdist_wheel --dist-dir=dist # Build gradlib RUN cd vllm/gradlib \ @@ -154,20 +179,9 @@ ARG COMMON_WORKDIR ARG BUILD_FA RUN python3 -m pip install --upgrade pip && rm -rf /var/lib/apt/lists/* -# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt. -# Manually remove it so that later steps of numpy upgrade can continue -RUN case "$(which python3)" in \ - *"/opt/conda/envs/py_3.9"*) \ - rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/;; \ - *) ;; esac - -RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \ - if ls /install/*.deb; then \ - apt-get purge -y hipblaslt \ - && dpkg -i /install/*.deb \ - && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \ - && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \ - fi +# Package upgrades for useful functionality or to avoid dependency issues +RUN --mount=type=cache,target=/root/.cache/pip \ + python3 -m pip install --upgrade numba scipy huggingface-hub[cli] RUN --mount=type=bind,from=export_rccl,src=/,target=/install \ if ls /install/*.deb; then \ @@ -200,16 +214,14 @@ RUN --mount=type=bind,from=export_amdsmi,src=/,target=/install \ RUN python3 -m pip install --upgrade numba scipy huggingface-hub[cli] # Install vLLM (and gradlib) -# Make sure punica kernels are built (for LoRA) -ENV VLLM_INSTALL_PUNICA_KERNELS=1 RUN --mount=type=bind,from=export_vllm,src=/,target=/install \ cd /install \ && pip install -U -r requirements-rocm.txt \ && case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ - *"rocm-6.0"*) \ - patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h rocm_patch/rocm_bf16.patch;; \ *"rocm-6.1"*) \ - cp rocm_patch/libamdhip64.so.6 /opt/rocm/lib/libamdhip64.so.6;; \ + cp rocm_patch/libamdhip64.so.6 /opt/rocm/lib/libamdhip64.so.6 \ + # Prevent interference if torch bundles its own HIP runtime + && rm -f "$(python3 -c 'import torch; print(torch.__path__[0])')"/lib/libamdhip64.so* || true;; \ *) ;; esac \ && pip uninstall -y vllm gradlib \ && pip install *.whl @@ -220,7 +232,6 @@ COPY --from=export_vllm /tests ${COMMON_WORKDIR}/vllm/tests COPY --from=export_vllm /examples ${COMMON_WORKDIR}/vllm/examples COPY --from=export_vllm /.buildkite ${COMMON_WORKDIR}/vllm/.buildkite - ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 ENV TOKENIZERS_PARALLELISM=false @@ -228,4 +239,3 @@ ENV TOKENIZERS_PARALLELISM=false ENV HIP_FORCE_DEV_KERNARG=1 CMD ["/bin/bash"] - diff --git a/ROCm_performance.md b/ROCm_performance.md index 31d6044801bcb..4ac81ee6e23ea 100644 --- a/ROCm_performance.md +++ b/ROCm_performance.md @@ -1,11 +1,4 @@ # Overview of the optional performance features uinque to https://github.com/ROCm/vllm -## Multi-GPU torchrun -On ROCm the default multi GPU executor is `torchrun` as opposed to `ray` on NVIDIA -This can be overridden by the `--worker-use-ray` flag to vllm or its benchmarks -To utilize torchran parallelism, the run command should be modified from -`python ` -to -`torchrun --standalone --nnodes=1 --nproc-per-node= ` ## Triton attention The default attention function on ROCm is using triton attention kernel. To fallback to the https://github.com/ROCm/flash-attention implementation set up the following environment symbol: `VLLM_USE_TRITON_FLASH_ATTN=0` @@ -53,3 +46,8 @@ python3 gradlib/gradlib/gemm_tuner.py --input_file /tmp/fp8_shapes.csv --tuned_f where `/tmp/tuned_fp8_16` will be used by our fp8 gemm linear layer. Now, when running inference with fp8, we are using the tuned gemm for best performance. + +## NCCL Performance environment variable + +For MI300x, setting environment variable NCCL_MIN_NCHANNELS=112 is expected to improve performance. + diff --git a/csrc/custom/custom.cu b/csrc/custom/custom.cu index e6f63e148cd37..d39cb59abaf83 100644 --- a/csrc/custom/custom.cu +++ b/csrc/custom/custom.cu @@ -1,16 +1,15 @@ #include #include #include - -namespace py = pybind11; +#include "core/registration.h" // declare templates for front (cpp) and back (cuda) sides of function: // template void LLGemm_Silu(void* in_a, void* in_b, void* out_c, const int M, const int K, cudaStream_t stream, const int rows_per_block); -void LLMM_Silu(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, - int64_t rows_per_block) { +void LLMM_Silu(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t rows_per_block) { auto M = in_a.size(0); auto K = in_a.size(1); LLGemm_Silu(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, @@ -21,10 +20,10 @@ void LLGemm1(void* in_a, void* in_b, void* out_c, const int M, const int K, cudaStream_t stream, const int rows_per_block); // template -void LLMM1(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, - int64_t rows_per_block) { - int M = in_a.size(0); - int K = in_a.size(1); +void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t rows_per_block = 4) { + auto M = in_a.size(0); + auto K = in_a.size(1); // if (N != in_b.numel()) // throw std::invalid_argument("Size mismatch A.numel(): " + // std::to_string(in_a.numel()) @@ -41,10 +40,10 @@ void LLMM1(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M, const int K, const int N, cudaStream_t stream, const int CuCount); -void wvSpltK(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, int64_t N_in, - int64_t CuCount) { - int M = in_a.size(0); - int K = in_a.size(1); +void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t N_in, const int64_t CuCount) { + auto M = in_a.size(0); + auto K = in_a.size(1); int N = N_in; wvSpltK_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, N, at::cuda::getCurrentCUDAStream(), CuCount); @@ -54,9 +53,9 @@ void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, cudaStream_t stream, const int solidx); void LLZZ(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, - const int solidx = 0) { - int M = in_a.size(0); - int K = in_a.size(1); + const int64_t solidx = 0) { + auto M = in_a.size(0); + auto K = in_a.size(1); LLGemmZZ(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, at::cuda::getCurrentCUDAStream(), solidx); @@ -69,7 +68,7 @@ void MMGPUKernel(float* in_a, float* in_b, float* out_c, int numARows, int numAColumns, int numBRows, int numBColumns, int numCRows, int numCColumns, cudaStream_t stream); -void MMCustomGPU(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c) { +void MMCustomGPU(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c) { auto matA_sizes{in_a.sizes()}; auto matB_sizes{in_b.sizes()}; auto matO_sizes{out_c.sizes()}; diff --git a/csrc/custom/custom_kernels.cu b/csrc/custom/custom_kernels.cu index 18679f86e82c1..f7dba39bb55ad 100644 --- a/csrc/custom/custom_kernels.cu +++ b/csrc/custom/custom_kernels.cu @@ -2,6 +2,7 @@ #include #include #include +#include "cuda_compat.h" #if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \ defined(__gfx941__) || defined(__gfx942__)) @@ -17,8 +18,6 @@ #define UNREACHABLE_CODE assert(false); #endif -constexpr int WARP_SIZE = 64; - template __device__ __forceinline__ T loadnt(T* addr) { return __builtin_nontemporal_load(addr); diff --git a/csrc/custom/custom_ops.h b/csrc/custom/custom_ops.h index 33da06fbda538..f6ea892b2ffa5 100644 --- a/csrc/custom/custom_ops.h +++ b/csrc/custom/custom_ops.h @@ -1,14 +1,14 @@ #pragma once #include -void LLMM_Silu(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, - int64_t rows_per_block); +void LLMM_Silu(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t rows_per_block); -void LLMM1(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, - int64_t rows_per_block); +void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t rows_per_block); -void wvSpltK(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, int64_t N_in, - int64_t CuCount); +void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t N_in, const int64_t CuCount); void paged_attention_custom(torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, diff --git a/csrc/custom/paged_attention/attention_ll4mi.cu b/csrc/custom/paged_attention/attention_ll4mi.cu index 09560cf0173ac..e78dce4c30de3 100644 --- a/csrc/custom/paged_attention/attention_ll4mi.cu +++ b/csrc/custom/paged_attention/attention_ll4mi.cu @@ -3,6 +3,7 @@ #include #include #include +#include "cuda_compat.h" #include @@ -23,7 +24,6 @@ #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) -#define WARP_SIZE 64 #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support diff --git a/csrc/custom/torch_bindings.cpp b/csrc/custom/torch_bindings.cpp index a6079f303a9cc..453f446f5d571 100644 --- a/csrc/custom/torch_bindings.cpp +++ b/csrc/custom/torch_bindings.cpp @@ -3,29 +3,28 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, custom_ops) { custom_ops.def( - "LLMM1(Tensor in_a, Tensor in_b, Tensor! out_c, int rows_per_block) -> ()" - ); + "LLMM1(Tensor in_a, Tensor in_b, Tensor! out_c, int rows_per_block=4) -> " + "()"); custom_ops.impl("LLMM1", torch::kCUDA, &LLMM1); custom_ops.def( - "LLMM_Silu(Tensor in_a, Tensor in_b, Tensor! out_c, int rows_per_block) -> ()" - ); + "LLMM_Silu(Tensor in_a, Tensor in_b, Tensor! out_c, int rows_per_block) " + "-> ()"); custom_ops.impl("LLMM_Silu", torch::kCUDA, &LLMM_Silu); custom_ops.def( - "paged_attention_custom(Tensor! out, Tensor exp_sums," - " Tensor max_logits, Tensor tmp_out," - " Tensor query, Tensor key_cache," - " Tensor value_cache, int num_kv_heads," - " float scale, Tensor block_tables," - " Tensor context_lens, int block_size," - " int max_context_len," - " Tensor? alibi_slopes," - " str kv_cache_dtype) -> ()" - ); - custom_ops.impl("paged_attention_custom", torch::kCUDA, &paged_attention_custom); + "paged_attention_custom(Tensor! out, Tensor exp_sums," + " Tensor max_logits, Tensor tmp_out," + " Tensor query, Tensor key_cache," + " Tensor value_cache, int num_kv_heads," + " float scale, Tensor block_tables," + " Tensor context_lens, int block_size," + " int max_context_len," + " Tensor? alibi_slopes," + " str kv_cache_dtype) -> ()"); + custom_ops.impl("paged_attention_custom", torch::kCUDA, + &paged_attention_custom); custom_ops.def( - "wvSpltK(Tensor in_a, Tensor in_b, Tensor! out_c, int N_in," - " int CuCount) -> ()" - ); + "wvSpltK(Tensor in_a, Tensor in_b, Tensor! out_c, int N_in," + " int CuCount) -> ()"); custom_ops.impl("wvSpltK", torch::kCUDA, &wvSpltK); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index cd5a9222aa0da..07c950b058788 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -154,16 +154,20 @@ void register_graph_buffers(fptr_t _fa, const std::vector& handles, #ifdef USE_ROCM -void free_meta_buffer(void* buffer) { hipFree(buffer); } +void free_meta_buffer(void* buffer) { CUDACHECK(cudaFree(buffer)); } -std::vector get_meta_buffer_ipc_handle(torch::Tensor inp) { - std::vector data_handle(sizeof(cudaIpcMemHandle_t), 0); - CUDACHECK(cudaIpcGetMemHandle((cudaIpcMemHandle_t*)data_handle.data(), +torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp) { + auto options = + torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto data_handle = + torch::empty({static_cast(sizeof(cudaIpcMemHandle_t))}, options); + CUDACHECK(cudaIpcGetMemHandle((cudaIpcMemHandle_t*)data_handle.data_ptr(), inp.data_ptr())); + ; return data_handle; } -torch::Tensor allocate_meta_buffer(int size) { +torch::Tensor allocate_meta_buffer(int64_t size) { auto device_index = c10::cuda::current_device(); at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index)); void* buffer; @@ -181,12 +185,4 @@ torch::Tensor allocate_meta_buffer(int size) { return torch::from_blob(buffer, {size}, free_meta_buffer, options); } -std::vector get_device_bdf(int dev) { - char busIdStr[] = "0000:00:00.0"; - std::vector bdf(sizeof(busIdStr), 0); - CUDACHECK(cudaDeviceGetPCIBusId((char*)bdf.data(), sizeof(busIdStr), dev)); - bdf.resize(bdf.size() - 1); // remove trailing NULL - return bdf; -} - #endif diff --git a/csrc/ops.h b/csrc/ops.h index 77180893568d4..6107a2941bd80 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -237,7 +237,6 @@ std::tuple> get_graph_buffer_ipc_meta( void register_graph_buffers(fptr_t _fa, const std::vector& handles, const std::vector>& offsets); #ifdef USE_ROCM -torch::Tensor allocate_meta_buffer(int size); -std::vector get_meta_buffer_ipc_handle(torch::Tensor inp); -std::vector get_device_bdf(int dev); +torch::Tensor allocate_meta_buffer(int64_t size); +torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp); #endif diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 7783acd741f5f..43c6f2d763bef 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -340,7 +340,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { &get_max_shared_memory_per_block_device_attribute); } -#ifndef USE_ROCM TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { // Custom all-reduce kernels custom_ar.def("init_custom_ar", &init_custom_ar); @@ -373,7 +372,13 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { custom_ar.def("register_graph_buffers", ®ister_graph_buffers); custom_ar.impl("register_graph_buffers", torch::kCPU, ®ister_graph_buffers); -} +#ifdef USE_ROCM + custom_ar.def("allocate_meta_buffer", &allocate_meta_buffer); + custom_ar.impl("allocate_meta_buffer", torch::kCUDA, &allocate_meta_buffer); + custom_ar.def("get_meta_buffer_ipc_handle", &get_meta_buffer_ipc_handle); + custom_ar.impl("get_meta_buffer_ipc_handle", torch::kCPU, + &get_meta_buffer_ipc_handle); #endif +} REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/docs/source/getting_started/amd-installation.rst b/docs/source/getting_started/amd-installation.rst index 4e3b4f6c9f7c0..17874bd07e5af 100644 --- a/docs/source/getting_started/amd-installation.rst +++ b/docs/source/getting_started/amd-installation.rst @@ -65,13 +65,6 @@ To build vllm on ROCm 6.1 for Radeon RX7900 series (gfx1100), you should specify $ DOCKER_BUILDKIT=1 docker build --build-arg BUILD_FA="0" -f Dockerfile.rocm -t vllm-rocm . -To build docker image for vllm on ROCm 5.7, you can specify ``BASE_IMAGE`` as below: - -.. code-block:: console - - $ DOCKER_BUILDKIT=1 docker build --build-arg BASE_IMAGE="rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" \ - -f Dockerfile.rocm -t vllm-rocm . - To run the above docker image ``vllm-rocm``, use the below command: .. code-block:: console @@ -160,10 +153,13 @@ Alternatively, wheels intended for vLLM use can be accessed under the releases. .. tip:: - Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers. - - To use CK flash-attention, please use this flag ``export VLLM_USE_TRITON_FLASH_ATTN=0`` to turn off triton flash attention. - - The ROCm version of pytorch, ideally, should match the ROCm driver version. + - Triton flash attention does not currently support sliding window attention. If using half precision, please use CK flash-attention for sliding window support. + - To use CK flash-attention or PyTorch naive attention, please use this flag ``export VLLM_USE_TRITON_FLASH_ATTN=0`` to turn off triton flash attention. + - The ROCm version of PyTorch, ideally, should match the ROCm driver version. .. tip:: - For MI300x (gfx942) users, to achieve optimal performance, please refer to `MI300x tuning guide `_ for performance optimization and tuning tips on system and workflow level. For vLLM, please refer to `vLLM performance optimization `_. + + diff --git a/setup_cython.py b/setup_cython.py deleted file mode 100644 index dca79af61a9f6..0000000000000 --- a/setup_cython.py +++ /dev/null @@ -1,37 +0,0 @@ -import Cython.Compiler.Options -from Cython.Build import cythonize -from setuptools import setup - -Cython.Compiler.Options.annotate = True - -infiles = [] - -infiles += [ - "vllm/engine/llm_engine.py", - "vllm/transformers_utils/detokenizer.py", - "vllm/engine/output_processor/single_step.py", - "vllm/outputs.py", - "vllm/engine/output_processor/stop_checker.py", -] - -infiles += [ - "vllm/core/scheduler.py", - "vllm/sequence.py", - "vllm/core/block_manager_v1.py", -] - -infiles += [ - "vllm/model_executor/layers/sampler.py", - "vllm/sampling_params.py", - "vllm/utils.py", -] - -setup(ext_modules=cythonize(infiles, - annotate=False, - force=True, - compiler_directives={ - 'language_level': "3", - 'infer_types': True - })) - -# example usage: python3 setup_cython.py build_ext --inplace diff --git a/tests/kernels/test_awq_triton.py b/tests/kernels/test_awq_triton.py index 198d40a155ccb..4fd88c1b21be1 100644 --- a/tests/kernels/test_awq_triton.py +++ b/tests/kernels/test_awq_triton.py @@ -5,6 +5,7 @@ import pytest import torch +from vllm.model_executor.layers.quantization.awq import torch_awq_dequantize from vllm.model_executor.layers.quantization.awq_triton import ( AWQ_TRITON_SUPPORTED_GROUP_SIZES, awq_dequantize_triton, awq_gemm_triton) @@ -57,7 +58,25 @@ def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor, scales = scales.repeat_interleave(group_size, dim=0) zeros = zeros.repeat_interleave(group_size, dim=0) - return (iweights - zeros) * scales + return (iweights - zeros) * scales, zeros + + +# input - [N, K] +# qweight - [K, M // 8] +# qzeros - [K // G, M // 8] +# scales - [K // G, M] +# split_k_iters - parallelism along K-dimension, int, power of 2. +def awq_gemm_torch(input: torch.Tensor, qweight: torch.Tensor, + scales: torch.Tensor, qzeros: torch.Tensor, + split_k_iters: int) -> torch.Tensor: + input_rows, input_cols = input.shape + qweight_rows, qweight_cols = qweight.shape + scales_rows, scales_cols = scales.shape + print(f"awq_gemm_torch:input_rows = {input_rows} input_cols = {input_cols}" + f" qweight_rows = {qweight_rows} qweight_cols = {qweight_cols}" + f" scales_rows = {scales_rows} scales_cols = {scales_cols}") + weights = torch_awq_dequantize(qweight, scales, qzeros) + return torch.matmul(input, weights) # qweights - [R , C // 8], int32 @@ -101,7 +120,8 @@ def test_dequantize(qweight_rows, qweight_cols, group_size): assert (not torch.any(torch.isinf(iweights_triton)) and not torch.any(torch.isnan(iweights_triton))) - iweights_torch = awq_dequantize_torch(qweight, scales, zeros, group_size) + iweights_torch = torch_awq_dequantize(qweight, scales, zeros) + print(f"Torch result:iweights_torch = {iweights_torch}") torch.testing.assert_close(iweights_triton, iweights_torch) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index d8ac4be156790..79f94a331fdd8 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -9,7 +9,7 @@ from transformers import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock -from vllm import envs +import vllm.envs as envs from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.models.mixtral import MixtralMoE @@ -97,14 +97,14 @@ def test_mixtral_moe(dtype: torch.dtype): # pad the weight if using padding if envs.VLLM_MOE_PADDING: - w13_weight = F.pad(vllm_moe.experts.w13_weight, (0, 128), "constant", - 0) + vllm_moe.experts.w13_weight = Parameter(F.pad( + vllm_moe.experts.w13_weight, (0, 128), "constant", 0), + requires_grad=False) torch.cuda.empty_cache() - w2_weight = F.pad(vllm_moe.experts.w2_weight, (0, 128), "constant", 0) + vllm_moe.experts.w2_weight = Parameter(F.pad( + vllm_moe.experts.w2_weight, (0, 128), "constant", 0), + requires_grad=False) torch.cuda.empty_cache() - vllm_moe.experts.w13_weight = Parameter(w13_weight, - requires_grad=False) - vllm_moe.experts.w2_weight = Parameter(w2_weight, requires_grad=False) # Run forward passes for both MoE blocks hf_states, _ = hf_moe.forward(hf_inputs) diff --git a/tests/test_utils.py b/tests/test_utils.py index eacf2105c15a0..baf527ca6a0db 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -108,23 +108,6 @@ def dummy(*, old_arg: object = None, new_arg: object = None): dummy(old_arg=1) -def is_rocm62(): - import torch - return isinstance(torch.version.hip, - str) and torch.version.hip.startswith("6.2") - - -def xfail_if_rocm62(function=None, - reason: str = "Tests are not yet ready for ROCm 6.2", - strict: bool = False): - if function: - return pytest.mark.xfail(is_rocm62(), reason=reason, - strict=strict)(function) - else: - assert callable(function) - return pytest.mark.xfail(is_rocm62(), reason=reason, strict=strict) - - def test_get_open_port(): os.environ["VLLM_PORT"] = "5678" # make sure we can get multiple ports, even if the env var is set @@ -193,3 +176,20 @@ def test_missing_required_argument(parser): parser.add_argument('--required-arg', required=True) with pytest.raises(SystemExit): parser.parse_args([]) + + +def is_rocm62(): + import torch + return isinstance(torch.version.hip, + str) and torch.version.hip.startswith("6.2") + + +def xfail_if_rocm62(function=None, + reason: str = "Tests are not yet ready for ROCm 6.2", + strict: bool = False): + if function: + return pytest.mark.xfail(is_rocm62(), reason=reason, + strict=strict)(function) + else: + assert callable(function) + return pytest.mark.xfail(is_rocm62(), reason=reason, strict=strict) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a31f56f657162..a838a9f6cede0 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -21,8 +21,8 @@ with contextlib.suppress(ImportError): import vllm._moe_C # noqa: F401 -if is_hip(): - import vllm._custom_C +with contextlib.suppress(ImportError): + import vllm._custom_C # noqa: F401 def hint_on_error(fn): @@ -131,31 +131,6 @@ def paged_attention_v2( blocksparse_block_size, blocksparse_head_sliding_step) -def paged_attention_custom( - out: torch.Tensor, - exp_sum: torch.Tensor, - max_logits: torch.Tensor, - tmp_out: torch.Tensor, - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - num_kv_heads: int, - scale: float, - block_tables: torch.Tensor, - seq_lens: torch.Tensor, - block_size: int, - max_seq_len: int, - alibi_slopes: Optional[torch.Tensor], - kv_cache_dtype: str, -): - torch.ops._custom_C.paged_attention_custom(out, exp_sum, max_logits, - tmp_out, query, key_cache, - value_cache, num_kv_heads, - scale, block_tables, seq_lens, - block_size, max_seq_len, - alibi_slopes, kv_cache_dtype) - - # pos encoding ops def rotary_embedding( positions: torch.Tensor, @@ -207,6 +182,12 @@ def advance_step(num_seqs: int, num_queries: int, block_size: int, def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, split_k_iters: int, thx: int, thy: int) -> torch.Tensor: + # print(f"awq_dequantize:qweight.shape = {qweight.shape}" + # f"scales = {scales.shape}," + # f"zeros = {zeros.shape}," + # f"split_k_iters = {split_k_iters}," + # f"thx = {thx}" + # f"thy = {thy}") if envs.VLLM_USE_TRITON_AWQ: from vllm.model_executor.layers.quantization.awq_triton import ( awq_dequantize_triton) @@ -217,6 +198,12 @@ def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: + # if input.shape[0] > 1: + # print(f"awq_gemm:input.shape = {input.shape}," + # f"qweight = {qweight.shape}," + # f"qzeros = {qzeros.shape}," + # f"scales.shape = {scales.shape}," + # f"split_k_iters = {split_k_iters}") if envs.VLLM_USE_TRITON_AWQ: from vllm.model_executor.layers.quantization.awq_triton import ( awq_gemm_triton) @@ -442,7 +429,7 @@ def scaled_fp8_quant( assert (input.ndim == 2) shape: Union[Tuple[int, int], torch.Size] = input.shape # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - out_dtype: torch.dtype = torch.float8_e4m3fnuz if vllm.utils.is_hip() \ + out_dtype: torch.dtype = torch.float8_e4m3fnuz if is_hip() \ else torch.float8_e4m3fn if num_token_padding: shape = (max(num_token_padding, input.shape[0]), shape[1]) @@ -680,19 +667,55 @@ def register_graph_buffers(fa: int, handles: List[str], torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) -def LLMM1(in_a: torch.Tensor, in_b: torch.Tensor, out_c: torch.Tensor, - rows_per_block: int): - torch.ops._custom_C.LLMM1(in_a, in_b, out_c, rows_per_block) +def allocate_meta_buffer(size: int) -> torch.Tensor: + return torch.ops._C_custom_ar.allocate_meta_buffer(size) + + +def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> List[str]: + return torch.ops._C_custom_ar.get_meta_buffer_ipc_handle(inp) -def LLMM_Silu(in_a: torch.Tensor, in_b: torch.Tensor, out_c: torch.Tensor, - rows_per_block: int): - torch.ops._custom_C.LLMM_Silu(in_a, in_b, out_c, rows_per_block) +# ROCm custom +def LLMM1(a: torch.Tensor, + b: torch.Tensor, + out: torch.Tensor, + rows_per_block: int = 4) -> None: + torch.ops._custom_C.LLMM1(a, b, out, rows_per_block) + + +def LLMM_Silu(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, + rows_per_block: int) -> None: + torch.ops._custom_C.LLMM_Silu(a, b, out, rows_per_block) + + +def paged_attention_custom( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, +) -> None: + torch.ops._custom_C.paged_attention_custom(out, exp_sum, max_logits, + tmp_out, query, key_cache, + value_cache, num_kv_heads, + scale, block_tables, seq_lens, + block_size, max_seq_len, + alibi_slopes, kv_cache_dtype) -def wvSpltK(in_a: torch.Tensor, in_b: torch.Tensor, out_c: torch.Tensor, - N_in: int, CuCount: int): - torch.ops._custom_C.wvSpltK(in_a, in_b, out_c, N_in, CuCount) +def wvSpltK(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, N: int, + cu_count: int) -> None: + torch.ops._custom_C.wvSpltK(a, b, out, N, cu_count) # temporary fix for https://github.com/vllm-project/vllm/issues/5456 diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 365dcc13f4863..412171296839d 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -307,7 +307,7 @@ def __init__( if self.use_naive_attn: self.attn_func = _sdpa_attention - logger.debug("Using naive attention in ROCmBackend") + logger.debug("Using naive (SDPA) attention in ROCmBackend") def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index f94211116a746..05134872ba39c 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -61,78 +61,58 @@ def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): @triton.jit -def load_fn(block_ptr, first, second, pad): - if first and second: - tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) - elif first: - tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad) - elif second: - tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad) +def load_fn(ptrs, offset_first, offset_second, boundary_first, + boundary_second): + if offset_first is not None and offset_second is not None: + mask = (offset_first[:, None] < boundary_first) & \ + (offset_second[None, :] < boundary_second) + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_first is not None: + mask = offset_first[:, None] < boundary_first + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_second is not None: + mask = offset_second[None, :] < boundary_second + tensor = tl.load(ptrs, mask=mask, other=0.0) else: - tensor = tl.load(block_ptr) + tensor = tl.load(ptrs) return tensor @triton.jit def _attn_fwd_inner( - acc, - l_i, - m_i, - q, - K_block_ptr, - V_block_ptr, - start_m, - actual_seqlen_k, - dropout_p, - philox_seed, - batch_philox_offset, - encoded_softmax_block_ptr, - block_min, - block_max, - offs_n_causal, - masked_blocks, - n_extra_tokens, - bias_ptr, - IS_CAUSAL: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - OFFS_M: tl.constexpr, - OFFS_N: tl.constexpr, - PRE_LOAD_V: tl.constexpr, - MASK_STEPS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - RETURN_ENCODED_SOFTMAX: tl.constexpr, - PADDED_HEAD: tl.constexpr, -): + acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, + stride_bn, start_m, actual_seqlen_k, actual_seqlen_q, dropout_p, + philox_seed, batch_philox_offset, encoded_sm_ptrs, block_min, + block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, + IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, + MASK_STEPS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, PADDED_HEAD: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr): # loop over k, v, and update accumulator for start_n in range(block_min, block_max, BLOCK_N): # For padded blocks, we will overrun the tensor size if # we load all BLOCK_N. For others, the blocks are all within range. - k = load_fn( - K_block_ptr, - PADDED_HEAD, - MASK_STEPS and (n_extra_tokens != 0), - "zero", - ) + k_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None + k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) + k = load_fn(k_ptrs, k_offs_k, k_offs_n, ACTUAL_BLOCK_DMODEL, + actual_seqlen_k) if PRE_LOAD_V: - v = load_fn( - V_block_ptr, - MASK_STEPS and (n_extra_tokens != 0), - PADDED_HEAD, - "zero", - ) + # We can use the same offsets as k, just with dims transposed. + v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, + ACTUAL_BLOCK_DMODEL) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # We start from end of seqlen_k so only the first iteration would need # to be checked for padding if it is not a multiple of block_n # TODO: This can be optimized to only be true for the padded block. - if MASK_STEPS: # noqa: SIM102 - # If this is the last block / iteration, we want to - # mask if the sequence length is not a multiple of block size - # a solution is to always do BLOCK_M // BLOCK_N + 1 steps - # if not is_modulo_mn. last step might get wasted but that is okay. - # check if this masking works for that case. - if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): + if MASK_STEPS: # NOQA: SIM102 + if start_n + BLOCK_N == block_max and n_extra_tokens != 0: + # If this is the last block / iteration, we want to + # mask if the sequence length is not a multiple of block size + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if + # not is_modulo_mn. Last step might get wasted but that is okay. + # Check if this masking works for that case. boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) @@ -145,13 +125,18 @@ def _attn_fwd_inner( qk = tl.where(causal_mask, qk, float("-inf")) # -- compute qk ---- qk += tl.dot(q, k) - if bias_ptr is not None: - bias = load_fn(bias_ptr, False, MASK_STEPS - and (n_extra_tokens != 0), "zero") - # While bias is added after multiplying qk with sm_scale, our - # optimization to use 2^x instead of e^x results in an additional - # scale factor of log2(e) which we must also multiply the bias with. - qk += bias * 1.44269504089 + if bias_ptrs is not None: + bias_offs_n = start_n + tl.arange(0, + BLOCK_N) if MASK_STEPS else None + bias = load_fn(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, + actual_seqlen_k) + # While bias is added after multiplying qk with sm_scale, + # our optimization to use 2^x instead of e^x results in an + # additional scale factor of log2(e) which we must also multiply + # the bias with. + qk += (bias * 1.44269504089) + + # softmax m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk = qk - m_ij[:, None] p = tl.math.exp2(qk) @@ -162,48 +147,32 @@ def _attn_fwd_inner( philox_offset = (batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n - BLOCK_N) - keep = dropout_mask( - philox_seed, - philox_offset, - dropout_p, - BLOCK_M, - BLOCK_N, - actual_seqlen_k, - ) + keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, + BLOCK_N, actual_seqlen_k) if RETURN_ENCODED_SOFTMAX: tl.store( - encoded_softmax_block_ptr, - tl.where(keep, p, - -p).to(encoded_softmax_block_ptr.type.element_ty), - ) + encoded_sm_ptrs, + tl.where(keep, p, -p).to(encoded_sm_ptrs.type.element_ty)) p = tl.where(keep, p, 0.0) elif RETURN_ENCODED_SOFTMAX: - tl.store( - encoded_softmax_block_ptr, - p.to(encoded_softmax_block_ptr.type.element_ty), - ) + tl.store(encoded_sm_ptrs, p.to(encoded_sm_ptrs.type.element_ty)) # -- update output accumulator -- alpha = tl.math.exp2(m_i - m_ij) acc = acc * alpha[:, None] if not PRE_LOAD_V: - v = load_fn( - V_block_ptr, - MASK_STEPS and (n_extra_tokens != 0), - PADDED_HEAD, - "zero", - ) + v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, + ACTUAL_BLOCK_DMODEL) # -- update m_i and l_i l_i = l_i * alpha + l_ij # update m_i and l_i m_i = m_ij - acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - if bias_ptr is not None: - bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) + acc += tl.dot(p.to(v.type.element_ty), v) + k_ptrs += BLOCK_N * stride_kn + v_ptrs += BLOCK_N * stride_vk + if bias_ptrs is not None: + bias_ptrs += BLOCK_N * stride_bn if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, - (0, BLOCK_N)) + encoded_sm_ptrs += BLOCK_N return acc, l_i, m_i @@ -306,60 +275,26 @@ def _attn_fwd_inner( key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], ) @triton.jit -def attn_fwd( - Q, - K, - V, - bias, - sm_scale, - L, - Out, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vk, - stride_vn, - stride_oz, - stride_oh, - stride_om, - stride_on, - stride_bz, - stride_bh, - stride_bm, - stride_bn, - cu_seqlens_q, - cu_seqlens_k, - dropout_p, - philox_seed, - philox_offset_base, - encoded_softmax, - HQ: tl.constexpr, - HK: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr, - MAX_SEQLENS_Q: tl.constexpr, - MAX_SEQLENS_K: tl.constexpr, - VARLEN: tl.constexpr, - IS_CAUSAL: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - PRE_LOAD_V: tl.constexpr, - BIAS_TYPE: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - RETURN_ENCODED_SOFTMAX: tl.constexpr, -): +def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, + stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, + stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, + stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, + stride_ah, cu_seqlens_q, cu_seqlens_k, dropout_p, philox_seed, + philox_offset_base, encoded_softmax, alibi_slopes, + HQ: tl.constexpr, HK: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, + IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_ALIBI: tl.constexpr): start_m = tl.program_id(0) off_h_q = tl.program_id(1) off_z = tl.program_id(2) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) if VARLEN: cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) @@ -389,8 +324,8 @@ def attn_fwd( # If seqlen_q != seqlen_k, attn scores are rectangular which means # the causal mask boundary is bottom right aligned, and ends at either # the top edge (seqlen_q < seqlen_k) or left edge. - # This captures the decrease in n_blocks if we have a rectangular attn - # matrix + # This captures the decrease in n_blocks if we have a rectangular + # attn matrix n_blocks_seqlen = cdiv_fn( (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) # This is what adjusts the block_max for the current WG, only @@ -399,29 +334,26 @@ def attn_fwd( # If we have no blocks after adjusting for seqlen deltas, this WG is # part of the blocks that are all 0. We exit early. if n_blocks <= 0: - o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + - off_h_q * stride_oh) - O_block_ptr = tl.make_block_ptr( - base=Out + o_offset, - shape=(seqlen_q, BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) + o_offset = (Out + off_z * stride_oz + off_h_q * stride_oh + + cu_seqlens_q_start * stride_om) + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[ + None, :] * stride_on acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) + o_ptrs_mask = offs_m[:, None] < seqlen_q # We still need to write 0s to the result - # tl.store(O_block_ptr, - # acc.to(Out.type.element_ty), boundary_check=(0,1)) - # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q - # + offs_m - # We store inf to LSE, not -inf because in the bwd pass, - # we subtract this - # from qk which makes it -inf, such that exp(qk - inf) = 0 - # for these masked blocks. - # l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) - # tl.store(l_ptrs, l) - # TODO: Should dropout and return encoded softmax be handled here? + tl.store(o_ptrs, acc, mask=o_ptrs_mask) + # The tensor allocated for L is based on MAX_SEQLENS_Q as that is + # statically known. + l_ptrs = (L + off_z * HQ * MAX_SEQLENS_Q + + off_h_q * MAX_SEQLENS_Q + offs_m) + # We store inf to LSE, not -inf because in the bwd pass, we subtract + # this from qk which makes it -inf, such that exp(qk - inf) = 0 for + # these masked blocks. + l = tl.full( # NOQA: E741 + [BLOCK_M], value=float("inf"), dtype=tl.float32) + l_ptrs_mask = offs_m < MAX_SEQLENS_Q + tl.store(l_ptrs, l, mask=l_ptrs_mask) + # TODO: Should dropout & return encoded softmax be handled here too? return # If MQA / GQA, set the K and V head offsets appropriately. @@ -433,71 +365,49 @@ def attn_fwd( n_extra_tokens = BLOCK_N - seqlen_k elif seqlen_k % BLOCK_N: n_extra_tokens = seqlen_k % BLOCK_N - padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL + PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) # Compute pointers for all the tensors used in this kernel. - q_offset = (off_z * stride_qz + off_h_q * stride_qh + + q_offset = (Q + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm) - Q_block_ptr = tl.make_block_ptr( - base=Q + q_offset, - shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - k_offset = (off_z * stride_kz + off_h_k * stride_kh + + q_ptrs = (q_offset + offs_m[:, None] * stride_qm + + offs_d[None, :] * stride_qk) + k_offset = (K + off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn) - K_block_ptr = tl.make_block_ptr( - base=K + k_offset, - shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), - strides=(stride_kk, stride_kn), - offsets=(0, 0), - block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1), - ) - v_offset = (off_z * stride_vz + off_h_k * stride_vh + + k_ptrs = (k_offset + offs_d[:, None] * stride_kk + + offs_n[None, :] * stride_kn) + v_offset = (V + off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk) - V_block_ptr = tl.make_block_ptr( - base=V + v_offset, - shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), - strides=(stride_vk, stride_vn), - offsets=(0, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0), - ) - if BIAS_TYPE != 0: - bias_ptr = tl.make_block_ptr( - base=bias + off_h_q * stride_bh, - shape=(seqlen_q, seqlen_k), - strides=(stride_bm, stride_bn), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0), - ) + v_ptrs = (v_offset + offs_n[:, None] * stride_vk + + offs_d[None, :] * stride_vn) + if USE_BIAS: + # Note: this might get large enough to overflow on some configs + bias_offset = off_h_q * stride_bh + bias_ptrs = bias + bias_offset + offs_m[:, None] * stride_bm + offs_n[ + None, :] * stride_bn + else: + bias_ptrs = None + + if USE_ALIBI: + a_offset = off_z * stride_az + off_h_q * stride_ah + alibi_slope = tl.load(alibi_slopes + a_offset) else: - bias_ptr = None + alibi_slope = None + if ENABLE_DROPOUT: - batch_philox_offset = philox_offset_base \ - + (off_z * HQ + off_h_q) \ - * seqlen_q * seqlen_k + off_hz = off_z * HQ + off_h_q + batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k else: batch_philox_offset = 0 - # We can ask to return the dropout mask without actually doing any dropout. - # In this case, we return an invalid pointer so indicate the mask is not i - # valid. - # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. + # We can ask to return the dropout mask without actually doing dropout. In + # this case, we return an invalid pointer so indicate the mask is not valid. if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.make_block_ptr( - base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, - shape=(seqlen_q, seqlen_k), - strides=(seqlen_k, 1), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0), - ) + encoded_sm_base = encoded_softmax + off_h_q * seqlen_q * seqlen_k + encoded_sm_ptrs = encoded_sm_base + offs_m[:, + None] * seqlen_k + offs_n[ + None, :] else: - encoded_softmax_block_ptr = 0 + encoded_sm_ptrs = None # initialize pointer to m and l m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) @@ -506,8 +416,11 @@ def attn_fwd( # have native e^x support in HW. qk_scale = sm_scale * 1.44269504089 # Q is loaded once at the beginning and shared by all N blocks. - q = load_fn(Q_block_ptr, True, padded_head, "zero") - q = (q * qk_scale).to(Q_block_ptr.type.element_ty) + q_ptrs_mask = offs_m[:, None] < seqlen_q + if PADDED_HEAD: + q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) + q = (q * qk_scale).to(q.type.element_ty) # Here we compute how many full and masked blocks we have. padded_block_k = n_extra_tokens != 0 @@ -525,8 +438,8 @@ def attn_fwd( n_full_blocks = n_blocks - masked_blocks block_min = 0 block_max = n_blocks * BLOCK_N - # Compute for full blocks. Here we set causal to false regardless of its - # value because there is no masking. Similarly we do not need padding. + # Compute for full blocks. Here we set causal to false unconditionally + # because there is no masking. Similarly we do not need padding. if n_full_blocks > 0: block_max = (n_blocks - masked_blocks) * BLOCK_N acc, l_i, m_i = _attn_fwd_inner( @@ -534,21 +447,26 @@ def attn_fwd( l_i, m_i, q, - K_block_ptr, - V_block_ptr, + k_ptrs, + v_ptrs, + bias_ptrs, + stride_kn, + stride_vk, + stride_bn, start_m, seqlen_k, + seqlen_q, dropout_p, philox_seed, batch_philox_offset, - encoded_softmax_block_ptr, + encoded_sm_ptrs, # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ block_min, block_max, 0, 0, 0, - bias_ptr, + alibi_slope, # IS_CAUSAL, .... False, BLOCK_M, @@ -561,41 +479,45 @@ def attn_fwd( False, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, - padded_head, - ) + PADDED_HEAD, + ACTUAL_BLOCK_DMODEL) block_min = block_max block_max = n_blocks * BLOCK_N tl.debug_barrier() # Remaining blocks, if any, are full / not masked. - if masked_blocks > 0: + if (masked_blocks > 0): offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0 - K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) - V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) - if bias_ptr is not None: - bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) + k_ptrs += n_full_blocks * BLOCK_N * stride_kn + v_ptrs += n_full_blocks * BLOCK_N * stride_vk + if USE_BIAS: + bias_ptrs += n_full_blocks * BLOCK_N * stride_bn if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, - (0, n_full_blocks)) + encoded_sm_ptrs += n_full_blocks * BLOCK_N acc, l_i, m_i = _attn_fwd_inner( acc, l_i, m_i, q, - K_block_ptr, - V_block_ptr, + k_ptrs, + v_ptrs, + bias_ptrs, + stride_kn, + stride_vk, + stride_bn, start_m, seqlen_k, + seqlen_q, dropout_p, philox_seed, batch_philox_offset, - encoded_softmax_block_ptr, + encoded_sm_ptrs, block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, - bias_ptr, + alibi_slope, IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, @@ -607,8 +529,8 @@ def attn_fwd( True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, - padded_head, - ) + PADDED_HEAD, + ACTUAL_BLOCK_DMODEL) # epilogue acc = acc / l_i[:, None] if ENABLE_DROPOUT: @@ -621,45 +543,42 @@ def attn_fwd( start_m_idx = start_m * BLOCK_M causal_start_idx = seqlen_q - seqlen_k acc = acc.to(Out.type.element_ty) - if IS_CAUSAL: # noqa: SIM102 + if IS_CAUSAL: # NOQA: SIM102 if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32) mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) - out_ptrs_mask = (mask_m_offsets[:, None] >= - out_mask_boundary[None, :]) + out_ptrs_mask = mask_m_offsets[:, + None] >= out_mask_boundary[None, :] z = 0.0 acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) # write back LSE - # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m - # If seqlen_q not multiple of BLOCK_M, we need to mask out the last - # few rows. This is only true for the last M block. For others, - # overflow_size will be -ve - # overflow_size = end_m_idx - seqlen_q - # if overflow_size > 0: - # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) - # # This is a > check because mask being 0 blocks the store. - # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) - # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) - # else: - # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few + # rows. This is only true for the last M block. For others, overflow_size + # will be -ve + overflow_size = end_m_idx - seqlen_q + if overflow_size > 0: + boundary = tl.full((BLOCK_M, ), + BLOCK_M - overflow_size, + dtype=tl.int32) + l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary + tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) + else: + tl.store(l_ptrs, m_i + tl.math.log2(l_i)) # write back O - o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + - off_h_q * stride_oh) - O_block_ptr = tl.make_block_ptr( - base=Out + o_offset, - shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - # Need boundary check on this to make sure the padding from the - # Q and KV tensors in both dims are not part of what we store back. - # TODO: Do the boundary check optionally. - tl.store(O_block_ptr, acc, boundary_check=(0, 1)) + o_offset = (Out + off_z * stride_oz + off_h_q * stride_oh + + cu_seqlens_q_start * stride_om) + o_ptrs = (o_offset + offs_m[:, None] * stride_om + + offs_d[None, :] * stride_on) + o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1) + if overflow_size > 0: + o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q) + if PADDED_HEAD: + o_ptrs_mask = o_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) def check_args( @@ -772,6 +691,10 @@ def forward( ) else: bias_strides = (0, 0, 0, 0) + alibi_strides = (0, 0) + M = torch.empty((batch, nheads_q, max_seqlens_q), + device=q.device, + dtype=torch.float32) attn_fwd[grid]( q, @@ -779,19 +702,21 @@ def forward( v, bias, sm_scale, - None, + M, o, *q_strides, *k_strides, *v_strides, *o_strides, *bias_strides, + *alibi_strides, cu_seqlens_q, cu_seqlens_k, dropout_p=0.0, philox_seed=philox_seed, philox_offset_base=philox_offset, encoded_softmax=encoded_softmax, + alibi_slopes=None, HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, @@ -800,7 +725,8 @@ def forward( IS_CAUSAL=causal, VARLEN=True, BLOCK_DMODEL=padded_d_model, - BIAS_TYPE=0 if bias is None else 1, + USE_BIAS=bias is not None, + USE_ALIBI=False, ENABLE_DROPOUT=False, RETURN_ENCODED_SOFTMAX=False, ) diff --git a/vllm/config.py b/vllm/config.py index 0a34dabf57e7c..a29049e8f6d97 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -894,11 +894,12 @@ def _verify_args(self) -> None: if self.use_ray: from vllm.executor import ray_utils ray_utils.assert_ray_available() - if is_hip(): + if not self.disable_custom_all_reduce and self.world_size > 1 and ( + self.pipeline_parallel_size) > 1: self.disable_custom_all_reduce = True logger.info( "Disabled the custom all-reduce kernel because it is not " - "supported on AMD GPUs.") + "supported with pipeline parallelism.") if self.ray_workers_use_nsight and not self.use_ray: raise ValueError("Unable to use nsight profiling unless workers " "run with Ray.") diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 0ef2cdc1aac4f..5d42d623f8337 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -12,90 +12,18 @@ from vllm.distributed.parallel_state import in_the_same_node_as from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import cuda_device_count_stateless, is_hip +from vllm.utils import cuda_device_count_stateless try: - if is_hip(): - from amdsmi import (AmdSmiException, amdsmi_get_processor_handles, - amdsmi_init, amdsmi_shut_down, - amdsmi_topo_get_link_type) - else: - import pynvml - - @contextmanager - def _nvml(): - if torch.version.hip: - try: - amdsmi_init() - yield - finally: - amdsmi_shut_down() - else: - try: - pynvml.nvmlInit() - yield - finally: - pynvml.nvmlShutdown() - -except ImportError: - # For AMD GPUs + ops.meta_size() + custom_ar = True +except Exception: + # For CPUs custom_ar = False - pynvml = None - - @contextmanager - def _nvml(): - try: - yield - finally: - pass - logger = init_logger(__name__) -@_nvml() -def _is_full_nvlink(device_ids: List[int], world_size) -> bool: - """ - query if the set of gpus are fully connected by nvlink (1 hop) - Note that `pynvml` is not affected by `CUDA_VISIBLE_DEVICES`, - so it works on real physical device ids. - """ - if is_hip(): - # On ROCm, we instead query if GPUs are connected by 1-hop XGMI - handles = [amdsmi_get_processor_handles()[i] for i in device_ids] - for i, handle in enumerate(handles): - for j, peer_handle in enumerate(handles): - if i < j: - try: - link_type = amdsmi_topo_get_link_type( - handle, peer_handle) - # type is 2 for XGMI - if link_type["hops"] != 1 or link_type["type"] != 2: - return False - except AmdSmiException as error: - logger.error("AMD link detection failed.", - exc_info=error) - return False - else: - handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids] - for i, handle in enumerate(handles): - for j, peer_handle in enumerate(handles): - if i < j: - try: - p2p_status = pynvml.nvmlDeviceGetP2PStatus( - handle, peer_handle, - pynvml.NVML_P2P_CAPS_INDEX_NVLINK) - if p2p_status != pynvml.NVML_P2P_STATUS_OK: - return False - except pynvml.NVMLError as error: - logger.error( - "NVLink detection failed. This is normal if your" - " machine has no NVLink equipped.", - exc_info=error) - return False - return True - - def _can_p2p(rank: int, world_size: int) -> bool: for i in range(world_size): if i == rank: @@ -186,14 +114,8 @@ def __init__(self, # test nvlink first, this will filter out most of the cases # where custom allreduce is not supported # this checks hardware and driver support for NVLink - - if current_platform.is_cuda(): - from vllm.platforms.cuda import CudaPlatform - cuda_platform: CudaPlatform = current_platform - full_nvlink = cuda_platform.is_full_nvlink(physical_device_ids, - world_size) - else: - full_nvlink = _is_full_nvlink(physical_device_ids, world_size) + assert current_platform.is_cuda() or current_platform.is_rocm() + full_nvlink = current_platform.is_full_nvlink(physical_device_ids) if world_size > 2 and not full_nvlink: logger.warning( "Custom allreduce is disabled because it's not supported on" @@ -204,7 +126,7 @@ def __init__(self, # this is expensive to compute at the first time # then we cache the result # On AMD GPU, p2p is always enabled between XGMI connected GPUs - if not is_hip() and not _can_p2p(rank, world_size): + if not current_platform.is_rocm() and not _can_p2p(rank, world_size): logger.warning( "Custom allreduce is disabled because your platform lacks " "GPU P2P capability or P2P test failed. To silence this " @@ -216,7 +138,7 @@ def __init__(self, # meta data composes of two parts: meta data for synchronization # (256 bytes) and a temporary buffer for storing intermediate # allreduce results. - if is_hip(): + if current_platform.is_rocm(): # meta data buffers need to be "uncached" for signal on MI200 self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size) else: @@ -239,7 +161,7 @@ def __init__(self, self.max_size = max_size self.rank = rank self.world_size = world_size - if is_hip(): + if current_platform.is_rocm(): # _share_cuda_() doesn't accept meta buffer not allocated from # PyTorch cache allocator, use direct HIP call to get IPC handle handle = ops.get_meta_buffer_ipc_handle(self.meta) @@ -271,10 +193,10 @@ def capture(self): self.register_graph_buffers() def _get_ipc_meta(self, inp: torch.Tensor): - if is_hip(): + if current_platform.is_rocm(): # _share_cuda_() doesn't accept meta buffer not allocated from # PyTorch cache allocator, use direct HIP call to get IPC handle - handle = custom_ar.get_meta_buffer_ipc_handle(inp) + handle = ops.get_meta_buffer_ipc_handle(inp) shard_data = ( bytes(handle), # ipc handle to base ptr 0, # offset of base ptr diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f703ac4d0b302..9a28af4451801 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -285,12 +285,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: # Parallel arguments parser.add_argument( '--distributed-executor-backend', - choices=['ray', 'mp', 'torchrun'], + choices=['ray', 'mp'], default=EngineArgs.distributed_executor_backend, help='Backend to use for distributed serving. When more than 1 GPU ' 'is used, on CUDA this will be automatically set to "ray" if ' - 'installed or "mp" (multiprocessing) otherwise. On ROCm, this is ' - 'instead set to torchrun by default.') + 'installed or "mp" (multiprocessing) otherwise.') parser.add_argument( '--worker-use-ray', action='store_true', diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index c441c1a1f2dfe..92c02072593e6 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -513,9 +513,6 @@ def _get_executor_cls(cls, initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutor executor_class = RayGPUExecutor - elif distributed_executor_backend == "torchrun": - from vllm.executor.torchrun_gpu_executor import TorchrunGPUExecutor - executor_class = TorchrunGPUExecutor elif distributed_executor_backend == "mp": from vllm.executor.multiproc_gpu_executor import ( MultiprocessingGPUExecutor) diff --git a/vllm/entrypoints/fast_sync_llm.py b/vllm/entrypoints/fast_sync_llm.py index 082c35077bffa..fc09f8a953c7f 100644 --- a/vllm/entrypoints/fast_sync_llm.py +++ b/vllm/entrypoints/fast_sync_llm.py @@ -2,13 +2,13 @@ from queue import Empty from typing import Union -from vllm import envs +import vllm.envs as envs from vllm.distributed.communication_op import broadcast_tensor_dict from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.executor.multiproc_gpu_executor import MultiprocessingGPUExecutor from vllm.executor.ray_gpu_executor import RayGPUExecutor -from vllm.inputs.data import PromptInputs, TokensPrompt +from vllm.inputs import PromptInputs, TokensPrompt from vllm.logger import init_logger from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams @@ -125,4 +125,4 @@ def run_engine(self): (output.request_id, result, stats)) except Exception as e: logger.error("Error in run_engine: %s", e) - raise e \ No newline at end of file + raise e diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 3598872b65bb0..85ef537519e5b 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -6,7 +6,7 @@ import uvicorn from fastapi import FastAPI, Response -from vllm import envs +import vllm.envs as envs from vllm.engine.async_llm_engine import AsyncEngineDeadError from vllm.engine.protocol import AsyncEngineClient from vllm.logger import init_logger diff --git a/vllm/entrypoints/sync_openai/api_server.py b/vllm/entrypoints/sync_openai/api_server.py index 1211dbf61e0e3..4c05742d6a78d 100644 --- a/vllm/entrypoints/sync_openai/api_server.py +++ b/vllm/entrypoints/sync_openai/api_server.py @@ -5,7 +5,7 @@ import time from contextlib import asynccontextmanager from http import HTTPStatus -from typing import Dict, List, Union +from typing import Dict, List, Optional, Union import uvicorn from fastapi import FastAPI, Request @@ -15,8 +15,8 @@ from prometheus_client import make_asgi_app import vllm +import vllm.envs as envs from vllm import FastSyncLLM as LLM -from vllm import envs from vllm.config import EngineConfig from vllm.engine.arg_utils import EngineArgs from vllm.entrypoints.chat_utils import (_parse_chat_message_content, @@ -63,8 +63,8 @@ def __init__(self): self.llm: LLM self.proc: multiprocessing.Process self.tokenizer = None - self.response_role: str - self.chat_template: str + self.response_role: Optional[str] + self.chat_template: Optional[str] def set_response_role(self, role): self.response_role = role @@ -96,7 +96,8 @@ async def run_main(self): ) self.loop = asyncio.get_event_loop() - self.proc = mp.Process(target=self.llm.run_engine) + self.proc = mp.Process( # type: ignore[attr-defined] + target=self.llm.run_engine) self.t.start() self.proc.start() @@ -173,8 +174,9 @@ async def _check_model(request: Union[CompletionRequest, async def _guided_decode_logits_processor(request, tokenizer): decoding_config = runner.engine_config.decoding_config - guided_decoding_backend = request.guided_decoding_backend \ - or decoding_config.guided_decoding_backend + assert decoding_config is not None + guided_decoding_backend = (request.guided_decoding_backend + or decoding_config.guided_decoding_backend) return await get_guided_decoding_logits_processor(guided_decoding_backend, request, tokenizer) diff --git a/vllm/envs.py b/vllm/envs.py index 3ea64fdf5c185..daebb411020c6 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -444,6 +444,10 @@ def get_default_config_root(): lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os .path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))), + # If set, vLLM will use Triton implementations of AWQ. + "VLLM_USE_TRITON_AWQ": + lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))), + # Try to accumulate this many requests before proceeding "VLLM_SYNC_SERVER_ACCUM_REQUESTS": lambda: int(os.getenv("VLLM_SYNC_SERVER_ACCUM_REQUESTS", "1")), @@ -455,10 +459,6 @@ def get_default_config_root(): # Pad the weight for moe kernel or not "VLLM_MOE_PADDING": lambda: bool(int(os.getenv("VLLM_MOE_PADDING", "1"))), - - # If set, vllm will print verbose logs during installation - "VLLM_USE_TRITON_AWQ": - lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", '1'))), } # end-env-vars-definition diff --git a/vllm/executor/torchrun_gpu_executor.py b/vllm/executor/torchrun_gpu_executor.py deleted file mode 100644 index 506c18c11186f..0000000000000 --- a/vllm/executor/torchrun_gpu_executor.py +++ /dev/null @@ -1,92 +0,0 @@ -from typing import List, Optional, Tuple, Union - -import torch - -import vllm.envs as envs -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, VisionLanguageConfig) -from vllm.distributed import (broadcast_object_list, - tensor_model_parallel_all_gather) -from vllm.executor.executor_base import ExecutorAsyncBase -from vllm.executor.gpu_executor import GPUExecutor -from vllm.logger import init_logger -from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput -from vllm.utils import make_async - -logger = init_logger(__name__) - -# A map between the device type (in device config) to its worker module. -DEVICE_TO_WORKER_MODULE_MAP = { - "cuda": "vllm.worker.worker", - "neuron": "vllm.worker.neuron_worker", -} - - -class TorchrunGPUExecutor(GPUExecutor): - - def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig], - speculative_config: Optional[SpeculativeConfig]) -> None: - self.local_rank = envs.LOCAL_RANK - self.rank = envs.RANK - self.is_driver_worker = self.rank == 0 - super().__init__(model_config, cache_config, parallel_config, - scheduler_config, device_config, load_config, - lora_config, vision_language_config, - speculative_config) - - def _init_executor(self): - self.driver_worker = self._create_worker(local_rank=self.local_rank, - rank=self.rank) - self.driver_worker.init_device() - self.driver_worker.load_model() - - def determine_num_available_blocks(self) -> Tuple[int, int]: - num_gpu_blocks, num_cpu_blocks = ( - self.driver_worker.determine_num_available_blocks()) - t = torch.tensor( - [[num_gpu_blocks], [num_cpu_blocks]], - device="cuda", - dtype=torch.int32, - ) - output = tensor_model_parallel_all_gather(t) - return (torch.min(output[0]).item(), torch.min(output[1]).item()) - - def execute_model( - self, execute_model_req: ExecuteModelRequest - ) -> List[Union[SamplerOutput, PoolerOutput]]: - output = self.driver_worker.execute_model(execute_model_req) - if self.is_driver_worker: - broadcast_object_list([output], src=0) - else: - res = [None] - broadcast_object_list(res, src=0) - output = res[0] - return output - - -class TorchrunGPUExecutorAsync(TorchrunGPUExecutor, ExecutorAsyncBase): - - async def execute_model_async( - self, - execute_model_req: ExecuteModelRequest, - ) -> List[Union[SamplerOutput, PoolerOutput]]: - output = await make_async(self.driver_worker.execute_model - )(execute_model_req=execute_model_req) - if self.is_driver_worker: - broadcast_object_list([output], src=0) - else: - res = [None] - broadcast_object_list(res, src=0) - output = res[0] - return output - - async def check_health_async(self) -> None: - # TorchrunGPUExecutor will always be healthy as long as - # it's running. - return diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 760357b77a9f7..56f86a1bfa593 100755 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -8,7 +8,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts, fused_marlin_moe, fused_moe, fused_topk, - get_config_file_name, grouped_topk) + get_config_file_name, grouped_topk, invoke_fused_moe_kernel, + moe_align_block_size) __all__ += [ "fused_marlin_moe", diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index df2db7a061546..a15f9e08018ca 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from torch.nn.modules import Module -from vllm import envs +import vllm.envs as envs from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 28906c43ef8dc..069181449fe8b 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -85,7 +85,7 @@ def __init__( self.act_fn = SiluAndMul() def forward(self, x): - if x.shape[0] == 1 and x.shape[1] == 1: + if is_hip() and x.shape[0] == 1 and x.shape[1] == 1: out = torch.empty(x.shape[0], self.gate_up_proj.weight.shape[0] // 2, dtype=x.dtype, diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 28525e8ff8811..d3e325d8a613d 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -1,8 +1,11 @@ import os -from functools import lru_cache -from typing import Tuple +from functools import lru_cache, wraps +from typing import List, Tuple import torch +from amdsmi import (AmdSmiException, amdsmi_get_gpu_board_info, + amdsmi_get_processor_handles, amdsmi_init, + amdsmi_shut_down, amdsmi_topo_get_link_type) from vllm.logger import init_logger @@ -16,6 +19,42 @@ " `spawn` instead.") os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" +# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES`` +if "HIP_VISIBLE_DEVICES" in os.environ: + val = os.environ["HIP_VISIBLE_DEVICES"] + if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None): + assert val == cuda_val + else: + os.environ["CUDA_VISIBLE_DEVICES"] = val + + +# AMDSMI utils +# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`, +# all the related functions work on real physical device ids. +# the major benefit of using AMDSMI is that it will not initialize CUDA + + +def with_nvml_context(fn): + + @wraps(fn) + def wrapper(*args, **kwargs): + amdsmi_init() + try: + return fn(*args, **kwargs) + finally: + amdsmi_shut_down() + + return wrapper + + +def device_id_to_physical_device_id(device_id: int) -> int: + if "CUDA_VISIBLE_DEVICES" in os.environ: + device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",") + physical_device_id = device_ids[device_id] + return int(physical_device_id) + else: + return device_id + class RocmPlatform(Platform): _enum = PlatformEnum.ROCM @@ -26,6 +65,36 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]: return torch.cuda.get_device_capability(device_id) @staticmethod + @with_nvml_context + def is_full_nvlink(physical_device_ids: List[int]) -> bool: + """ + query if the set of gpus are fully connected by xgmi (1 hop) + """ + # On ROCm, we instead query if GPUs are connected by 1 hop XGMI + handles = [ + amdsmi_get_processor_handles()[i] for i in physical_device_ids + ] + for i, handle in enumerate(handles): + for j, peer_handle in enumerate(handles): + if i < j: + try: + link_type = amdsmi_topo_get_link_type( + handle, peer_handle) + # type is 2 for XGMI + if link_type["hops"] != 1 or link_type["type"] != 2: + return False + except AmdSmiException as error: + logger.error("AMD 1 hop XGMI detection failed.", + exc_info=error) + return False + return True + + @staticmethod + @with_nvml_context @lru_cache(maxsize=8) def get_device_name(device_id: int = 0) -> str: - return torch.cuda.get_device_name(device_id) + physical_device_id = device_id_to_physical_device_id(device_id) + handle = amdsmi_get_processor_handles()[physical_device_id] + # Note: this may not be exactly the same as the torch device name + # E.g. `AMD Instinct MI300X OAM` vs `AMD Instinct MI300X` + return amdsmi_get_gpu_board_info(handle)["product_name"] diff --git a/vllm/sequence.py b/vllm/sequence.py index 9efbe51a61d72..e7cde87f605a7 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -688,7 +688,7 @@ def maybe_set_first_token_time(self, time: float) -> None: # in TPOT, rather than recalculating TTFT (since from the ) # POV of the user, there is simply a long generation delay. if (self.metrics.first_token_time is None - and next(iter(self.seqs)).get_output_len() == 1): + and self.seqs[0].get_output_len() == 1): self.metrics.first_token_time = time def maybe_set_first_scheduled_time(self, time: float) -> None: @@ -818,7 +818,7 @@ def is_finished(self) -> bool: def is_prefill(self) -> bool: # Every sequence should be in the same stage. - return next(iter(self.seqs)).is_prefill() + return self.seqs[0].is_prefill() def __repr__(self) -> str: return (f"SequenceGroup(request_id={self.request_id}, " diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index f47cf4b24f923..7ed609c3b447c 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -444,12 +444,8 @@ def init_worker_distributed_environment( """Initialize the distributed environment.""" set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) - if parallel_config.distributed_executor_backend != "torchrun": - init_distributed_environment(parallel_config.world_size, rank, - distributed_init_method, local_rank) - else: - init_distributed_environment(parallel_config.world_size, -1, "env://", - local_rank) + init_distributed_environment(parallel_config.world_size, rank, + distributed_init_method, local_rank) ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size)