diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index 972c62a091aea..6659440135ff4 100755 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -71,13 +71,36 @@ mkdir -p ${HF_CACHE} HF_MOUNT="/root/.cache/huggingface" commands=$@ +echo "Commands:$commands" +#ignore certain kernels tests +if [[ $commands == *" kernels "* ]]; then + commands="${commands} \ + --ignore=kernels/test_attention.py \ + --ignore=kernels/test_attention_selector.py \ + --ignore=kernels/test_blocksparse_attention.py \ + --ignore=kernels/test_causal_conv1d.py \ + --ignore=kernels/test_cutlass.py \ + --ignore=kernels/test_encoder_decoder_attn.py \ + --ignore=kernels/test_flash_attn.py \ + --ignore=kernels/test_flashinfer.py \ + --ignore=kernels/test_int8_quant.py \ + --ignore=kernels/test_machete_gemm.py \ + --ignore=kernels/test_mamba_ssm.py \ + --ignore=kernels/test_marlin_gemm.py \ + --ignore=kernels/test_moe.py \ + --ignore=kernels/test_prefix_prefill.py \ + --ignore=kernels/test_rand.py \ + --ignore=kernels/test_sampler.py" +fi + PARALLEL_JOB_COUNT=8 # check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs. if [[ $commands == *"--shard-id="* ]]; then for GPU in $(seq 0 $(($PARALLEL_JOB_COUNT-1))); do #replace shard arguments - commands=${@//"--shard-id= "/"--shard-id=${GPU} "} + commands=${commands//"--shard-id= "/"--shard-id=${GPU} "} commands=${commands//"--num-shards= "/"--num-shards=${PARALLEL_JOB_COUNT} "} + echo "Shard ${GPU} commands:$commands" docker run \ --device /dev/kfd --device /dev/dri \ --network host \ diff --git a/.buildkite/run-cpu-test-ppc64le.sh b/.buildkite/run-cpu-test-ppc64le.sh index a01cf3fe67489..49ae838cf0690 100755 --- a/.buildkite/run-cpu-test-ppc64le.sh +++ b/.buildkite/run-cpu-test-ppc64le.sh @@ -11,8 +11,9 @@ trap remove_docker_container EXIT remove_docker_container # Run the image, setting --shm-size=4g for tensor parallel. +source /etc/environment #docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test cpu-test -docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN --name cpu-test cpu-test +docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN=$HF_TOKEN --name cpu-test cpu-test # Run basic model test docker exec cpu-test bash -c " diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index ca9cf15780e25..73ce82c5857ab 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -22,13 +22,17 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py" # Run basic model test docker exec cpu-test bash -c " - pip install pytest matplotlib einops transformers_stream_generator - pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py \ - --ignore=tests/models/test_oot_registration.py \ - --ignore=tests/models/test_registry.py \ - --ignore=tests/models/test_fp8.py \ - --ignore=tests/models/test_jamba.py \ - --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported + pip install pytest matplotlib einops transformers_stream_generator datamodel_code_generator + pytest -v -s tests/models/decoder_only/language \ + --ignore=tests/models/test_fp8.py \ + --ignore=tests/models/decoder_only/language/test_jamba.py \ + --ignore=tests/models/decoder_only/language/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported + +# Run compressed-tensor test +docker exec cpu-test bash -c " + pytest -s -v \ + tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \ + tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynanmic_per_token" # online inference docker exec cpu-test bash -c " diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d0317b2fc48c9..9b0cb6663a55b 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -50,6 +50,7 @@ steps: - tests/worker commands: - pytest -v -s async_engine # Async Engine + - NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py - pytest -v -s test_inputs.py - pytest -v -s multimodal - pytest -v -s test_utils.py # Utils @@ -91,7 +92,7 @@ steps: - pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process - pytest -v -s entrypoints/openai - pytest -v -s entrypoints/test_chat_utils.py - + - pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests - label: Distributed Tests (4 GPUs) # 10min working_dir: "/vllm-workspace/tests" @@ -162,15 +163,6 @@ steps: - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 offline_inference_encoder_decoder.py -- label: Models Test # 1hr10min - source_file_dependencies: - - vllm/ - - tests/models - commands: - - pip install -e ./plugins/vllm_add_dummy_model - - pytest -v -s models/test_oot_registration.py # it needs a clean process - - pytest -v -s models -m \"not vlm\" --ignore=models/test_oot_registration.py - - label: torch compile integration test source_file_dependencies: - vllm/ @@ -178,14 +170,6 @@ steps: - pytest -v -s ./compile/test_full_graph.py - pytest -v -s ./compile/test_wrapper.py - -- label: Vision Language Models Test # 42min - #mirror_hardwares: [amd] - source_file_dependencies: - - vllm/ - commands: - - pytest -v -s models -m vlm - - label: Prefix Caching Test # 7min #mirror_hardwares: [amd] source_file_dependencies: @@ -217,7 +201,8 @@ steps: commands: # See https://github.com/vllm-project/vllm/issues/5152 - export VLLM_ATTENTION_BACKEND=XFORMERS - - pytest -v -s spec_decode + - pytest -v -s spec_decode/e2e/test_multistep_correctness.py + - pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py - label: LoRA Test %N # 30min each mirror_hardwares: [amd] @@ -228,6 +213,7 @@ steps: parallelism: 4 - label: Kernels Test %N # 30min each + mirror_hardwares: [amd] source_file_dependencies: - csrc/ - vllm/attention @@ -282,6 +268,45 @@ steps: commands: - pytest -v -s tool_use +##### models test ##### + +- label: Basic Models Test # 3min + source_file_dependencies: + - vllm/ + - tests/models + commands: + - pip install -e ./plugins/vllm_add_dummy_model + - pytest -v -s models/test_oot_registration.py # it needs a clean process + - pytest -v -s models/*.py --ignore=models/test_oot_registration.py + +- label: Decoder-only Language Models Test # 1h3min + #mirror_hardwares: [amd] + source_file_dependencies: + - vllm/ + - tests/models/decoder_only/language + commands: + - pytest -v -s models/decoder_only/language + +- label: Decoder-only Multi-Modal Models Test # 56min + #mirror_hardwares: [amd] + source_file_dependencies: + - vllm/ + - tests/models/decoder_only/audio_language + - tests/models/decoder_only/vision_language + commands: + - pytest -v -s models/decoder_only/audio_language + - pytest -v -s models/decoder_only/vision_language + +- label: Other Models Test # 5min + #mirror_hardwares: [amd] + source_file_dependencies: + - vllm/ + - tests/models/embedding/language + - tests/models/encoder_decoder/language + commands: + - pytest -v -s models/embedding/language + - pytest -v -s models/encoder_decoder/language + ##### 1 GPU test ##### ##### multi gpus test ##### @@ -307,11 +332,11 @@ steps: - tests/distributed/ commands: - # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up) - - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py + - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep -q 'Same node test passed' - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py - # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up) - - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py + - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep -q 'Same node test passed' - label: Distributed Tests (2 GPUs) # 28min #mirror_hardwares: [amd] @@ -324,11 +349,10 @@ steps: - vllm/model_executor/models/ - tests/distributed/ commands: - - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py - - TARGET_TEST_SUITE=L4 pytest -v -s distributed/test_basic_distributed_correctness.py - - pytest -v -s distributed/test_basic_distributed_correctness_enc_dec.py - - pytest -v -s distributed/test_chunked_prefill_distributed.py - - pytest -v -s distributed/test_multimodal_broadcast.py + - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed' + - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus + # Avoid importing model tests that cause CUDA reinitialization error + - pytest models/encoder_decoder/language/test_bart.py models/decoder_only/vision_language/test_broadcast.py -v -s -m distributed_2_gpus - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - pip install -e ./plugins/vllm_add_dummy_model - pytest -v -s distributed/test_distributed_oot.py @@ -386,7 +410,18 @@ steps: - vllm/ - tests/weight_loading commands: - - bash weight_loading/run_model_weight_loading_test.sh + - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt + +- label: Weight Loading Multiple GPU Test - Large Models # optional + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + gpu: a100 + optional: true + source_file_dependencies: + - vllm/ + - tests/weight_loading + commands: + - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt ##### multi gpus test ##### diff --git a/.github/ISSUE_TEMPLATE/400-bug report.yml b/.github/ISSUE_TEMPLATE/400-bug report.yml index d4113da8b5b81..30db1721a9df7 100644 --- a/.github/ISSUE_TEMPLATE/400-bug report.yml +++ b/.github/ISSUE_TEMPLATE/400-bug report.yml @@ -30,6 +30,15 @@ body: validations: required: true +- type: textarea + attributes: + label: Model Input Dumps + description: | + If you are facing crashing due to illegal memory access or other issues with model execution, vLLM may dump the problematic input of the model. In this case, you will see the message `Error in model execution (input dumped to /tmp/err_xxx.pkl)`. If you see this message, please zip the file (because GitHub doesn't support .pkl file format) and upload it here. This will help us to reproduce the issue and facilitate the debugging process. + placeholder: | + Upload the dumped input file. + validations: + required: false - type: textarea attributes: label: 🐛 Describe the bug diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 262ce8e1530a8..be0afc6305044 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -39,6 +39,16 @@ FIX #xxxx (*link existing issues this PR will resolve*)
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.
  • +

    Adding or changing kernels

    +

    Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.

    + +

    Notes for Large Changes

    Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

    diff --git a/CMakeLists.txt b/CMakeLists.txt index 71f160acc4dcc..7a0fa967155bb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -208,9 +208,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") FetchContent_Declare( cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git - # CUTLASS 3.5.1 - GIT_TAG 06b21349bcf6ddf6a1686a47a137ad1446579db9 + GIT_TAG v3.5.1 GIT_PROGRESS TRUE + + # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. + # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. + # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE + GIT_SHALLOW TRUE ) FetchContent_MakeAvailable(cutlass) @@ -244,6 +248,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "-gencode arch=compute_90a,code=sm_90a") endif() + # # Machete kernels @@ -307,28 +312,11 @@ define_gpu_extension_target( USE_SABI 3 WITH_SOABI) -if(VLLM_GPU_LANG STREQUAL "HIP") - # - # custom extension - # - set(CUSTOM_SRC - "csrc/custom/torch_bindings.cpp" - "csrc/custom/custom_kernels.cu" - "csrc/custom/fused_kernels.cu" - "csrc/custom/custom.cu" - "csrc/custom/paged_attention/attention_ll4mi.cu" - ) - - define_gpu_extension_target( - _custom_C - DESTINATION vllm - LANGUAGE ${VLLM_GPU_LANG} - SOURCES ${CUSTOM_SRC} - COMPILE_FLAGS ${VLLM_GPU_FLAGS} - ARCHITECTURES ${VLLM_GPU_ARCHES} - USE_SABI 3 - WITH_SOABI) -endif() +# If CUTLASS is compiled on NVCC >= 12.5, it by default uses +# cudaGetDriverEntryPointByVersion as a wrapper to avoid directly calling the +# driver API. This causes problems when linking with earlier versions of CUDA. +# Setting this variable sidesteps the issue by calling the driver directly. +target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) # # _moe_C extension @@ -354,6 +342,28 @@ define_gpu_extension_target( WITH_SOABI) +if(VLLM_GPU_LANG STREQUAL "HIP") + # + # _rocm_C extension + # + set(VLLM_ROCM_EXT_SRC + "csrc/rocm/torch_bindings.cpp" + "csrc/rocm/attention.cu" + "csrc/rocm/custom_kernels.cu" + "csrc/rocm/fused_kernels.cu" + "csrc/rocm/custom.cu") + + define_gpu_extension_target( + _rocm_C + DESTINATION vllm + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${VLLM_ROCM_EXT_SRC} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + USE_SABI 3 + WITH_SOABI) +endif() + if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") message(STATUS "Enabling C extension.") @@ -364,6 +374,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") endif() if(VLLM_GPU_LANG STREQUAL "HIP") - message(STATUS "Enabling custom extension.") - add_dependencies(default _custom_C) + message(STATUS "Enabling rocm extension.") + add_dependencies(default _rocm_C) endif() diff --git a/Dockerfile b/Dockerfile index 0ec6655ed449e..5484be5bc5785 100644 --- a/Dockerfile +++ b/Dockerfile @@ -145,6 +145,7 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \ && apt-get update -y \ && apt-get install -y ccache software-properties-common git curl sudo vim python3-pip \ + && apt-get install -y ffmpeg libsm6 libxext6 libgl1 \ && add-apt-repository ppa:deadsnakes/ppa \ && apt-get update -y \ && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv libibverbs-dev \ diff --git a/Dockerfile.cpu b/Dockerfile.cpu index 9a570f988f3db..34b4c95e34ffc 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -2,9 +2,14 @@ FROM ubuntu:22.04 AS cpu-test-1 +ENV CCACHE_DIR=/root/.cache/ccache + +ENV CMAKE_CXX_COMPILER_LAUNCHER=ccache + RUN --mount=type=cache,target=/var/cache/apt \ apt-get update -y \ && apt-get install -y curl ccache git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \ + && apt-get install -y ffmpeg libsm6 libxext6 libgl1 \ && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 # https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html @@ -25,6 +30,19 @@ RUN --mount=type=cache,target=/root/.cache/pip \ pip install --upgrade pip && \ pip install -r requirements-build.txt +# install oneDNN +RUN git clone -b rls-v3.5 https://github.com/oneapi-src/oneDNN.git + +RUN --mount=type=cache,target=/root/.cache/ccache \ + cmake -B ./oneDNN/build -S ./oneDNN -G Ninja -DONEDNN_LIBRARY_TYPE=STATIC \ + -DONEDNN_BUILD_DOC=OFF \ + -DONEDNN_BUILD_EXAMPLES=OFF \ + -DONEDNN_BUILD_TESTS=OFF \ + -DONEDNN_BUILD_GRAPH=OFF \ + -DONEDNN_ENABLE_WORKLOAD=INFERENCE \ + -DONEDNN_ENABLE_PRIMITIVE=MATMUL && \ + cmake --build ./oneDNN/build --target install --config Release + FROM cpu-test-1 AS build WORKDIR /workspace/vllm @@ -40,7 +58,6 @@ COPY ./ ./ ARG VLLM_CPU_DISABLE_AVX512 ENV VLLM_CPU_DISABLE_AVX512=${VLLM_CPU_DISABLE_AVX512} -ENV CCACHE_DIR=/root/.cache/ccache RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=cache,target=/root/.cache/ccache \ VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel && \ diff --git a/Dockerfile.neuron b/Dockerfile.neuron index caa1b1d6c4424..f0c3479625a70 100644 --- a/Dockerfile.neuron +++ b/Dockerfile.neuron @@ -6,7 +6,9 @@ FROM $BASE_IMAGE RUN echo "Base image is $BASE_IMAGE" # Install some basic utilities -RUN apt-get update && apt-get install python3 python3-pip -y +RUN apt-get update \ + && apt-get install python3 python3-pip -y \ + && apt-get install -y ffmpeg libsm6 libxext6 libgl1 ### Mount Point ### # When launching the container, mount the code directory to /app diff --git a/Dockerfile.openvino b/Dockerfile.openvino index 06ca4638dfeb9..96b9593a2bfa8 100644 --- a/Dockerfile.openvino +++ b/Dockerfile.openvino @@ -4,7 +4,8 @@ FROM ubuntu:22.04 AS dev RUN apt-get update -y && \ - apt-get install -y python3-pip git + apt-get install -y python3-pip git && \ + apt-get install -y ffmpeg libsm6 libxext6 libgl1 WORKDIR /workspace # copy requirements diff --git a/Dockerfile.ppc64le b/Dockerfile.ppc64le index 16780f8ab950c..3313162bf28e1 100644 --- a/Dockerfile.ppc64le +++ b/Dockerfile.ppc64le @@ -4,7 +4,7 @@ USER root ENV PATH="/usr/local/cargo/bin:$PATH:/opt/conda/bin/" -RUN apt-get update -y && apt-get install -y git wget vim libnuma-dev libsndfile-dev libprotobuf-dev build-essential +RUN apt-get update -y && apt-get install -y git wget curl vim libnuma-dev libsndfile-dev libprotobuf-dev build-essential ffmpeg libsm6 libxext6 libgl1 # Some packages in requirements-cpu are installed here # IBM provides optimized packages for ppc64le processors in the open-ce project for mamba @@ -16,7 +16,7 @@ COPY ./ /workspace/vllm WORKDIR /workspace/vllm # These packages will be in rocketce eventually -RUN pip install -v cmake torch==2.3.1 uvloop==0.20.0 -r requirements-cpu.txt --prefer-binary --extra-index-url https://repo.fury.io/mgiessing +RUN pip install -v cmake xformers torch==2.3.1 uvloop==0.20.0 -r requirements-cpu.txt --prefer-binary --extra-index-url https://repo.fury.io/mgiessing RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install @@ -25,4 +25,3 @@ WORKDIR /workspace/ RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] - diff --git a/Dockerfile.tpu b/Dockerfile.tpu index 3a11c6721ead9..04cd4d79f4045 100644 --- a/Dockerfile.tpu +++ b/Dockerfile.tpu @@ -4,6 +4,9 @@ ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:night FROM $BASE_IMAGE WORKDIR /workspace +# Install some basic utilities +RUN apt-get update && apt-get install -y ffmpeg libsm6 libxext6 libgl1 + # Install the TPU and Pallas dependencies. RUN python3 -m pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html RUN python3 -m pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html diff --git a/Dockerfile.xpu b/Dockerfile.xpu index f91baa11a3753..50bbd8f7dad87 100644 --- a/Dockerfile.xpu +++ b/Dockerfile.xpu @@ -1,15 +1,22 @@ -FROM intel/oneapi-basekit:2024.1.0-devel-ubuntu20.04 +FROM intel/oneapi-basekit:2024.2.1-0-devel-ubuntu22.04 RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/intel-oneapi-archive-keyring.gpg > /dev/null && \ echo "deb [signed-by=/usr/share/keyrings/intel-oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main " | tee /etc/apt/sources.list.d/oneAPI.list && \ chmod 644 /usr/share/keyrings/intel-oneapi-archive-keyring.gpg && \ - rm /etc/apt/sources.list.d/intel-graphics.list && \ wget -O- https://repositories.intel.com/graphics/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null && \ echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/graphics/ubuntu jammy arc" | tee /etc/apt/sources.list.d/intel.gpu.jammy.list && \ chmod 644 /usr/share/keyrings/intel-graphics.gpg RUN apt-get update -y \ -&& apt-get install -y curl libicu70 lsb-release git wget vim numactl python3 python3-pip +&& apt-get install -y curl libicu70 lsb-release git wget vim numactl python3 python3-pip ffmpeg libsm6 libxext6 libgl1 + +RUN git clone https://github.com/intel/pti-gpu && \ + cd pti-gpu/sdk && \ + mkdir build && \ + cd build && \ + cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=../cmake/toolchains/icpx_toolchain.cmake -DBUILD_TESTING=OFF .. && \ + make -j && \ + cmake --install . --config Release --prefix "/usr/local" COPY ./ /workspace/vllm diff --git a/README.md b/README.md index 9ae30f8d2de55..53749cb36b972 100644 --- a/README.md +++ b/README.md @@ -17,15 +17,16 @@ Easy, fast, and cheap LLM serving for everyone --- -**vLLM & NVIDIA Triton User Meetup (Monday, September 9, 5pm-9pm PT) at Fort Mason, San Francisco** +**vLLM, AMD, Anyscale Meet & Greet at [Ray Summit 2024](http://raysummit.anyscale.com) (Monday, Sept 30th, 5-7pm PT) at Marriott Marquis San Francisco** -We are excited to announce our sixth vLLM Meetup, in collaboration with NVIDIA Triton Team. -Join us to hear the vLLM's recent update about performance. -Register now [here](https://lu.ma/87q3nvnh) and be part of the event! +We are excited to announce our special vLLM event in collaboration with AMD and Anyscale. +Join us to learn more about recent advancements of vLLM on MI300X. +Register [here](https://lu.ma/db5ld9n5) and be a part of the event! --- *Latest News* 🔥 +- [2024/09] We hosted [the sixth vLLM meetup](https://lu.ma/87q3nvnh) with NVIDIA! Please find the meetup slides [here](https://docs.google.com/presentation/d/1wrLGwytQfaOTd5wCGSPNhoaW3nq0E-9wqyP7ny93xRs/edit?usp=sharing). - [2024/07] We hosted [the fifth vLLM meetup](https://lu.ma/lp0gyjqr) with AWS! Please find the meetup slides [here](https://docs.google.com/presentation/d/1RgUD8aCfcHocghoP3zmXzck9vX3RCI9yfUAB2Bbcl4Y/edit?usp=sharing). - [2024/07] In partnership with Meta, vLLM officially supports Llama 3.1 with FP8 quantization and pipeline parallelism! Please check out our blog post [here](https://blog.vllm.ai/2024/07/23/llama31.html). - [2024/06] We hosted [the fourth vLLM meetup](https://lu.ma/agivllm) with Cloudflare and BentoML! Please find the meetup slides [here](https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing). @@ -130,3 +131,10 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs year={2023} } ``` + +## Contact Us + +* For technical questions and feature requests, please use Github issues or discussions. +* For discussing with fellow users, please use Discord. +* For security disclosures, please use Github's security advisory feature. +* For collaborations and partnerships, please contact us at vllm-questions AT lists.berkeley.edu. \ No newline at end of file diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 97afd301c8f24..a39d1cf842f06 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -10,7 +10,7 @@ from tqdm import tqdm from vllm import LLM, SamplingParams -from vllm.engine.arg_utils import EngineArgs +from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs from vllm.inputs import PromptInputs from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import FlexibleArgumentParser @@ -205,13 +205,11 @@ def run_to_completion(profile_dir: Optional[str] = None): default=None, help=('path to save the pytorch profiler output. Can be visualized ' 'with ui.perfetto.dev or Tensorboard.')) - parser.add_argument( - "--device", - type=str, - default="auto", - choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"], - help='device type for vLLM execution, supporting CUDA, OpenVINO and ' - 'CPU.') + parser.add_argument("--device", + type=str, + default="auto", + choices=DEVICE_OPTIONS, + help='device type for vLLM execution') parser.add_argument('--block-size', type=int, default=16, diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 94549d84fb4e4..3f531ee82cc94 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -11,7 +11,7 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase) -from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs +from vllm.engine.arg_utils import DEVICE_OPTIONS, AsyncEngineArgs, EngineArgs from vllm.entrypoints.openai.api_server import ( build_async_engine_client_from_engine_args) from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS @@ -451,13 +451,11 @@ def main(args: argparse.Namespace): 'accuracy issues. FP8_E5M2 (without scaling) is only supported on ' 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' 'instead supported for common inference criteria.') - parser.add_argument( - "--device", - type=str, - default="auto", - choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"], - help='device type for vLLM execution, supporting CUDA, OpenVINO and ' - 'CPU.') + parser.add_argument("--device", + type=str, + default="auto", + choices=DEVICE_OPTIONS, + help='device type for vLLM execution') parser.add_argument( "--num-scheduler-steps", type=int, diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index ce0d9db3068c1..b0c23fee5b373 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -145,7 +145,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: v_scale, ) else: - ops.paged_attention_custom( + ops.paged_attention_rocm( output, exp_sums, max_logits, @@ -161,6 +161,8 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: max_seq_len, alibi_slopes, kv_cache_dtype, + k_scale, + v_scale, ) else: raise ValueError(f"Invalid version: {version}") diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 3ba3a2b6a93cd..8470e9ea9ebd9 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -1,4 +1,5 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +set(CMAKE_CXX_STANDARD 17) # # Define environment variables for special configurations @@ -83,12 +84,7 @@ endif() message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") -list(APPEND LIBS "numa") - - -# -# Define extension targets -# +list(APPEND LIBS dnnl numa) # # _C extension @@ -102,6 +98,16 @@ set(VLLM_EXT_SRC "csrc/cpu/pos_encoding.cpp" "csrc/cpu/torch_bindings.cpp") +if (AVX512_FOUND AND NOT AVX512_DISABLED) + set(VLLM_EXT_SRC + "csrc/cpu/quant.cpp" + ${VLLM_EXT_SRC}) +endif() + +# +# Define extension targets +# + define_gpu_extension_target( _C DESTINATION vllm diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 69998b45be70a..1ea6d2b0f090e 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -350,6 +350,7 @@ function (define_gpu_extension_target GPU_MOD_NAME) target_include_directories(${GPU_MOD_NAME} PRIVATE csrc ${GPU_INCLUDE_DIRECTORIES}) + # TODO: is torch_python_LIBRARY needed? target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${torch_python_LIBRARY} ${GPU_LIBRARIES}) diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp index f50620a5287d4..5b1d3d6442b2b 100644 --- a/csrc/cpu/cpu_types_x86.hpp +++ b/csrc/cpu/cpu_types_x86.hpp @@ -24,8 +24,8 @@ namespace vec_op { #define CPU_KERNEL_GUARD_OUT(NAME) #else #define CPU_KERNEL_GUARD_IN(NAME) \ - std::cout << #NAME << " invoked." << std::endl; -#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl; + RECORD_FUNCTION(#NAME, c10::ArrayRef({})); +#define CPU_KERNEL_GUARD_OUT(NAME) #endif #define FORCE_INLINE __attribute__((always_inline)) inline @@ -106,6 +106,12 @@ struct BF16Vec16 : public Vec { explicit BF16Vec16(const FP32Vec16 &); void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; } + + void save(void* ptr, const int elem_num) const { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + _mm256_mask_storeu_epi16(ptr, mask, reg); + } }; #ifdef __AVX512F__ @@ -313,8 +319,28 @@ struct FP32Vec16 : public Vec { return FP32Vec16(_mm512_div_ps(reg, b.reg)); } + FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const { + return FP32Vec16(_mm512_min_ps(max.reg, _mm512_max_ps(min.reg, reg))); + } + + FP32Vec16 max(const FP32Vec16& b) const { + return FP32Vec16(_mm512_max_ps(reg, b.reg)); + } + + FP32Vec16 max(const FP32Vec16& b, const int elem_num) const { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + return FP32Vec16(_mm512_mask_max_ps(reg, mask, reg, b.reg)); + } + + FP32Vec16 abs() const { + return FP32Vec16(_mm512_abs_ps(reg)); + } + float reduce_sum() const { return _mm512_reduce_add_ps(reg); } + float reduce_max() const { return _mm512_reduce_max_ps(reg); } + template float reduce_sub_sum(int idx) { static_assert(VEC_ELEM_NUM % group_size == 0); constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); @@ -323,6 +349,12 @@ struct FP32Vec16 : public Vec { } void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); } + + void save(float* ptr, const int elem_num) const { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + _mm512_mask_storeu_ps(ptr, mask, reg); + } }; #else struct FP32Vec16 : public Vec { @@ -433,6 +465,32 @@ struct FP32Vec16 : public Vec { }; #endif +#ifdef __AVX512F__ +struct INT8Vec16: public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + __m128i reg; + int8_t values[VEC_ELEM_NUM]; + }; + + __m128i reg; + + explicit INT8Vec16(const FP32Vec16& vec) : reg( + _mm512_cvtepi32_epi8(_mm512_cvt_roundps_epi32(vec.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) + ) {} + + void save(int8_t* ptr) const { + _mm_storeu_epi8(ptr, reg); + } + + void save(int8_t* ptr, const int elem_num) const { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + _mm_mask_storeu_epi8(ptr, mask, reg); + } +}; +#endif + template struct VecType { using vec_type = void; }; template using vec_t = typename VecType::vec_type; diff --git a/csrc/cpu/dnnl_helper.hpp b/csrc/cpu/dnnl_helper.hpp new file mode 100644 index 0000000000000..024ad4ae43da8 --- /dev/null +++ b/csrc/cpu/dnnl_helper.hpp @@ -0,0 +1,168 @@ +#ifndef DNNL_HELPER_HPP +#define DNNL_HELPER_HPP + +#include + +#include "oneapi/dnnl/dnnl.hpp" + +namespace { +template +struct DNNLType { + static constexpr dnnl::memory::data_type type = + dnnl::memory::data_type::undef; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16; +}; + +template +constexpr inline dnnl::memory::data_type get_dnnl_type() { + return DNNLType>::type; +} +}; // namespace + +template +class DNNLPrimitiveHelper { + public: + // I8 input GEMM kernel (C = a_scales * A @ (b_scales * B^T) + bias) + // A: [M, K], row-major + // B: [K, N], column-major + // C: [M, N], row-major + // bias: [N], row-major, optional + // a_scales: [MS] + // b_scales: [NS] + // Note: Due to the limitation of oneDNN + // (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is + // not supported. + template + static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c, + const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N, + dnnl_dim_t K, const float* a_scales, + const float* b_scales, dnnl_dim_t MS, + dnnl_dim_t NS) { + auto&& OutputType = get_dnnl_type(); + auto&& BiasType = get_dnnl_type(); + + dnnl::memory::desc a_md({M, K}, dnnl::memory::data_type::s8, {K, 1}); + dnnl::memory::desc b_md({K, N}, dnnl::memory::data_type::s8, {1, K}); + dnnl::memory::desc c_md({M, N}, OutputType, {N, 1}); + + dnnl::primitive_attr attr; + if constexpr (!InputNoScale) { + if (MS == 1) { + // per-tensor + attr.set_scales_mask(DNNL_ARG_SRC, 0); + } else { + // per-token + TORCH_CHECK(false, "per-token quantization is unsupported."); + } + } + + if (NS == 1) { + // per-tensor + attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); + } else { + // per-channel + attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2); + } + + dnnl::matmul::primitive_desc matmul_pd; + if (bias) { + dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1}); + matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, + bias_md, c_md, attr); + } else { + matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, + c_md, attr); + } + dnnl::matmul matmul(matmul_pd); + + auto& engine = default_engine(); + + dnnl::memory a_m(a_md, engine, (void*)a); + dnnl::memory b_m(b_md, engine, (void*)b); + dnnl::memory c_m(c_md, engine, (void*)c); + dnnl::memory a_scales_m({{MS}, dnnl::memory::data_type::f32, {1}}, engine, + (void*)a_scales); + dnnl::memory b_scales_m({{NS}, dnnl::memory::data_type::f32, {1}}, engine, + (void*)b_scales); + + auto& stream = default_stream(); + if constexpr (InputNoScale) { + if (bias) { + dnnl::memory::desc bias_md({N}, BiasType, {1}); + dnnl::memory bias_m(bias_md, engine, (void*)bias); + matmul.execute( + stream, { + {DNNL_ARG_SRC, a_m}, + {DNNL_ARG_WEIGHTS, b_m}, + {DNNL_ARG_BIAS, bias_m}, + {DNNL_ARG_DST, c_m}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, + }); + } else { + matmul.execute( + stream, { + {DNNL_ARG_SRC, a_m}, + {DNNL_ARG_WEIGHTS, b_m}, + {DNNL_ARG_DST, c_m}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, + }); + } + } else { + if (bias) { + dnnl::memory::desc bias_md({N}, BiasType, {1}); + dnnl::memory bias_m(bias_md, engine, (void*)bias); + matmul.execute( + stream, { + {DNNL_ARG_SRC, a_m}, + {DNNL_ARG_WEIGHTS, b_m}, + {DNNL_ARG_BIAS, bias_m}, + {DNNL_ARG_DST, c_m}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, + }); + } else { + matmul.execute( + stream, { + {DNNL_ARG_SRC, a_m}, + {DNNL_ARG_WEIGHTS, b_m}, + {DNNL_ARG_DST, c_m}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, + }); + } + } + stream.wait(); + } + + private: + static dnnl::engine& default_engine() { + static dnnl::engine engine(dnnl::engine::kind::cpu, 0); + return engine; + } + + static dnnl::stream& default_stream() { + static dnnl::stream stream(default_engine()); + return stream; + } +}; + +#endif diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp new file mode 100644 index 0000000000000..0cfc19097fded --- /dev/null +++ b/csrc/cpu/quant.cpp @@ -0,0 +1,294 @@ +#include "cpu_types.hpp" +#include "dnnl_helper.hpp" + +namespace { +template +struct KernelVecType { + using load_vec_type = void; + using cvt_vec_type = void; +}; + +template <> +struct KernelVecType { + using load_vec_type = vec_op::FP32Vec16; + using cvt_vec_type = vec_op::FP32Vec16; +}; + +template <> +struct KernelVecType { + using load_vec_type = vec_op::BF16Vec16; + using cvt_vec_type = vec_op::FP32Vec16; +}; + +#ifdef __AVX512F__ +template +void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + const float* scale, const int num_tokens, + const int hidden_size) { + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + constexpr float i8_min = + static_cast(std::numeric_limits::min()); + constexpr float i8_max = + static_cast(std::numeric_limits::max()); + const cvt_vec_t inv_scale(1.0 / *scale); + const cvt_vec_t i8_min_vec(i8_min); + const cvt_vec_t i8_max_vec(i8_max); + + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale).clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j); + } + + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale).clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + + if (j + vec_elem_num == hidden_size) { + elems_int8.save(output + i * hidden_size + j); + } else { + elems_int8.save(output + i * hidden_size + j, hidden_size - j); + } + } +} + +template +void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + float* scale, const int num_tokens, + const int hidden_size) { + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + cvt_vec_t max_abs(0.0); + { + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + max_abs = max_abs.max(elems_fp32.abs()); + } + + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + + if (j + vec_elem_num == hidden_size) { + max_abs = max_abs.max(elems_fp32.abs()); + } else { + max_abs = max_abs.max(elems_fp32.abs(), hidden_size - j); + } + } + + float scale_val = max_abs.reduce_max() / 127.0f; + scale[i] = scale_val; + const cvt_vec_t inv_scale(1.0 / scale_val); + + { + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j); + } + + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale); + vec_op::INT8Vec16 elems_int8(elems_fp32); + + if (j + vec_elem_num == hidden_size) { + elems_int8.save(output + i * hidden_size + j); + } else { + elems_int8.save(output + i * hidden_size + j, hidden_size - j); + } + } + } +} + +template +void dynamic_output_scale_impl(const float* input, scalar_t* output, + const float* scale, const scalar_t* bias, + const int num_tokens, const int hidden_size) { + CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl) + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + int j = 0; + cvt_vec_t token_scale_vec(scale[i]); + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + cvt_vec_t elems_fp32(input + i * hidden_size + j); + elems_fp32 = elems_fp32 * token_scale_vec; + + if constexpr (Bias) { + load_vec_t bias_vec(bias + j); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j); + } + + cvt_vec_t elems_fp32(input + i * hidden_size + j); + elems_fp32 = elems_fp32 * token_scale_vec; + + if constexpr (Bias) { + load_vec_t bias_vec(bias + j); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + + load_vec_t elems_out(elems_fp32); + + if (j + vec_elem_num == hidden_size) { + elems_out.save(output + i * hidden_size + j); + } else { + elems_out.save(output + i * hidden_size + j, hidden_size - j); + } + } +} +#else +template +void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + const float* scale, const int num_tokens, + const int hidden_size) { + TORCH_CHECK(false, "static_scaled_int8_quant_impl requires AVX512 support.") +} + +template +void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + float* scale, const int num_tokens, + const int hidden_size) { + TORCH_CHECK(false, "dynamic_scaled_int8_quant_impl requires AVX512 support.") +} + +template +void dynamic_output_scale_impl() { + TORCH_CHECK(false, "dynamic_output_scale_impl requires AVX512 support.") +} +#endif +} // namespace + +void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major + const torch::Tensor& a, // [M, IC], row-major + const torch::Tensor& b, // [IC, OC], column-major + const torch::Tensor& a_scales, // [1] or [M] + const torch::Tensor& b_scales, // [1] or [OC] + const c10::optional& bias // [OC] +) { + CPU_KERNEL_GUARD_IN(cutlass_scaled_mm) + // Checks for conformality + TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, + "int8_scaled_mm only supports INT8 inputs.") + TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); + TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && + b.size(1) == c.size(1)); + TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); + TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); + + // Check for strides and alignment + TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major + TORCH_CHECK(b.stride(0) == 1); // Column-major + TORCH_CHECK(c.stride(0) % 16 == 0 && + b.stride(1) % 16 == 0); // 16 Byte Alignment + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + + if (bias) { + TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() && + bias->dim() == 1); + } + + VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "cutlass_scaled_mm", [&] { + if (a_scales.numel() != 1) { + // per-token + // Note: oneDNN doesn't support per-token activation quantization + torch::Tensor tmp_fp32_out = + torch::empty_like(c, ::at::ScalarType::Float); + DNNLPrimitiveHelper::gemm_s8s8_jit( + a.data_ptr(), b.data_ptr(), + tmp_fp32_out.data_ptr(), (void*)(0), a.size(0), b.size(1), + a.size(1), (float*)(0), b_scales.data_ptr(), 0, + b_scales.numel()); + if (bias.has_value()) { + dynamic_output_scale_impl( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), bias->data_ptr(), c.size(0), + c.size(1)); + } else { + dynamic_output_scale_impl( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), (scalar_t*)(0), c.size(0), c.size(1)); + } + } else { + // per-tensor + if (bias.has_value()) { + DNNLPrimitiveHelper::gemm_s8s8_jit( + a.data_ptr(), b.data_ptr(), c.data_ptr(), + bias->data_ptr(), a.size(0), b.size(1), a.size(1), + a_scales.data_ptr(), b_scales.data_ptr(), + a_scales.numel(), b_scales.numel()); + } else { + DNNLPrimitiveHelper::gemm_s8s8_jit( + a.data_ptr(), b.data_ptr(), c.data_ptr(), + (void*)(0), a.size(0), b.size(1), a.size(1), + a_scales.data_ptr(), b_scales.data_ptr(), + a_scales.numel(), b_scales.numel()); + } + } + }); +} + +// static-per-tensor quantization. +void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] + const torch::Tensor& input, // [..., hidden_size] + const torch::Tensor& scale) { + CPU_KERNEL_GUARD_IN(static_scaled_int8_quant) + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(scale.numel() == 1); + + const int hidden_size = input.size(-1); + const int num_tokens = input.numel() / hidden_size; + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "static_scaled_int8_quant_impl", [&] { + static_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), num_tokens, hidden_size); + }); +} + +// dynamic-per-token quantization. +void dynamic_scaled_int8_quant( + torch::Tensor& out, // [..., hidden_size] + const torch::Tensor& input, // [..., hidden_size] + torch::Tensor& scale // [..., 1] +) { + CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant) + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); + + int const hidden_size = input.size(-1); + int const num_tokens = input.numel() / hidden_size; + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] { + dynamic_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), num_tokens, hidden_size); + }); +} diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index cf7d977da7c1c..b45da1b386b5b 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -4,7 +4,12 @@ #include -void init_cpu_threads_env(const std::string& cpu_ids); +std::string init_cpu_threads_env(const std::string& cpu_ids); + +void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a, + const torch::Tensor& b, const torch::Tensor& a_scales, + const torch::Tensor& b_scales, + const c10::optional& bias); TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops @@ -27,8 +32,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // PagedAttention V2. ops.def( "paged_attention_v2(" - " Tensor! out, Tensor exp_sums, Tensor max_logits," - " Tensor tmp_out, Tensor query, Tensor key_cache," + " Tensor! out, Tensor! exp_sums, Tensor! max_logits," + " Tensor! tmp_out, Tensor query, Tensor key_cache," " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," @@ -84,6 +89,28 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor! key, int head_size," " Tensor cos_sin_cache, bool is_neox) -> ()"); ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); + + // Quantization +#ifdef __AVX512F__ + // Compute int8 quantized tensor for given scaling factor. + ops.def( + "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> " + "()"); + ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant); + // Compute int8 quantized tensor and scaling factor + ops.def( + "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> " + "()"); + ops.impl("dynamic_scaled_int8_quant", torch::kCPU, + &dynamic_scaled_int8_quant); + // W8A8 GEMM, supporting symmetric per-tensor or per-row/column + // quantization. + ops.def( + "cutlass_scaled_mm(Tensor! out, Tensor a," + " Tensor b, Tensor a_scales," + " Tensor b_scales, Tensor? bias) -> ()"); + ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm); +#endif } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { @@ -95,8 +122,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { // Copy the cache blocks from src to dst. cache_ops.def( - "copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor " - "block_mapping) -> ()"); + "copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, " + "Tensor block_mapping) -> ()"); cache_ops.impl("copy_blocks", torch::kCPU, ©_blocks); // Reshape the key and value tensors and cache them. @@ -111,7 +138,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) { // CPU utils - utils.def("init_cpu_threads_env(str cpu_ids) -> ()", &init_cpu_threads_env); + utils.def("init_cpu_threads_env(str cpu_ids) -> str", &init_cpu_threads_env); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/cpu/utils.cpp b/csrc/cpu/utils.cpp index 5782580baa861..1138a55df2f05 100644 --- a/csrc/cpu/utils.cpp +++ b/csrc/cpu/utils.cpp @@ -5,7 +5,7 @@ #include "cpu_types.hpp" -void init_cpu_threads_env(const std::string& cpu_ids) { +std::string init_cpu_threads_env(const std::string& cpu_ids) { bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str()); TORCH_CHECK(omp_cpu_mask->size > 0); std::vector omp_cpu_ids; @@ -51,15 +51,40 @@ void init_cpu_threads_env(const std::string& cpu_ids) { torch::set_num_threads((int)omp_cpu_ids.size()); TORCH_CHECK_EQ(omp_cpu_ids.size(), torch::get_num_threads()); TORCH_CHECK_EQ(omp_cpu_ids.size(), omp_get_max_threads()); + + std::vector> thread_core_mapping; + thread_core_mapping.reserve(omp_cpu_ids.size()); + omp_lock_t writelock; + omp_init_lock(&writelock); + #pragma omp parallel for schedule(static, 1) for (size_t i = 0; i < omp_cpu_ids.size(); ++i) { - cpu_set_t* mask = CPU_ALLOC(omp_cpu_mask->size); - size_t size = CPU_ALLOC_SIZE(omp_cpu_mask->size); - CPU_ZERO_S(size, mask); - CPU_SET_S(omp_cpu_ids[i], size, mask); - sched_setaffinity(0, sizeof(cpu_set_t), mask); - CPU_FREE(mask); + cpu_set_t mask; + CPU_ZERO(&mask); + CPU_SET(omp_cpu_ids[i], &mask); + int ret = sched_setaffinity(0, sizeof(cpu_set_t), &mask); + if (ret == -1) { + TORCH_CHECK(false, + "sched_setaffinity failed. errno: " + std::to_string(errno)); + } + + omp_set_lock(&writelock); + thread_core_mapping.emplace_back(gettid(), omp_cpu_ids[i]); + omp_unset_lock(&writelock); } + omp_destroy_lock(&writelock); + numa_free_nodemask(omp_cpu_mask); + + std::stringstream ss; + ss << "OMP threads binding of Process " << getpid() << ":\n"; + std::sort(thread_core_mapping.begin(), thread_core_mapping.end(), + [](auto&& a, auto&& b) { return a.second < b.second; }); + for (auto&& item : thread_core_mapping) { + ss << "\t" + << "OMP tid: " << item.first << ", core " << item.second << "\n"; + } + + return ss.str(); } diff --git a/csrc/custom/custom_ops.h b/csrc/custom/custom_ops.h deleted file mode 100644 index f8ab5ee5544df..0000000000000 --- a/csrc/custom/custom_ops.h +++ /dev/null @@ -1,22 +0,0 @@ -#pragma once -#include - -void LLMM_Silu(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, - const int64_t rows_per_block); - -void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, - const int64_t rows_per_block); - -void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, - const int64_t N_in, const int64_t CuCount); - -void paged_attention_custom(torch::Tensor& out, torch::Tensor& exp_sums, - torch::Tensor& max_logits, torch::Tensor& tmp_out, - torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int64_t num_kv_heads, - double scale, torch::Tensor& block_tables, - torch::Tensor& context_lens, int64_t block_size, - int64_t max_context_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, - double v_scale); diff --git a/csrc/custom/torch_bindings.cpp b/csrc/custom/torch_bindings.cpp deleted file mode 100644 index dc26ac5e57204..0000000000000 --- a/csrc/custom/torch_bindings.cpp +++ /dev/null @@ -1,31 +0,0 @@ -#include "core/registration.h" -#include "custom_ops.h" - -TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, custom_ops) { - custom_ops.def( - "LLMM1(Tensor in_a, Tensor in_b, Tensor! out_c, int rows_per_block) -> " - "()"); - custom_ops.impl("LLMM1", torch::kCUDA, &LLMM1); - custom_ops.def( - "LLMM_Silu(Tensor in_a, Tensor in_b, Tensor! out_c, int rows_per_block) " - "-> ()"); - custom_ops.impl("LLMM_Silu", torch::kCUDA, &LLMM_Silu); - custom_ops.def( - "paged_attention_custom(Tensor! out, Tensor exp_sums," - " Tensor max_logits, Tensor tmp_out," - " Tensor query, Tensor key_cache," - " Tensor value_cache, int num_kv_heads," - " float scale, Tensor block_tables," - " Tensor context_lens, int block_size," - " int max_context_len," - " Tensor? alibi_slopes," - " str kv_cache_dtype," - " float k_scale, float v_scale) -> ()"); - custom_ops.impl("paged_attention_custom", torch::kCUDA, - &paged_attention_custom); - custom_ops.def( - "wvSpltK(Tensor in_a, Tensor in_b, Tensor! out_c, int N_in," - " int CuCount) -> ()"); - custom_ops.impl("wvSpltK", torch::kCUDA, &wvSpltK); -} -REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 1e170e80d2f70..92184f43c9eb0 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -1737,4 +1737,4 @@ torch::Tensor marlin_gemm_moe( moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par, replicate_input, apply_weights); return c; -} \ No newline at end of file +} diff --git a/csrc/moe/marlin_moe_ops.h b/csrc/moe/marlin_moe_ops.h index 01ba8ff69850d..43d264e0770d6 100644 --- a/csrc/moe/marlin_moe_ops.h +++ b/csrc/moe/marlin_moe_ops.h @@ -9,4 +9,4 @@ torch::Tensor marlin_gemm_moe( const torch::Tensor& g_idx, const torch::Tensor& perm, torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, - bool replicate_input, bool apply_weights); \ No newline at end of file + bool replicate_input, bool apply_weights); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index d4d43e2c601b5..8a0e625b43fa1 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -16,7 +16,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int " "size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, " "bool replicate_input, bool apply_weights) -> Tensor"); - m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); #endif } diff --git a/csrc/ops.h b/csrc/ops.h index 10f337f98ccbc..dab0b8c6dbf2f 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -65,10 +65,21 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input); void gelu_quick(torch::Tensor& out, torch::Tensor& input); -void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size, - torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, - torch::Tensor& input_positions, torch::Tensor& seq_lens, - torch::Tensor& slot_mapping, torch::Tensor& block_tables); +void advance_step_flashattn(int64_t num_seqs, int64_t num_queries, + int64_t block_size, torch::Tensor& input_tokens, + torch::Tensor& sampled_token_ids, + torch::Tensor& input_positions, + torch::Tensor& seq_lens, + torch::Tensor& slot_mapping, + torch::Tensor& block_tables); + +void advance_step_flashinfer( + int64_t num_seqs, int64_t num_queries, int64_t block_size, + torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, + torch::Tensor& input_positions, torch::Tensor& seq_lens, + torch::Tensor& slot_mapping, torch::Tensor& block_tables, + torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr, + torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds); #ifndef USE_ROCM torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, @@ -134,9 +145,17 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits); +torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight, + torch::Tensor& perm, c10::SymInt size_k, + c10::SymInt size_n, int64_t num_bits); + torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits); +torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight, + c10::SymInt size_k, c10::SymInt size_n, + int64_t num_bits); + torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m, int64_t n); diff --git a/csrc/prepare_inputs/advance_step.cu b/csrc/prepare_inputs/advance_step.cu index 0e537ddd6c4cd..a9d08ca0dc14c 100644 --- a/csrc/prepare_inputs/advance_step.cu +++ b/csrc/prepare_inputs/advance_step.cu @@ -12,13 +12,11 @@ namespace prepare_inputs { // template -__global__ void advance_step_kernel(int num_seqs, int num_queries, - int block_size, long* input_tokens_ptr, - long const* sampled_token_ids_ptr, - long* input_positions_ptr, - int* seq_lens_ptr, long* slot_mapping_ptr, - int const* block_tables_ptr, - int64_t const block_tables_stride) { +__global__ void advance_step_flashattn_kernel( + int num_seqs, int num_queries, int block_size, long* input_tokens_ptr, + long const* sampled_token_ids_ptr, long* input_positions_ptr, + int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr, + int64_t const block_tables_stride) { int num_query_blocks = div_ceil(num_queries, num_threads); if (blockIdx.x >= num_query_blocks) { @@ -79,16 +77,91 @@ inline void verify_tensor(std::string const& name, torch::Tensor& t, } } -void advance_step(int num_seqs, int num_queries, int block_size, - torch::Tensor& input_tokens, // type: long - torch::Tensor& sampled_token_ids, // type: long - torch::Tensor& input_positions, // type: long - torch::Tensor& seq_lens, // type: int - torch::Tensor& slot_mapping, // type: long - torch::Tensor& block_tables) { // type: int +__global__ void advance_step_flashinfer_kernel( + int num_threads, int num_seqs, int num_queries, int block_size, + long* input_tokens_ptr, long const* sampled_token_ids_ptr, + long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr, + int const* block_tables_ptr, int64_t const block_tables_stride, + int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) { + int num_query_blocks = div_ceil(num_queries, num_threads); + + if (blockIdx.x < num_query_blocks) { + int cur_query_id = blockIdx.x * num_threads + threadIdx.x; + + if (cur_query_id < num_queries) { + // Update input_tokens + input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id]; + + int seq_len = seq_lens_ptr[cur_query_id]; + int next_seq_len = seq_len + 1; + int next_input_pos = next_seq_len - 1; + + // Update seq_lens + seq_lens_ptr[cur_query_id] = next_seq_len; + // Update input_positions + input_positions_ptr[cur_query_id] = next_input_pos; + + int const* seq_block_tables_ptr = + block_tables_ptr + block_tables_stride * cur_query_id; + + int block_index = next_input_pos / block_size; + int block_offset = next_input_pos % block_size; + + // Update paged_kv_last_page_len + paged_kv_last_page_len_ptr[cur_query_id] = block_offset + 1; + + int slot_num = + seq_block_tables_ptr[block_index] * block_size + block_offset; + // Update slot_mapping + slot_mapping_ptr[cur_query_id] = slot_num; + block_table_bound_ptr[cur_query_id] = div_ceil(next_seq_len, block_size); + } + } +} + +__global__ void advance_step_flashinfer_indptr_kernel( + int num_threads, int num_seqs, int num_queries, int* paged_kv_indptr_ptr, + int* block_table_bound_ptr) { + int idx = blockIdx.x * num_threads + threadIdx.x; + + // Update paged_kv_indptr + if (idx < num_queries) { + int sum = 0; + for (int i = 0; i <= idx; ++i) { + sum += block_table_bound_ptr[i]; + } + paged_kv_indptr_ptr[idx + 1] = sum; + } +} + +__global__ void advance_step_flashinfer_indices_kernel( + int num_threads, int num_seqs, int num_queries, int const* block_tables_ptr, + int64_t const block_tables_stride, int* paged_kv_indices_ptr, + int* paged_kv_indptr_ptr, int* block_table_bound_ptr) { + int idx = blockIdx.x * num_threads + threadIdx.x; + int row = idx / block_tables_stride; + int col = idx % block_tables_stride; + + if (row < num_queries && col < block_table_bound_ptr[row]) { + paged_kv_indices_ptr[paged_kv_indptr_ptr[row] + col] = + block_tables_ptr[row * block_tables_stride + col]; + } + // if cudagraph, fill padded seqs with the last valid seq's indptr + if (num_queries < row && row <= num_seqs) { + paged_kv_indptr_ptr[row] = paged_kv_indptr_ptr[num_queries]; + } +} + +void advance_step_flashattn(int num_seqs, int num_queries, int block_size, + torch::Tensor& input_tokens, // type: long + torch::Tensor& sampled_token_ids, // type: long + torch::Tensor& input_positions, // type: long + torch::Tensor& seq_lens, // type: int + torch::Tensor& slot_mapping, // type: long + torch::Tensor& block_tables) { // type: int if (logging) { - printf("advance_step:\n"); + printf("advance_step_flashattn:\n"); printf(" num_seqs = %d\n", num_seqs); printf(" num_queries = %d\n", num_queries); printf(" block_size = %d\n", block_size); @@ -108,24 +181,126 @@ void advance_step(int num_seqs, int num_queries, int block_size, int blocks; cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); - advance_step_kernel<<>>( - num_seqs, num_queries, block_size, + advance_step_flashattn_kernel + <<>>( + num_seqs, num_queries, block_size, + reinterpret_cast(input_tokens.data_ptr()), + reinterpret_cast(sampled_token_ids.data_ptr()), + reinterpret_cast(input_positions.data_ptr()), + reinterpret_cast(seq_lens.data_ptr()), + reinterpret_cast(slot_mapping.data_ptr()), + reinterpret_cast(block_tables.data_ptr()), + block_tables.stride(0)); +} + +void advance_step_flashinfer( + int num_seqs, int num_queries, int block_size, + torch::Tensor& input_tokens, // type: long + torch::Tensor& sampled_token_ids, // type: long + torch::Tensor& input_positions, // type: long + torch::Tensor& seq_lens, // type: int + torch::Tensor& slot_mapping, // type: long + torch::Tensor& block_tables, // type: int + torch::Tensor& paged_kv_indices, // type: int + torch::Tensor& paged_kv_indptr, // type: int + torch::Tensor& paged_kv_last_page_len, // type: int + torch::Tensor& block_table_bound) { // type: int + + if (logging) { + printf("advance_step_flashinfer:\n"); + printf(" num_seqs = %d\n", num_seqs); + printf(" num_queries = %d\n", num_queries); + printf(" block_size = %d\n", block_size); + printf(" block_tables.stride(0) = %d\n", block_tables.stride(0)); + } + // Verify all tensors + verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong); + // verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1, + // at::kLong); + verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong); + verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt); + verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong); + verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt); + + verify_tensor("paged_kv_indices", paged_kv_indices, -1, -1, at::kInt); + verify_tensor("paged_kv_indptr", paged_kv_indptr, num_seqs + 1, -1, at::kInt); + verify_tensor("paged_kv_last_page_len", paged_kv_last_page_len, num_seqs, -1, + at::kInt); + + verify_tensor("block_table_bound", block_table_bound, num_seqs, -1, at::kInt); + + int dev = sampled_token_ids.get_device(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); + + int blocks; + int threads; + cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); + cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev); + if (logging) { + printf("launching kernel with %d blocks\n", blocks); + } + + // TODO(will): support arbitrary block_tables stride + if ((blocks * threads) / block_tables.stride(0) < num_queries) { + TORCH_CHECK(false, + "multi-step: not enough threads to map block_table to" + "FlashInfer's paged_kv_indices on GPU. Try reducing the number " + "of seqs,", + " increasing the block size or take smaller steps.", + " num_queries = ", num_queries, + " block_tables.stride(0) = ", block_tables.stride(0), + " blocks = ", blocks, " max_threads = ", threads); + } + + advance_step_flashinfer_kernel<<>>( + threads, num_seqs, num_queries, block_size, reinterpret_cast(input_tokens.data_ptr()), reinterpret_cast(sampled_token_ids.data_ptr()), reinterpret_cast(input_positions.data_ptr()), reinterpret_cast(seq_lens.data_ptr()), reinterpret_cast(slot_mapping.data_ptr()), reinterpret_cast(block_tables.data_ptr()), - block_tables.stride(0)); + block_tables.stride(0), + reinterpret_cast(paged_kv_last_page_len.data_ptr()), + reinterpret_cast(block_table_bound.data_ptr())); + + advance_step_flashinfer_indptr_kernel<<>>( + threads, num_seqs, num_queries, + reinterpret_cast(paged_kv_indptr.data_ptr()), + reinterpret_cast(block_table_bound.data_ptr())); + + advance_step_flashinfer_indices_kernel<<>>( + threads, num_seqs, num_queries, + reinterpret_cast(block_tables.data_ptr()), + block_tables.stride(0), + reinterpret_cast(paged_kv_indices.data_ptr()), + reinterpret_cast(paged_kv_indptr.data_ptr()), + reinterpret_cast(block_table_bound.data_ptr())); } } // namespace prepare_inputs -void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size, - torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, - torch::Tensor& input_positions, torch::Tensor& seq_lens, - torch::Tensor& slot_mapping, torch::Tensor& block_tables) { - prepare_inputs::advance_step(num_seqs, num_queries, block_size, input_tokens, - sampled_token_ids, input_positions, seq_lens, - slot_mapping, block_tables); +void advance_step_flashattn(int64_t num_seqs, int64_t num_queries, + int64_t block_size, torch::Tensor& input_tokens, + torch::Tensor& sampled_token_ids, + torch::Tensor& input_positions, + torch::Tensor& seq_lens, + torch::Tensor& slot_mapping, + torch::Tensor& block_tables) { + prepare_inputs::advance_step_flashattn( + num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, + input_positions, seq_lens, slot_mapping, block_tables); +} + +void advance_step_flashinfer( + int64_t num_seqs, int64_t num_queries, int64_t block_size, + torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, + torch::Tensor& input_positions, torch::Tensor& seq_lens, + torch::Tensor& slot_mapping, torch::Tensor& block_tables, + torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr, + torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bound) { + prepare_inputs::advance_step_flashinfer( + num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, + input_positions, seq_lens, slot_mapping, block_tables, paged_kv_indices, + paged_kv_indptr, paged_kv_last_page_len, block_table_bound); } \ No newline at end of file diff --git a/csrc/quantization/gguf/dequantize.cuh b/csrc/quantization/gguf/dequantize.cuh index 2069fba759ea0..c012262e49015 100644 --- a/csrc/quantization/gguf/dequantize.cuh +++ b/csrc/quantization/gguf/dequantize.cuh @@ -353,18 +353,47 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_ template static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) { - const int i = blockIdx.x; + const int64_t i = blockIdx.x; const block_iq1_s * x = (const block_iq1_s *) vx; - const int tid = threadIdx.x; - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA; + const float d = __half2float(x[i].d) * (2*((x[i].qh[ib] >> 12) & 7) + 1); + uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32; + grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)]; + grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f; + grid32[0] &= 0x0f0f0f0f; + for (int j = 0; j < 8; ++j) { + y[j] = __float2half(d * (q[j] + delta)); + } +} + +template +static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int64_t i = blockIdx.x; + const block_iq1_m * x = (const block_iq1_m *) vx; + + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; - const int i8 = 4*ib+il; - uint8_t h = x[i].scales[i8/2] >> 4*(i8%2); - const int8_t * grid = (const int8_t *)(iq1s_grid + (x[i].qs[i8] | ((h & 8) << 5))); - const float d = __half2float(x[i].d) * (2*(h & 7) + 1); - for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j]); + const uint16_t * sc = (const uint16_t *)x[i].scales; + iq1m_scale_t scale; + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + const int64_t ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4); + const float d = __half2float(scale.f16) * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1); + const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA; + uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32; + grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)]; + grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f; + grid32[0] &= 0x0f0f0f0f; + for (int j = 0; j < 8; ++j) { + y[j] = __float2half(d * (q[j] + delta)); + } } template @@ -475,6 +504,12 @@ static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, c dequantize_block_iq1_s<<>>(vx, y); } +template +static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_iq1_m<<>>(vx, y); +} + template static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int nb = (k + QK_K - 1) / QK_K; @@ -525,6 +560,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(int64_t type) { return dequantize_row_iq2_s_cuda; case 23: return dequantize_row_iq4_xs_cuda; + case 29: + return dequantize_row_iq1_m_cuda; default: return nullptr; } diff --git a/csrc/quantization/gguf/ggml-common.h b/csrc/quantization/gguf/ggml-common.h index d7989d84bf68e..fba94fd1d157b 100644 --- a/csrc/quantization/gguf/ggml-common.h +++ b/csrc/quantization/gguf/ggml-common.h @@ -149,14 +149,30 @@ typedef struct { uint8_t scales[IQ3S_N_SCALE]; } block_iq3_s; +// 1.5625 bpw #define QR1_S 8 #define QI1_S (QK_K / (4*QR1_S)) typedef struct { half d; - uint8_t qs[QK_K/8]; - uint8_t scales[QK_K/16]; + uint8_t qs[QK_K/8]; + uint16_t qh[QK_K/32]; } block_iq1_s; +// 1.75 bpw +#define QR1_M 8 +#define QI1_M (QK_K / (4*QR1_M)) +typedef struct { + uint8_t qs[QK_K/8]; // grid index, low 8 bits + uint8_t qh[QK_K/16]; // grid index, high 3 bits + grid shift bit (for two groups of 8) + uint8_t scales[QK_K/32]; // 3-bit block scales (4-bit if QK_K == 64) +} block_iq1_m; + +// Used by IQ1_M quants +typedef union { + half f16; + uint16_t u16; +} iq1m_scale_t; + #define QK4_NL 32 #define QR4_NL 2 #define QI4_NL (QK4_NL / (4*QR4_NL)) @@ -733,135 +749,265 @@ static const __device__ uint32_t iq3xs_grid[512] = { 0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404, }; -static const __device__ uint64_t iq1s_grid[512] = { - 0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000, - 0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01, - 0xffffff00ff000000, 0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100, - 0xffffff0000010000, 0xffffff0001000000, 0xffffff01ffff00ff, 0xffffff01ff01ff00, - 0xffffff01ff010100, 0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101ff0101, - 0xffffff0101010100, 0xffff00ffff00ff01, 0xffff00ffff0000ff, 0xffff00ff00ff0100, - 0xffff00ff0100ff00, 0xffff00ff010001ff, 0xffff0000ff0101ff, 0xffff000000ffff00, - 0xffff000000000000, 0xffff00000001ff01, 0xffff000001000101, 0xffff0000010100ff, - 0xffff0001ffff0100, 0xffff00010000ff00, 0xffff000100010101, 0xffff000101000000, - 0xffff01ffffff0000, 0xffff01ffff01ffff, 0xffff01ffff010100, 0xffff01ff00000000, - 0xffff01ff01ffffff, 0xffff01ff01ff0001, 0xffff01ff0101ffff, 0xffff01ff01010001, - 0xffff0100ffffff01, 0xffff01000000ffff, 0xffff010000000100, 0xffff010001ff01ff, - 0xffff010001000000, 0xffff0101ff000000, 0xffff0101000101ff, 0xffff010101ffff01, - 0xffff01010101ff00, 0xff00ffffff000000, 0xff00ffff00ffff00, 0xff00ffff00000001, - 0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000, 0xff00ff00ff00ff00, - 0xff00ff00ff0000ff, 0xff00ff00ff000100, 0xff00ff00ff010001, 0xff00ff0000ff0001, - 0xff00ff000000ffff, 0xff00ff0000000000, 0xff00ff000001ff00, 0xff00ff0000010100, - 0xff00ff0001ff0000, 0xff00ff000100ff00, 0xff00ff0001000100, 0xff00ff01ff000000, - 0xff00ff0100ff0000, 0xff00ff01000001ff, 0xff00ff0101010001, 0xff0000ff00000000, - 0xff0000ff0001ff00, 0xff0000ff00010100, 0xff000000ffff0101, 0xff000000ff000000, - 0xff000000ff01ff00, 0xff00000000ff0000, 0xff0000000000ff00, 0xff000000000000ff, - 0xff00000000000000, 0xff00000000000001, 0xff00000000000100, 0xff0000000001ffff, - 0xff00000000010000, 0xff00000001000000, 0xff00000001010100, 0xff000001ff00ff01, - 0xff000001ff0100ff, 0xff00000100000000, 0xff0000010001ff00, 0xff00000101ff0100, - 0xff0000010100ff00, 0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff000100ff, - 0xff0001ff01000000, 0xff000100ff0001ff, 0xff0001000000ff01, 0xff00010000000000, - 0xff00010000010001, 0xff00010000010100, 0xff00010001ffff00, 0xff00010001ff0101, - 0xff00010001010000, 0xff000101ffffffff, 0xff000101ff000101, 0xff00010101ff00ff, - 0xff00010101000001, 0xff000101010100ff, 0xff01ffffff000101, 0xff01ffffff01ffff, - 0xff01ffffff01ff01, 0xff01ffffff0101ff, 0xff01ffff00000000, 0xff01ffff01ff0001, - 0xff01ffff0101ff01, 0xff01ff00ff000000, 0xff01ff0000ff0100, 0xff01ff000000ff01, - 0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ff01ff00, 0xff01ff0100000101, - 0xff0100ffffff0000, 0xff0100ffff010000, 0xff0100ff01ff00ff, 0xff0100ff01000100, - 0xff0100ff010100ff, 0xff010000ffffff01, 0xff01000000000000, 0xff0100000101ff00, - 0xff010001ffff00ff, 0xff010001ff000100, 0xff01000100ffff00, 0xff01000100010001, - 0xff01000101ff0001, 0xff010001010001ff, 0xff0101ffffffffff, 0xff0101ffff01ffff, - 0xff0101ffff010101, 0xff0101ff0000ff00, 0xff0101ff01010001, 0xff010100ff000000, - 0xff010100ff01ff01, 0xff01010000ff0001, 0xff01010000000100, 0xff01010001000000, - 0xff0101010100ffff, 0x00ffffff0000ff01, 0x00ffffff000000ff, 0x00ffffff00000100, - 0x00ffffff00010000, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100, - 0x00ffff0000000000, 0x00ffff0001000100, 0x00ffff0001010001, 0x00ffff01ff00ff01, - 0x00ffff0100ff0100, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ffff0101ff00ff, - 0x00ffff010101ff00, 0x00ff00ffffffffff, 0x00ff00ffffff01ff, 0x00ff00ffff000101, - 0x00ff00ff00000000, 0x00ff00ff000101ff, 0x00ff00ff01010101, 0x00ff0000ff000000, - 0x00ff0000ff01ffff, 0x00ff000000ff0000, 0x00ff00000000ff00, 0x00ff0000000000ff, - 0x00ff000000000000, 0x00ff000000000001, 0x00ff000000000100, 0x00ff000000010000, - 0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff000101, 0x00ff000100ffffff, - 0x00ff000100000000, 0x00ff0001010001ff, 0x00ff01ffff000000, 0x00ff01ff0001ff00, - 0x00ff01ff01ff0100, 0x00ff0100ff01ff01, 0x00ff010000ff00ff, 0x00ff010000ff0101, - 0x00ff010000000000, 0x00ff010000010101, 0x00ff01000100ff00, 0x00ff010001010000, - 0x00ff0101ffffff00, 0x00ff01010000ff01, 0x00ff010100000100, 0x00ff010101ff0000, - 0x0000ffffffff0100, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000, - 0x0000ffff00000000, 0x0000ffff00010101, 0x0000ffff01ffff01, 0x0000ffff01000100, - 0x0000ff00ff000000, 0x0000ff00ff01ff00, 0x0000ff00ff0101ff, 0x0000ff0000ff0000, - 0x0000ff000000ff00, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001, - 0x0000ff0000000100, 0x0000ff0000010000, 0x0000ff0001ffffff, 0x0000ff0001ff01ff, - 0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffff0101, 0x0000ff01ff010000, - 0x0000ff0100000000, 0x0000ff0101000101, 0x000000ffffff0001, 0x000000ffff000000, - 0x000000ff00ff0000, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000, - 0x000000ff00000001, 0x000000ff00000100, 0x000000ff00010000, 0x000000ff01000000, - 0x000000ff0101ff00, 0x00000000ffff0000, 0x00000000ff00ff00, 0x00000000ff0000ff, - 0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff010000, - 0x0000000000ffff00, 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001, - 0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01, - 0x00000000000000ff, 0x0000000000000001, 0x00000000000001ff, 0x0000000000000100, - 0x0000000000000101, 0x000000000001ff00, 0x00000000000100ff, 0x0000000000010000, - 0x0000000000010001, 0x0000000000010100, 0x0000000001ff0000, 0x000000000100ff00, - 0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, 0x0000000001000100, - 0x0000000001010000, 0x00000001ffff01ff, 0x00000001ff000000, 0x0000000100ff0000, - 0x000000010000ff00, 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001, - 0x0000000100000100, 0x0000000100010000, 0x0000000101000000, 0x000001ffff00ff00, - 0x000001ffff010001, 0x000001ffff0101ff, 0x000001ff00ffff01, 0x000001ff0000ffff, - 0x000001ff00000000, 0x000001ff010000ff, 0x000001ff01010100, 0x00000100ffff0100, - 0x00000100ff000000, 0x0000010000ff0000, 0x000001000000ff00, 0x00000100000000ff, - 0x0000010000000000, 0x0000010000000001, 0x0000010000000100, 0x0000010000010000, - 0x0000010001000000, 0x000001000101ff01, 0x00000101ffff0001, 0x00000101ff01ffff, - 0x0000010100000000, 0x0000010101010100, 0x0001ffffff000000, 0x0001ffff00ffffff, - 0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01000000, 0x0001ff00ffffff00, - 0x0001ff00ffff01ff, 0x0001ff00ff010000, 0x0001ff0000000000, 0x0001ff0000010001, - 0x0001ff0001ff0000, 0x0001ff0001010100, 0x0001ff01ff0000ff, 0x0001ff01ff000001, - 0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff01000101ff, 0x0001ff010100ff01, - 0x000100ffff00ffff, 0x000100ffff00ff01, 0x000100ffff000100, 0x000100ff00000000, - 0x000100ff000101ff, 0x000100ff01ff0101, 0x000100ff0100ffff, 0x000100ff01010101, - 0x00010000ff000000, 0x00010000ff010100, 0x0001000000ff0000, 0x000100000000ff00, - 0x00010000000000ff, 0x0001000000000000, 0x0001000000000001, 0x0001000000000100, - 0x0001000000010000, 0x0001000001ffff01, 0x0001000001000000, 0x0001000100ff0101, - 0x0001000100000000, 0x00010001010100ff, 0x000101ffffff01ff, 0x000101ffffff0101, - 0x000101ff00010000, 0x000101ff01ff0000, 0x000101ff0100ff01, 0x00010100ffff0000, - 0x0001010000000000, 0x000101000001ffff, 0x0001010000010101, 0x00010100010001ff, - 0x00010101ff00ff00, 0x00010101ff010001, 0x0001010100ffffff, 0x0001010100ff01ff, - 0x00010101000101ff, 0x0001010101ff0000, 0x000101010100ff01, 0x0001010101000101, - 0x01ffffffffff0101, 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff, - 0x01ffffffff010101, 0x01ffffff00000000, 0x01ffffff01ff01ff, 0x01ffffff01000101, - 0x01ffffff0101ff01, 0x01ffffff010100ff, 0x01ffff000000ff00, 0x01ffff0000000001, - 0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000, 0x01ffff01ffffffff, - 0x01ffff01ffff01ff, 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff0101ff, - 0x01ffff010100ffff, 0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff00ffff01, - 0x01ff0000ff0000ff, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ffffff, - 0x01ff000001010100, 0x01ff0001ffffff01, 0x01ff0001ff010001, 0x01ff000101ff0100, - 0x01ff000101000001, 0x01ff0001010100ff, 0x01ff01ffff00ffff, 0x01ff01ff00010001, - 0x01ff01ff01000000, 0x01ff01ff010101ff, 0x01ff0100ff000001, 0x01ff010000ffff00, - 0x01ff010000000100, 0x01ff010001ff01ff, 0x01ff01000101ffff, 0x01ff0101ffff00ff, - 0x01ff0101ffff0101, 0x01ff0101ff0101ff, 0x01ff010100010000, 0x0100ffff00ff00ff, - 0x0100ffff00ff0001, 0x0100ffff00000100, 0x0100ffff0100ff00, 0x0100ff00ffff0000, - 0x0100ff00ff00ffff, 0x0100ff00ff00ff01, 0x0100ff00ff000100, 0x0100ff00ff010000, - 0x0100ff0000000000, 0x0100ff00000100ff, 0x0100ff0001ff0101, 0x0100ff0001010101, - 0x0100ff0100ff00ff, 0x0100ff0100ff0001, 0x0100ff0100000100, 0x0100ff0100010001, - 0x0100ff0101000000, 0x010000ffff00ff00, 0x010000ff0000ffff, 0x010000ff00000000, - 0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00, 0x01000000ffff0101, - 0x01000000ff000000, 0x01000000ff0100ff, 0x01000000ff010101, 0x0100000000ff0000, - 0x010000000000ff00, 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001, - 0x0100000000000100, 0x0100000000010000, 0x0100000001000000, 0x0100000100000000, - 0x01000001000101ff, 0x0100000101ffff01, 0x010001ffff000101, 0x010001ff00ff0100, - 0x010001ff0000ff00, 0x010001ff000100ff, 0x010001ff01ffffff, 0x01000100ffff0000, - 0x01000100ff0001ff, 0x0100010000000000, 0x010001000001ff00, 0x0100010001ff0000, - 0x01000100010000ff, 0x0100010001000101, 0x01000101ff00ff01, 0x0100010100ff0100, - 0x010001010000ffff, 0x0100010101010001, 0x0101ffffffff0101, 0x0101ffffff0001ff, - 0x0101ffffff01ffff, 0x0101ffffff010101, 0x0101ffff00000000, 0x0101ffff0101ffff, - 0x0101ffff010101ff, 0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00, - 0x0101ff0000010000, 0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ff010101, - 0x0101ff0100000000, 0x0101ff010101ff00, 0x010100ffffff0000, 0x010100ffff010000, - 0x010100ff00ff01ff, 0x010100ff000000ff, 0x010100ff00000101, 0x010100ff01ffff00, - 0x01010000ffffff01, 0x01010000ff000100, 0x01010000ff01ff01, 0x0101000000000000, - 0x01010000000100ff, 0x010100000101ff01, 0x01010001ffff0000, 0x01010001ff00ffff, - 0x01010001ff010000, 0x0101000101ffffff, 0x0101000101ff01ff, 0x0101000101010101, - 0x010101ffff01ffff, 0x010101ff00000000, 0x010101ff0001ff01, 0x010101ff0101ffff, - 0x010101ff010101ff, 0x01010100ffffffff, 0x01010100ff000001, 0x010101000000ff00, - 0x0101010001010000, 0x0101010100ff0001, 0x010101010001ff01, 0x010101010101ffff, +#define IQ1S_DELTA 0.125f +#define IQ1M_DELTA 0.125f +static const __device__ uint64_t iq1s_grid_gpu[2048] = { + 0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000, + 0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101, + 0x02000000, 0x02000002, 0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200, + 0x02020202, 0x00000110, 0x00000111, 0x00010011, 0x00010110, 0x00010112, 0x00010211, 0x00010212, + 0x00020111, 0x01000011, 0x01000112, 0x01000211, 0x01010012, 0x01010111, 0x01010212, 0x01020011, + 0x01020110, 0x01020112, 0x01020210, 0x02000111, 0x02010011, 0x02010110, 0x02010112, 0x02020111, + 0x00000020, 0x00000022, 0x00000220, 0x00000222, 0x00010121, 0x00020020, 0x00020022, 0x00020220, + 0x00020222, 0x01000121, 0x01010021, 0x01010221, 0x01020120, 0x01020221, 0x02000020, 0x02000022, + 0x02000220, 0x02000222, 0x02010021, 0x02010121, 0x02010221, 0x02020020, 0x02020022, 0x02020220, + 0x02020222, 0x00011001, 0x00011100, 0x00011102, 0x00021101, 0x01001001, 0x01001201, 0x01011101, + 0x01011202, 0x01021100, 0x01021101, 0x02011001, 0x02011201, 0x02021101, 0x00001011, 0x00001110, + 0x00001111, 0x00001112, 0x00011111, 0x00011210, 0x00011212, 0x00021211, 0x01001010, 0x01001111, + 0x01001212, 0x01011010, 0x01011011, 0x01011110, 0x01011111, 0x01011112, 0x01011211, 0x01021010, + 0x01021012, 0x01021111, 0x01021210, 0x01021212, 0x02001011, 0x02011011, 0x02011111, 0x02011210, + 0x02011212, 0x02021011, 0x02021110, 0x02021111, 0x02021112, 0x02021211, 0x00011120, 0x00011221, + 0x01001021, 0x01001120, 0x01011020, 0x01011022, 0x01011121, 0x01011220, 0x01021020, 0x01021021, + 0x01021122, 0x01021221, 0x02001121, 0x02011021, 0x02011120, 0x02011221, 0x00002000, 0x00002002, + 0x00002200, 0x00002202, 0x00012101, 0x00022000, 0x00022002, 0x00022200, 0x00022202, 0x01002101, + 0x01012001, 0x01012102, 0x01022101, 0x02002000, 0x02002002, 0x02002200, 0x02002202, 0x02012101, + 0x02022000, 0x02022002, 0x02022200, 0x02022202, 0x00002111, 0x00012011, 0x00012110, 0x00012211, + 0x00022110, 0x00022111, 0x01002011, 0x01012010, 0x01012011, 0x01012111, 0x01022011, 0x01022110, + 0x01022211, 0x02012011, 0x02012110, 0x02012112, 0x02012211, 0x02022111, 0x00002020, 0x00002022, + 0x00002220, 0x00002222, 0x00012121, 0x00022020, 0x00022022, 0x00022220, 0x00022222, 0x01002121, + 0x01012021, 0x01012221, 0x01022021, 0x01022121, 0x02002020, 0x02002022, 0x02002121, 0x02002220, + 0x02002222, 0x02012121, 0x02022020, 0x02022022, 0x02022220, 0x02022222, 0x00110000, 0x00110001, + 0x00110100, 0x00110201, 0x00120100, 0x00120101, 0x01100001, 0x01100100, 0x01110000, 0x01110101, + 0x01110200, 0x01120001, 0x01120100, 0x01120101, 0x01120201, 0x02110001, 0x02110100, 0x02110102, + 0x02120001, 0x02120101, 0x00100011, 0x00100110, 0x00100112, 0x00100211, 0x00110010, 0x00110012, + 0x00110111, 0x00110210, 0x00120011, 0x00120110, 0x00120211, 0x01100111, 0x01100212, 0x01110010, + 0x01110011, 0x01110012, 0x01110110, 0x01110111, 0x01110112, 0x01110211, 0x01120010, 0x01120111, + 0x02100110, 0x02110012, 0x02110111, 0x02120011, 0x02120110, 0x00110021, 0x00110120, 0x00110122, + 0x00120121, 0x01100020, 0x01100122, 0x01100221, 0x01110022, 0x01110121, 0x01110220, 0x01110222, + 0x01120120, 0x01120122, 0x02100121, 0x02110021, 0x02110120, 0x02110122, 0x02120121, 0x00101001, + 0x00101102, 0x00101201, 0x00111100, 0x00111101, 0x00111200, 0x00111201, 0x00121001, 0x00121102, + 0x01101001, 0x01101101, 0x01101102, 0x01101200, 0x01101202, 0x01111001, 0x01111100, 0x01111101, + 0x01111102, 0x01111201, 0x01121002, 0x01121101, 0x01121200, 0x02101100, 0x02101201, 0x02111000, + 0x02111100, 0x02111101, 0x02111200, 0x02111201, 0x02111202, 0x02121001, 0x02121100, 0x02121101, + 0x02121201, 0x00101012, 0x00101111, 0x00101212, 0x00111011, 0x00111110, 0x00111111, 0x00111112, + 0x00111211, 0x00121010, 0x00121012, 0x00121111, 0x00121210, 0x00121212, 0x01101011, 0x01101110, + 0x01101111, 0x01101112, 0x01111011, 0x01111012, 0x01111110, 0x01111111, 0x01111112, 0x01111211, + 0x01111212, 0x01121011, 0x01121110, 0x01121111, 0x01121112, 0x01121211, 0x02101010, 0x02101012, + 0x02101110, 0x02101111, 0x02101210, 0x02101212, 0x02111010, 0x02111011, 0x02111110, 0x02111111, + 0x02111112, 0x02111211, 0x02111212, 0x02121010, 0x02121012, 0x02121111, 0x00101021, 0x00101120, + 0x00101121, 0x00101122, 0x00111121, 0x00111122, 0x00111220, 0x00111222, 0x00121021, 0x00121122, + 0x01101020, 0x01101022, 0x01101120, 0x01101121, 0x01101220, 0x01101222, 0x01111021, 0x01111121, + 0x01111122, 0x01111220, 0x01111221, 0x01121021, 0x01121120, 0x01121121, 0x01121220, 0x01121221, + 0x01121222, 0x02101122, 0x02101222, 0x02111022, 0x02111121, 0x02121120, 0x02121221, 0x00112001, + 0x00112102, 0x00122101, 0x01102001, 0x01102100, 0x01102102, 0x01102201, 0x01112000, 0x01112101, + 0x01112200, 0x01112202, 0x01122000, 0x01122001, 0x01122100, 0x01122102, 0x01122201, 0x02102101, + 0x02112001, 0x02112100, 0x02122101, 0x00112010, 0x00112012, 0x00112111, 0x00112212, 0x00122011, + 0x00122111, 0x01102012, 0x01102110, 0x01102111, 0x01102210, 0x01112011, 0x01112110, 0x01112111, + 0x01112112, 0x01112211, 0x01112212, 0x01122010, 0x01122111, 0x01122212, 0x02102211, 0x02112011, + 0x02112012, 0x02112111, 0x02112210, 0x02122011, 0x02122112, 0x02122211, 0x00102221, 0x00112122, + 0x00122120, 0x00122122, 0x01102120, 0x01102122, 0x01102221, 0x01112020, 0x01112022, 0x01112121, + 0x01112220, 0x01122021, 0x01122122, 0x01122221, 0x02102121, 0x02112021, 0x02112122, 0x02112222, + 0x00200000, 0x00200002, 0x00200200, 0x00200202, 0x00210101, 0x00220000, 0x00220002, 0x00220101, + 0x00220200, 0x00220202, 0x01200101, 0x01210001, 0x01210201, 0x01220001, 0x01220101, 0x02200000, + 0x02200002, 0x02200200, 0x02200202, 0x02210101, 0x02220000, 0x02220002, 0x02220101, 0x02220200, + 0x02220202, 0x00200111, 0x00210011, 0x00210110, 0x00210211, 0x00220111, 0x01200012, 0x01200110, + 0x01200211, 0x01210111, 0x01210210, 0x01210212, 0x01220011, 0x01220110, 0x01220111, 0x01220112, + 0x02200111, 0x02210010, 0x02210112, 0x02210211, 0x02220111, 0x00200021, 0x00200220, 0x00200222, + 0x00210021, 0x00210121, 0x00220020, 0x00220022, 0x00220220, 0x00220222, 0x01200121, 0x01210021, + 0x01210122, 0x01210221, 0x01220121, 0x02200021, 0x02200220, 0x02200222, 0x02210021, 0x02210121, + 0x02220020, 0x02220022, 0x02220220, 0x02220222, 0x00201101, 0x00211100, 0x00211102, 0x00211201, + 0x00221101, 0x01201100, 0x01201101, 0x01201102, 0x01201201, 0x01211002, 0x01211101, 0x01211200, + 0x01211202, 0x01221102, 0x02201101, 0x02211001, 0x02211100, 0x02211201, 0x02221001, 0x02221101, + 0x00201211, 0x00211111, 0x00221011, 0x00221211, 0x01201010, 0x01201111, 0x01201210, 0x01211011, + 0x01211110, 0x01211111, 0x01211211, 0x01221012, 0x01221111, 0x01221210, 0x02201211, 0x02211010, + 0x02211110, 0x02211111, 0x02211210, 0x02211212, 0x02221011, 0x02221110, 0x02221112, 0x02221211, + 0x00201121, 0x00211020, 0x00211022, 0x00211221, 0x00221121, 0x01201021, 0x01201221, 0x01211121, + 0x01221020, 0x01221021, 0x01221221, 0x02201120, 0x02201122, 0x02211020, 0x02211222, 0x00202000, + 0x00202002, 0x00202200, 0x00202202, 0x00212101, 0x00222000, 0x00222002, 0x00222200, 0x00222202, + 0x01202101, 0x01212001, 0x01212100, 0x01222101, 0x02202000, 0x02202002, 0x02202200, 0x02202202, + 0x02222000, 0x02222002, 0x02222200, 0x02222202, 0x00202211, 0x00212011, 0x00212110, 0x00212211, + 0x00222111, 0x01202112, 0x01202211, 0x01212012, 0x01212111, 0x01222011, 0x01222110, 0x01222112, + 0x01222211, 0x02202111, 0x02212010, 0x02212112, 0x02212211, 0x02222110, 0x02222111, 0x00202020, + 0x00202022, 0x00202220, 0x00202222, 0x00222020, 0x00222022, 0x00222220, 0x00222222, 0x01202121, + 0x01212021, 0x01212122, 0x01212221, 0x01222121, 0x02202020, 0x02202022, 0x02202220, 0x02202222, + 0x02212121, 0x02222020, 0x02222022, 0x02222220, 0x02222222, 0x10000101, 0x10010001, 0x10010102, + 0x10020101, 0x11000201, 0x11010002, 0x11010101, 0x11010200, 0x11010202, 0x11020001, 0x11020100, + 0x11020102, 0x12010100, 0x12010201, 0x12020001, 0x12020102, 0x10000010, 0x10000011, 0x10000110, + 0x10000112, 0x10000211, 0x10010012, 0x10010111, 0x10010112, 0x10010210, 0x10010212, 0x10020011, + 0x10020112, 0x10020211, 0x11000111, 0x11000210, 0x11000212, 0x11010011, 0x11010110, 0x11010111, + 0x11010112, 0x11010211, 0x11010212, 0x11020111, 0x11020210, 0x11020212, 0x12000011, 0x12000110, + 0x12000112, 0x12010010, 0x12010012, 0x12010111, 0x12020010, 0x12020011, 0x12020012, 0x10000121, + 0x10010021, 0x10010120, 0x10010122, 0x10020121, 0x11000021, 0x11010022, 0x11010121, 0x11010222, + 0x11020120, 0x11020221, 0x12000221, 0x12010120, 0x12020121, 0x10001001, 0x10011101, 0x10011201, + 0x10021201, 0x11001101, 0x11001200, 0x11001202, 0x11011001, 0x11011100, 0x11011101, 0x11011102, + 0x11021001, 0x11021002, 0x11021101, 0x11021200, 0x11021202, 0x12001001, 0x12001102, 0x12001201, + 0x12011000, 0x12011002, 0x12011101, 0x12021000, 0x12021001, 0x12021201, 0x10001011, 0x10001012, + 0x10001111, 0x10001212, 0x10011011, 0x10011110, 0x10011111, 0x10011112, 0x10011211, 0x10021010, + 0x10021111, 0x10021212, 0x11001011, 0x11001110, 0x11001111, 0x11001112, 0x11001211, 0x11011010, + 0x11011011, 0x11011110, 0x11011111, 0x11011112, 0x11011210, 0x11011211, 0x11021011, 0x11021110, + 0x11021111, 0x11021112, 0x11021211, 0x12001012, 0x12001110, 0x12001111, 0x12001210, 0x12011011, + 0x12011110, 0x12011111, 0x12011112, 0x12011211, 0x12011212, 0x12021111, 0x12021210, 0x12021212, + 0x10001021, 0x10001121, 0x10001221, 0x10011120, 0x10011121, 0x10011220, 0x10011222, 0x10021021, + 0x10021120, 0x10021221, 0x11001020, 0x11001022, 0x11001121, 0x11001220, 0x11011020, 0x11011021, + 0x11011022, 0x11011121, 0x11011122, 0x11011221, 0x11021022, 0x11021121, 0x11021220, 0x12001021, + 0x12001121, 0x12001222, 0x12011120, 0x12011121, 0x12021021, 0x12021120, 0x12021122, 0x10002101, + 0x10012001, 0x10012101, 0x10012202, 0x10022101, 0x11002002, 0x11002201, 0x11012000, 0x11012101, + 0x11012200, 0x11022001, 0x11022100, 0x11022102, 0x11022201, 0x12002101, 0x12012001, 0x12012100, + 0x12012102, 0x12012201, 0x12022101, 0x10002011, 0x10002111, 0x10002112, 0x10002212, 0x10012010, + 0x10012110, 0x10012111, 0x10012210, 0x10022011, 0x10022110, 0x10022112, 0x11002010, 0x11002111, + 0x11002212, 0x11012011, 0x11012012, 0x11012110, 0x11012111, 0x11012112, 0x11012211, 0x11022010, + 0x11022012, 0x11022111, 0x11022112, 0x11022212, 0x12002112, 0x12002211, 0x12012012, 0x12012111, + 0x12012112, 0x12012210, 0x12022011, 0x12022110, 0x12022112, 0x12022211, 0x10012122, 0x11002120, + 0x11002122, 0x11002221, 0x11012121, 0x11012220, 0x11012222, 0x11022120, 0x11022221, 0x12012120, + 0x12022121, 0x10100001, 0x10100100, 0x10100101, 0x10100102, 0x10100201, 0x10110002, 0x10110101, + 0x10110202, 0x10120001, 0x10120100, 0x10120201, 0x11100000, 0x11100101, 0x11100200, 0x11110001, + 0x11110100, 0x11110101, 0x11110102, 0x11110201, 0x11120101, 0x11120200, 0x12100102, 0x12100201, + 0x12110101, 0x12110200, 0x12120000, 0x12120001, 0x12120102, 0x12120201, 0x10100111, 0x10100210, + 0x10100211, 0x10100212, 0x10110011, 0x10110110, 0x10110111, 0x10110112, 0x10110210, 0x10110211, + 0x10120010, 0x10120111, 0x10120112, 0x10120210, 0x10120212, 0x11100011, 0x11100110, 0x11100111, + 0x11100112, 0x11100211, 0x11110010, 0x11110011, 0x11110012, 0x11110110, 0x11110111, 0x11110112, + 0x11110210, 0x11110211, 0x11110212, 0x11120011, 0x11120110, 0x11120111, 0x11120112, 0x11120211, + 0x12100012, 0x12100111, 0x12110011, 0x12110110, 0x12110111, 0x12110112, 0x12110211, 0x12120010, + 0x12120111, 0x12120212, 0x10100021, 0x10100122, 0x10110022, 0x10110121, 0x10110222, 0x10120021, + 0x10120120, 0x11100022, 0x11100121, 0x11100222, 0x11110021, 0x11110120, 0x11110121, 0x11110122, + 0x11110221, 0x11120022, 0x11120121, 0x12100121, 0x12110020, 0x12110022, 0x12110121, 0x12110221, + 0x12110222, 0x12120120, 0x10101100, 0x10101101, 0x10111001, 0x10111100, 0x10111101, 0x10111102, + 0x10111200, 0x10111201, 0x10121001, 0x10121101, 0x10121200, 0x10121202, 0x11101001, 0x11101100, + 0x11101101, 0x11101102, 0x11101201, 0x11101202, 0x11111000, 0x11111001, 0x11111100, 0x11111101, + 0x11111102, 0x11111200, 0x11111201, 0x11111202, 0x11121001, 0x11121002, 0x11121100, 0x11121101, + 0x11121102, 0x11121201, 0x12101000, 0x12101200, 0x12101202, 0x12111001, 0x12111100, 0x12111101, + 0x12111102, 0x12111201, 0x12121001, 0x12121100, 0x12121101, 0x12121202, 0x10101011, 0x10101012, + 0x10101110, 0x10101111, 0x10101112, 0x10101211, 0x10111010, 0x10111011, 0x10111012, 0x10111110, + 0x10111111, 0x10111112, 0x10111211, 0x10111212, 0x10121011, 0x10121110, 0x10121111, 0x10121112, + 0x10121211, 0x11101010, 0x11101011, 0x11101012, 0x11101110, 0x11101111, 0x11101112, 0x11101210, + 0x11101211, 0x11111010, 0x11111011, 0x11111012, 0x11111110, 0x11111111, 0x11111112, 0x11111210, + 0x11111211, 0x11111212, 0x11121010, 0x11121011, 0x11121110, 0x11121111, 0x11121112, 0x11121210, + 0x11121211, 0x11121212, 0x12101011, 0x12101110, 0x12101111, 0x12101211, 0x12101212, 0x12111010, + 0x12111011, 0x12111110, 0x12111111, 0x12111112, 0x12111210, 0x12111211, 0x12121011, 0x12121110, + 0x12121111, 0x12121112, 0x12121211, 0x10101020, 0x10101021, 0x10101022, 0x10101120, 0x10101122, + 0x10101220, 0x10101221, 0x10111021, 0x10111120, 0x10111121, 0x10111220, 0x10111221, 0x10121020, + 0x10121021, 0x10121022, 0x10121120, 0x10121121, 0x10121122, 0x10121220, 0x10121221, 0x11101021, + 0x11101121, 0x11101122, 0x11101220, 0x11101221, 0x11101222, 0x11111020, 0x11111021, 0x11111022, + 0x11111120, 0x11111121, 0x11111122, 0x11111220, 0x11111221, 0x11111222, 0x11121021, 0x11121120, + 0x11121121, 0x11121221, 0x12101022, 0x12101121, 0x12101122, 0x12101220, 0x12101221, 0x12101222, + 0x12111021, 0x12111121, 0x12111222, 0x12121022, 0x12121121, 0x12121122, 0x12121220, 0x12121221, + 0x10102100, 0x10102101, 0x10102102, 0x10102201, 0x10112000, 0x10112101, 0x10112200, 0x10122001, + 0x10122202, 0x11102101, 0x11102200, 0x11102202, 0x11112001, 0x11112100, 0x11112101, 0x11112102, + 0x11112200, 0x11112201, 0x11122000, 0x11122002, 0x11122100, 0x11122101, 0x12102002, 0x12102201, + 0x12112000, 0x12112002, 0x12112101, 0x12112200, 0x12122001, 0x12122201, 0x10102011, 0x10102012, + 0x10102111, 0x10102212, 0x10112011, 0x10112110, 0x10112111, 0x10112112, 0x10112211, 0x10122111, + 0x11102011, 0x11102110, 0x11102111, 0x11102112, 0x11102211, 0x11112010, 0x11112011, 0x11112012, + 0x11112110, 0x11112111, 0x11112112, 0x11112210, 0x11112211, 0x11112212, 0x11122011, 0x11122110, + 0x11122111, 0x11122112, 0x11122211, 0x12102011, 0x12102111, 0x12102211, 0x12112011, 0x12112110, + 0x12112111, 0x12112112, 0x12112210, 0x12112211, 0x12122111, 0x10102120, 0x10102220, 0x10112121, + 0x10112222, 0x10122020, 0x10122121, 0x10122122, 0x10122221, 0x11102121, 0x11102220, 0x11102221, + 0x11112021, 0x11112121, 0x11112122, 0x11112220, 0x11112221, 0x11122022, 0x11122121, 0x11122220, + 0x11122222, 0x12102021, 0x12102222, 0x12112022, 0x12112121, 0x12112122, 0x12112220, 0x12112222, + 0x12122021, 0x10200101, 0x10210100, 0x10210102, 0x10210201, 0x10220101, 0x11200100, 0x11210000, + 0x11210101, 0x11210102, 0x11210200, 0x11210202, 0x11220001, 0x11220100, 0x11220102, 0x11220201, + 0x12200001, 0x12210102, 0x12220101, 0x10200011, 0x10200110, 0x10200112, 0x10200211, 0x10210012, + 0x10210111, 0x10220011, 0x10220012, 0x10220112, 0x10220211, 0x11200111, 0x11200211, 0x11210011, + 0x11210111, 0x11210112, 0x11210211, 0x11220111, 0x11220112, 0x11220212, 0x12200110, 0x12200212, + 0x12210012, 0x12210111, 0x12220011, 0x12220112, 0x12220211, 0x10210021, 0x10210122, 0x10210221, + 0x11200020, 0x11200021, 0x11200122, 0x11210121, 0x11210122, 0x11210220, 0x11220020, 0x12200121, + 0x12210021, 0x12210122, 0x12220121, 0x10211001, 0x10211002, 0x10211101, 0x10211102, 0x10211202, + 0x10221001, 0x10221102, 0x10221201, 0x11201000, 0x11201002, 0x11201101, 0x11201200, 0x11201202, + 0x11211001, 0x11211100, 0x11211101, 0x11211102, 0x11211201, 0x11211202, 0x11221000, 0x11221002, + 0x11221101, 0x12201100, 0x12201101, 0x12201201, 0x12211000, 0x12211002, 0x12211100, 0x12211101, + 0x12211102, 0x12211200, 0x12211202, 0x12221001, 0x12221100, 0x12221201, 0x10201111, 0x10201210, + 0x10201212, 0x10211011, 0x10211111, 0x10211112, 0x10211211, 0x11201110, 0x11201111, 0x11201112, + 0x11201211, 0x11211010, 0x11211011, 0x11211110, 0x11211111, 0x11211112, 0x11211211, 0x11221011, + 0x11221110, 0x11221111, 0x11221112, 0x11221211, 0x12201112, 0x12201211, 0x12201212, 0x12211011, + 0x12211111, 0x12211112, 0x12211211, 0x12211212, 0x12221012, 0x12221111, 0x12221112, 0x12221210, + 0x10201022, 0x10201221, 0x10211121, 0x10221020, 0x10221122, 0x10221220, 0x10221221, 0x11201020, + 0x11201121, 0x11201220, 0x11201222, 0x11211021, 0x11211120, 0x11211121, 0x11211122, 0x11211220, + 0x11211222, 0x11221020, 0x11221121, 0x11221220, 0x12201020, 0x12201022, 0x12201121, 0x12201222, + 0x12211120, 0x12211122, 0x12211220, 0x12211221, 0x12221020, 0x12221120, 0x12221122, 0x12221222, + 0x10212102, 0x10212201, 0x10222101, 0x11202001, 0x11212002, 0x11212101, 0x11212202, 0x11222001, + 0x11222201, 0x12202101, 0x12212001, 0x12212200, 0x12222102, 0x10202011, 0x10202110, 0x10212010, + 0x10212111, 0x10222011, 0x10222110, 0x10222112, 0x10222211, 0x11202010, 0x11202011, 0x11202111, + 0x11202112, 0x11202210, 0x11212011, 0x11212110, 0x11212111, 0x11212112, 0x11212211, 0x11222010, + 0x11222111, 0x11222212, 0x12202012, 0x12202110, 0x12202212, 0x12212111, 0x12222011, 0x12222110, + 0x12222111, 0x12222211, 0x10212021, 0x10212122, 0x10212220, 0x11202021, 0x11202120, 0x11202221, + 0x11212020, 0x11212121, 0x11212220, 0x11212222, 0x11222120, 0x11222121, 0x11222221, 0x12202122, + 0x12212120, 0x12212220, 0x12212222, 0x12222122, 0x20000000, 0x20000002, 0x20000200, 0x20000202, + 0x20020000, 0x20020002, 0x20020200, 0x20020202, 0x21000101, 0x21010000, 0x21010001, 0x21010100, + 0x21010102, 0x21010201, 0x21020101, 0x22000000, 0x22000002, 0x22000200, 0x22000202, 0x22010101, + 0x22020000, 0x22020002, 0x22020200, 0x22020202, 0x20000111, 0x20010011, 0x20010110, 0x20010112, + 0x20010211, 0x20020111, 0x21000011, 0x21000110, 0x21000211, 0x21010010, 0x21010012, 0x21010111, + 0x21010112, 0x21010210, 0x21010211, 0x21020110, 0x21020112, 0x21020211, 0x22000111, 0x22000211, + 0x22010110, 0x22010112, 0x22010211, 0x22020111, 0x20000020, 0x20000022, 0x20000220, 0x20000222, + 0x20010121, 0x20020020, 0x20020022, 0x20020220, 0x20020222, 0x21010021, 0x21010120, 0x21010221, + 0x21020121, 0x22000020, 0x22000022, 0x22000220, 0x22000222, 0x22010121, 0x22020020, 0x22020022, + 0x22020220, 0x22020222, 0x20011100, 0x20011201, 0x21001001, 0x21001100, 0x21011001, 0x21011101, + 0x21011202, 0x21021001, 0x21021100, 0x21021201, 0x22011100, 0x22011201, 0x20001011, 0x20001211, + 0x20011012, 0x20011111, 0x20011212, 0x20021112, 0x20021211, 0x21001010, 0x21001011, 0x21001111, + 0x21001210, 0x21011011, 0x21011110, 0x21011111, 0x21011112, 0x21011211, 0x21011212, 0x21021111, + 0x21021112, 0x21021210, 0x21021212, 0x22001011, 0x22001110, 0x22001112, 0x22001211, 0x22011010, + 0x22011012, 0x22011111, 0x22011210, 0x22021112, 0x20011021, 0x20011122, 0x20011221, 0x20021121, + 0x21001021, 0x21001120, 0x21001221, 0x21001222, 0x21011020, 0x21011121, 0x21011221, 0x21011222, + 0x21021021, 0x21021122, 0x21021222, 0x22001121, 0x22011021, 0x22011222, 0x22021120, 0x20002000, + 0x20002002, 0x20002200, 0x20002202, 0x20012101, 0x20022000, 0x20022002, 0x20022200, 0x20022202, + 0x21002001, 0x21002101, 0x21012001, 0x21012100, 0x21012201, 0x21022101, 0x21022201, 0x22002000, + 0x22002002, 0x22002200, 0x22002202, 0x22012101, 0x22022000, 0x22022002, 0x22022200, 0x22022202, + 0x20002111, 0x20002112, 0x20012011, 0x20012110, 0x20012112, 0x20022111, 0x21002011, 0x21002110, + 0x21002112, 0x21002211, 0x21012010, 0x21012012, 0x21012111, 0x21012212, 0x21022011, 0x21022110, + 0x22002111, 0x22012112, 0x22012211, 0x22022111, 0x20002020, 0x20002022, 0x20002220, 0x20002222, + 0x20012121, 0x20022020, 0x20022022, 0x20022220, 0x20022222, 0x21002121, 0x21012021, 0x21012120, + 0x21012122, 0x22002020, 0x22002022, 0x22002220, 0x22002222, 0x22012121, 0x22022020, 0x22022022, + 0x22022220, 0x22022222, 0x20100101, 0x20110001, 0x20110102, 0x20110200, 0x20110201, 0x20120101, + 0x21100001, 0x21100102, 0x21100201, 0x21110101, 0x21110200, 0x21110202, 0x21120201, 0x21120202, + 0x22100101, 0x22110001, 0x22110100, 0x22110102, 0x22110201, 0x22120101, 0x20100011, 0x20100110, + 0x20100112, 0x20100211, 0x20110010, 0x20110111, 0x20110210, 0x20110212, 0x20120011, 0x20120110, + 0x20120112, 0x20120211, 0x21100010, 0x21100111, 0x21110010, 0x21110011, 0x21110110, 0x21110111, + 0x21110112, 0x21110211, 0x21120012, 0x21120111, 0x22100110, 0x22100112, 0x22110012, 0x22110111, + 0x22110210, 0x22120011, 0x22120110, 0x22120112, 0x22120211, 0x20100121, 0x20110021, 0x20110120, + 0x20110221, 0x20120121, 0x21100120, 0x21100122, 0x21100221, 0x21110020, 0x21110022, 0x21110121, + 0x21110220, 0x21120122, 0x21120221, 0x22100121, 0x22110120, 0x22110122, 0x22120221, 0x20101001, + 0x20101100, 0x20101102, 0x20111000, 0x20111101, 0x20111200, 0x20121102, 0x21101000, 0x21101202, + 0x21111001, 0x21111100, 0x21111101, 0x21111102, 0x21111200, 0x21111201, 0x21121000, 0x21121001, + 0x21121002, 0x21121101, 0x22101100, 0x22101102, 0x22111002, 0x22111100, 0x22111101, 0x22111200, + 0x22121001, 0x22121201, 0x20101010, 0x20101111, 0x20101210, 0x20101212, 0x20111010, 0x20111011, + 0x20111110, 0x20111111, 0x20111112, 0x20111211, 0x20121011, 0x20121111, 0x20121211, 0x20121212, + 0x21101011, 0x21101110, 0x21101111, 0x21101112, 0x21101211, 0x21111010, 0x21111011, 0x21111012, + 0x21111110, 0x21111111, 0x21111112, 0x21111210, 0x21111211, 0x21111212, 0x21121011, 0x21121110, + 0x21121111, 0x21121112, 0x21121211, 0x22101011, 0x22101111, 0x22101210, 0x22111011, 0x22111012, + 0x22111110, 0x22111111, 0x22111112, 0x22111211, 0x22111212, 0x22121010, 0x22121012, 0x22121111, + 0x22121210, 0x22121212, 0x20101021, 0x20101120, 0x20111020, 0x20111121, 0x20111221, 0x20121020, + 0x20121122, 0x20121221, 0x21101121, 0x21101220, 0x21101221, 0x21111021, 0x21111022, 0x21111121, + 0x21111122, 0x21111221, 0x21121121, 0x21121220, 0x22101022, 0x22101120, 0x22101221, 0x22101222, + 0x22111022, 0x22111120, 0x22111121, 0x22121120, 0x22121122, 0x22121221, 0x20102101, 0x20112102, + 0x20112201, 0x20122101, 0x21102001, 0x21102102, 0x21112000, 0x21112002, 0x21112101, 0x21112102, + 0x21112202, 0x21122100, 0x21122101, 0x22102101, 0x22112001, 0x22112102, 0x22112201, 0x22122101, + 0x20102110, 0x20102112, 0x20102211, 0x20112010, 0x20112012, 0x20112111, 0x20112210, 0x20112212, + 0x20122010, 0x20122011, 0x20122110, 0x20122112, 0x21102010, 0x21102012, 0x21102111, 0x21102210, + 0x21102212, 0x21112011, 0x21112110, 0x21112111, 0x21112112, 0x21112211, 0x21122012, 0x21122111, + 0x21122112, 0x21122212, 0x22102011, 0x22102110, 0x22112010, 0x22112012, 0x22112111, 0x22112212, + 0x22122011, 0x22122112, 0x20102121, 0x20112121, 0x20122121, 0x21102120, 0x21102122, 0x21102221, + 0x21112020, 0x21112121, 0x21112220, 0x21122021, 0x22102121, 0x22112021, 0x22112120, 0x22112121, + 0x22112122, 0x20200000, 0x20200002, 0x20200200, 0x20200202, 0x20210101, 0x20220000, 0x20220002, + 0x20220200, 0x20220202, 0x21200101, 0x21210001, 0x21210100, 0x21210102, 0x21210201, 0x22200000, + 0x22200002, 0x22200200, 0x22200202, 0x22210101, 0x22220000, 0x22220002, 0x22220200, 0x22220202, + 0x20200111, 0x20200211, 0x20210011, 0x20210110, 0x20210112, 0x20210211, 0x20210212, 0x21200112, + 0x21200211, 0x21210011, 0x21210111, 0x21210210, 0x21210212, 0x21220011, 0x21220110, 0x22200111, + 0x22210010, 0x22210012, 0x22210112, 0x22210211, 0x20200022, 0x20200220, 0x20200222, 0x20210020, + 0x20210221, 0x20220022, 0x20220220, 0x20220222, 0x21200121, 0x21210021, 0x21210122, 0x21210221, + 0x21220121, 0x22200020, 0x22200022, 0x22200220, 0x22200222, 0x22210121, 0x22220020, 0x22220022, + 0x22220220, 0x22220222, 0x20211201, 0x20221101, 0x21201001, 0x21201100, 0x21211000, 0x21211100, + 0x21211101, 0x21211200, 0x21211202, 0x21221001, 0x21221101, 0x21221102, 0x21221200, 0x21221201, + 0x22201101, 0x20201112, 0x20201211, 0x20211010, 0x20211012, 0x20211111, 0x20211210, 0x20221112, + 0x20221211, 0x21201012, 0x21201111, 0x21211011, 0x21211110, 0x21211111, 0x21211112, 0x21211211, + 0x21221111, 0x21221212, 0x22201011, 0x22201110, 0x22201111, 0x22201112, 0x22201211, 0x22211012, + 0x22211111, 0x22211210, 0x20201121, 0x20211021, 0x20211122, 0x20211222, 0x20221021, 0x20221121, + 0x21201120, 0x21201122, 0x21201222, 0x21211022, 0x21211121, 0x21211122, 0x21211220, 0x21221020, + 0x21221022, 0x22201122, 0x22211020, 0x22211121, 0x22211122, 0x22211221, 0x22221021, 0x22221120, + 0x22221122, 0x20202000, 0x20202002, 0x20202200, 0x20202202, 0x20222000, 0x20222002, 0x20222200, + 0x20222202, 0x21212001, 0x21212100, 0x21212102, 0x21212201, 0x22202000, 0x22202002, 0x22202200, + 0x22202202, 0x22212101, 0x22222000, 0x22222002, 0x22222200, 0x22222202, 0x20202111, 0x20212110, + 0x20212211, 0x20222011, 0x20222111, 0x21202011, 0x21212010, 0x21212111, 0x21212212, 0x21222011, + 0x21222112, 0x21222211, 0x22212010, 0x22212112, 0x20202020, 0x20202022, 0x20202220, 0x20202222, + 0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020, + 0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222, }; static const __device__ uint8_t ksigns_iq2xs[128] = { diff --git a/csrc/quantization/gguf/gguf_kernel.cu b/csrc/quantization/gguf/gguf_kernel.cu index 966d9992b25fd..37e4de4e14dd3 100644 --- a/csrc/quantization/gguf/gguf_kernel.cu +++ b/csrc/quantization/gguf/gguf_kernel.cu @@ -166,6 +166,11 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), col, row, stream); break; + case 29: + mul_mat_vec_iq1_m_q8_1_cuda((void*)W.data_ptr(), + (void*)quant_X.data_ptr(), + (half*)Y.data_ptr(), col, row, stream); + break; } return Y; } diff --git a/csrc/quantization/gguf/mmvq.cuh b/csrc/quantization/gguf/mmvq.cuh index ef2ea072392d2..b221ae7896138 100644 --- a/csrc/quantization/gguf/mmvq.cuh +++ b/csrc/quantization/gguf/mmvq.cuh @@ -157,6 +157,14 @@ static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, half * <<>>(vx, vy, dst, ncols, nrows); } +static void mul_mat_vec_iq1_m_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + static void mul_mat_vec_iq4_nl_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); diff --git a/csrc/quantization/gguf/vecdotq.cuh b/csrc/quantization/gguf/vecdotq.cuh index 78c749d3f3bc1..ff339753bcbb5 100644 --- a/csrc/quantization/gguf/vecdotq.cuh +++ b/csrc/quantization/gguf/vecdotq.cuh @@ -1,5 +1,18 @@ // copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/vecdotq.cuh // and https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmq.cu +static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) { + const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment + + int x32 = x16[2*i32 + 0] << 0; + x32 |= x16[2*i32 + 1] << 16; + + return x32; +} + +static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32) { + return ((const int *) x)[i32]; // assume at least 4 byte alignment +} + static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) { const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment int x32 = 0; @@ -1658,28 +1671,76 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( static __device__ __forceinline__ float vec_dot_iq1_s_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 const block_iq1_s * bq1 = (const block_iq1_s *) vbq; - const int ib32 = iqs; - int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; - const uint8_t h1 = bq1->scales[2*ib32+0]; - const uint8_t h2 = bq1->scales[2*ib32+1]; - const int * q8 = (const int *)bq8_1[ib32].qs; - const int * grid1 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5))); - const int * grid2 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1))); - const int * grid3 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5))); - const int * grid4 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1))); - for (int j = 0; j < 2; ++j) { - sumi1 = __dp4a(q8[j+0], grid1[j], sumi1); - sumi2 = __dp4a(q8[j+2], grid2[j], sumi2); - sumi3 = __dp4a(q8[j+4], grid3[j], sumi3); - sumi4 = __dp4a(q8[j+6], grid4[j], sumi4); - } - const float d = __half2float(bq1->d) * __low2float(bq8_1[ib32].ds); - return d * (sumi1 * (2*(h1 & 7) + 1) + sumi2 * (2*((h1 >> 4) & 7) + 1) + - sumi3 * (2*(h2 & 7) + 1) + sumi4 * (2*((h2 >> 4) & 7) + 1)); -#endif + const int qs_packed = get_int_b2(bq1->qs, iqs); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + const int qh = bq1->qh[iqs]; + + int sumi = 0; +#pragma unroll + for (int l0 = 0; l0 < 8; l0 += 2) { + const int grid = iq1s_grid_gpu[qs[l0/2] | (((qh >> 3*(l0/2)) & 0x07) << 8)]; + + const int grid0 = (grid >> 0) & 0x0F0F0F0F; + const int grid1 = (grid >> 4) & 0x0F0F0F0F; + + const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0); + const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1); + + sumi = __dp4a(grid0, u0, sumi); + sumi = __dp4a(grid1, u1, sumi); + } + + const float d1q = __half2float(bq1->d) * (((qh >> 11) & 0x0E) + 1); + const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000); + const float2 ds = __half22float2(bq8_1[iqs].ds); + return d1q * (ds.x*sumi + ds.y*delta); +} + +static __device__ __forceinline__ float vec_dot_iq1_m_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_iq1_m * bq1 = (const block_iq1_m *) vbq; + + const int qs_packed = get_int_b4(bq1->qs, iqs); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + int sumi[2] = {0}; + float sumf[2] = {0.0f}; +#pragma unroll + for (int l0 = 0; l0 < 8; l0 += 2) { + const int qhl = bq1->qh[2*iqs + l0/4] >> (4 * ((l0/2) % 2)); + + const int grid = iq1s_grid_gpu[qs[l0/2] | ((qhl & 0x07) << 8)]; + + const int grid0 = (grid >> 0) & 0x0F0F0F0F; + const int grid1 = (grid >> 4) & 0x0F0F0F0F; + + const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0); + const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1); + + sumi[l0/4] = __dp4a(grid0, u0, sumi[l0/4]); + sumi[l0/4] = __dp4a(grid1, u1, sumi[l0/4]); + + const float delta = -1.0f + IQ1M_DELTA - (qhl & 0x08) * (2.0f*IQ1M_DELTA/0x08); + int sumy = 0; + sumy = __dp4a(u0, 0x01010101, sumy); + sumy = __dp4a(u1, 0x01010101, sumy); + sumf[l0/4] += delta*sumy; + } + + const uint16_t * sc = (const uint16_t *) bq1->scales; + + iq1m_scale_t scale; + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00F0) | ((sc[2] >> 4) & 0x0F00) | (sc[3] & 0xF000); + const float d = __half2float(scale.f16) * __low2float(bq8_1[iqs].ds); + + const int tmp = sc[iqs/2] >> (6*(iqs%2)); + const int sc0 = 2*((tmp >> 0) & 0x07) + 1; + const int sc1 = 2*((tmp >> 3) & 0x07) + 1; + return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1); } static __device__ __forceinline__ void get_int_from_table_16(const uint32_t & q4, const uint8_t * values, diff --git a/csrc/quantization/gptq_marlin/awq_marlin_repack.cu b/csrc/quantization/gptq_marlin/awq_marlin_repack.cu index c58216d8e00c5..de8d9ef2ee63e 100644 --- a/csrc/quantization/gptq_marlin/awq_marlin_repack.cu +++ b/csrc/quantization/gptq_marlin/awq_marlin_repack.cu @@ -267,3 +267,15 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, } #endif + +torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight, + c10::SymInt size_k, c10::SymInt size_n, + int64_t num_bits) { + int const pack_factor = 32 / num_bits; + auto options = torch::TensorOptions() + .dtype(b_q_weight.dtype()) + .device(b_q_weight.device()); + return torch::empty_symint( + {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, + options); +} diff --git a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu index c71b1bf573263..70d48de12ab05 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu @@ -342,3 +342,15 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, } #endif + +torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight, + torch::Tensor& perm, c10::SymInt size_k, + c10::SymInt size_n, int64_t num_bits) { + int const pack_factor = 32 / num_bits; + auto options = torch::TensorOptions() + .dtype(b_q_weight.dtype()) + .device(b_q_weight.device()); + return torch::empty_symint( + {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, + options); +} diff --git a/csrc/custom/paged_attention/attention_ll4mi.cu b/csrc/rocm/attention.cu similarity index 94% rename from csrc/custom/paged_attention/attention_ll4mi.cu rename to csrc/rocm/attention.cu index b38ec30dfcdc1..eb7c278435ab9 100644 --- a/csrc/custom/paged_attention/attention_ll4mi.cu +++ b/csrc/rocm/attention.cu @@ -1,4 +1,19 @@ -// TODO: add license terms +/* + * Copyright (c) 2024, The vLLM team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #include #include #include @@ -6,8 +21,8 @@ #include "cuda_compat.h" #include -#include "../../attention/dtype_fp8.cuh" -#include "../../quantization/fp8/amd/quant_utils.cuh" +#include "../attention/dtype_fp8.cuh" +#include "../quantization/fp8/amd/quant_utils.cuh" #if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \ defined(__gfx941__) || defined(__gfx942__)) @@ -51,8 +66,6 @@ using _B8x8 = uint2; ////// Non temporal load stores /////// - #if 1 - template __device__ __forceinline__ T load(T* addr) { return addr[0]; @@ -63,83 +76,6 @@ __device__ __forceinline__ void store(T value, T* addr) { addr[0] = value; } - #else - -template -__device__ __forceinline__ T load(const T* addr) { - return __builtin_nontemporal_load(addr); -} - -template <> -__device__ __forceinline__ float2 load(const float2* addr) { - auto addr_alias{reinterpret_cast(addr)}; - auto result = __builtin_nontemporal_load(addr_alias); - auto ret = reinterpret_cast(&result); - return ret[0]; -} - -template <> -__device__ __forceinline__ float4 load(const float4* addr) { - auto addr_alias{reinterpret_cast(addr)}; - auto result1 = __builtin_nontemporal_load(addr_alias); - auto result2 = __builtin_nontemporal_load(addr_alias + 1); - float4 ret{}; - auto ret_alias = reinterpret_cast(&result1); - ret.x = ret_alias->x; - ret.y = ret_alias->y; - ret_alias = reinterpret_cast(&result2); - ret.z = ret_alias->x; - ret.w = ret_alias->y; - return ret; -} - -template <> -__device__ __forceinline__ __half load(const __half* addr) { - auto addr_alias{reinterpret_cast(addr)}; - auto result = __builtin_nontemporal_load(addr_alias); - auto ret = reinterpret_cast<__half*>(&result); - return ret[0]; -} - -template <> -__device__ __forceinline__ __half2 load(const __half2* addr) { - auto addr_alias{reinterpret_cast(addr)}; - auto result = __builtin_nontemporal_load(addr_alias); - auto ret = reinterpret_cast<__half2*>(&result); - return ret[0]; -} - -template <> -__device__ __forceinline__ vllm::Half4_ load(const vllm::Half4_* addr) { - auto addr_alias{reinterpret_cast(addr)}; - auto result = __builtin_nontemporal_load(addr_alias); - auto ret = reinterpret_cast(&result); - return ret[0]; -} - -template <> -__device__ __forceinline__ vllm::Half8_ load(const vllm::Half8_* addr) { - auto addr_alias{reinterpret_cast(addr)}; - auto result1 = __builtin_nontemporal_load(addr_alias); - auto result2 = __builtin_nontemporal_load(addr_alias + 1); - vllm::Half8_ ret{}; - auto ret_alias = reinterpret_cast(&result1); - ret.x = ret_alias->x; - ret.y = ret_alias->y; - ret_alias = reinterpret_cast(&result2); - ret.z = ret_alias->x; - ret.w = ret_alias->y; - return ret; -} - -//// Not using nontemporal stores for now -template -__device__ __forceinline__ void store(T value, T* addr) { - return __builtin_nontemporal_store(value, addr); -} - - #endif - template __device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA, const _B16x4& inpB, @@ -673,7 +609,6 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } } else { // warp in context - // iterate across heads #pragma unroll for (int qh = 0; qh < QHLOOP; qh++) { @@ -1136,7 +1071,7 @@ void paged_attention_custom_launcher( break; \ } -void paged_attention_custom( +void paged_attention( torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] diff --git a/csrc/custom/custom.cu b/csrc/rocm/custom.cu similarity index 100% rename from csrc/custom/custom.cu rename to csrc/rocm/custom.cu diff --git a/csrc/custom/custom_kernels.cu b/csrc/rocm/custom_kernels.cu similarity index 100% rename from csrc/custom/custom_kernels.cu rename to csrc/rocm/custom_kernels.cu diff --git a/csrc/custom/fused_kernels.cu b/csrc/rocm/fused_kernels.cu similarity index 100% rename from csrc/custom/fused_kernels.cu rename to csrc/rocm/fused_kernels.cu diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h new file mode 100644 index 0000000000000..18c72f937f90a --- /dev/null +++ b/csrc/rocm/ops.h @@ -0,0 +1,23 @@ +#pragma once + +#include + +void LLMM_Silu(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t rows_per_block); + +void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t rows_per_block); + +void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t N_in, const int64_t CuCount); + +void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, + torch::Tensor& max_logits, torch::Tensor& tmp_out, + torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int64_t num_kv_heads, + double scale, torch::Tensor& block_tables, + torch::Tensor& context_lens, int64_t block_size, + int64_t max_context_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double k_scale, + double v_scale); diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp new file mode 100644 index 0000000000000..2efa03e87e214 --- /dev/null +++ b/csrc/rocm/torch_bindings.cpp @@ -0,0 +1,46 @@ +#include "core/registration.h" +#include "rocm/ops.h" + +// Note on op signatures: +// The X_meta signatures are for the meta functions corresponding to op X. +// They must be kept in sync with the signature for X. Generally, only +// functions that return Tensors require a meta function. +// +// See the following links for detailed docs on op registration and function +// schemas. +// https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9 +// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { + // vLLM custom ops for rocm + rocm_ops.def( + "LLMM1(Tensor in_a, Tensor in_b, Tensor! out_c, int rows_per_block) -> " + "()"); + rocm_ops.impl("LLMM1", torch::kCUDA, &LLMM1); + rocm_ops.def( + "LLMM_Silu(Tensor in_a, Tensor in_b, Tensor! out_c, int rows_per_block) " + "-> ()"); + rocm_ops.impl("LLMM_Silu", torch::kCUDA, &LLMM_Silu); + + // Custom attention op + // Compute the attention between an input query and the cached + // keys/values using PagedAttention. + rocm_ops.def( + "paged_attention(Tensor! out, Tensor exp_sums," + " Tensor max_logits, Tensor tmp_out," + " Tensor query, Tensor key_cache," + " Tensor value_cache, int num_kv_heads," + " float scale, Tensor block_tables," + " Tensor context_lens, int block_size," + " int max_context_len," + " Tensor? alibi_slopes," + " str kv_cache_dtype," + " float k_scale, float v_scale) -> ()"); + rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention); + rocm_ops.def( + "wvSpltK(Tensor in_a, Tensor in_b, Tensor! out_c, int N_in," + " int CuCount) -> ()"); + rocm_ops.impl("wvSpltK", torch::kCUDA, &wvSpltK); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 21408e03fc340..51b03df5d5976 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -36,8 +36,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // PagedAttention V2. ops.def( "paged_attention_v2(" - " Tensor! out, Tensor exp_sums, Tensor max_logits," - " Tensor tmp_out, Tensor query, Tensor key_cache," + " Tensor! out, Tensor! exp_sums, Tensor! max_logits," + " Tensor! tmp_out, Tensor query, Tensor key_cache," " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," @@ -77,8 +77,23 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("gelu_quick", torch::kCUDA, &gelu_quick); // prepare_inputs advance_step - ops.def("advance_step", &advance_step); - ops.impl("advance_step", torch::kCUDA, &advance_step); + ops.def( + "advance_step_flashattn(int num_seqs, int num_queries, int block_size, " + "Tensor! input_tokens, Tensor sampled_token_ids, " + "Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping, " + "Tensor block_tables) -> ()"); + ops.impl("advance_step_flashattn", torch::kCUDA, &advance_step_flashattn); + + ops.def( + "advance_step_flashinfer(" + " int num_seqs, int num_queries, int block_size," + " Tensor! input_tokens, Tensor sampled_token_ids," + " Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping," + " Tensor block_tables, Tensor! paged_kv_indices," + " Tensor! paged_kv_indptr, Tensor! paged_kv_last_page_len," + " Tensor! block_table_bounds" + ") -> ()"); + ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer); // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. @@ -130,27 +145,56 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Quantization ops #ifndef USE_ROCM // Quantized GEMM for AQLM. - ops.def("aqlm_gemm", &aqlm_gemm); + ops.def( + "aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, " + "Tensor scales, int[] codebook_partition_sizes, Tensor? bias) " + "-> Tensor"); ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm); // Decompression method for AQLM. - ops.def("aqlm_dequant", &aqlm_dequant); + ops.def( + "aqlm_dequant(Tensor codes, Tensor codebooks, " + "int[] codebook_partition_sizes) -> Tensor"); ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant); // Quantized GEMM for AWQ. - ops.def("awq_gemm", &awq_gemm); + ops.def( + "awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, " + "Tensor _zeros, int split_k_iters) -> Tensor"); ops.impl("awq_gemm", torch::kCUDA, &awq_gemm); // Dequantization for AWQ. - ops.def("awq_dequantize", &awq_dequantize); + ops.def( + "awq_dequantize(Tensor _kernel, Tensor _scaling_factors, " + "Tensor _zeros, int split_k_iters, int thx, int thy) -> Tensor"); ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize); + // Note about marlin kernel 'workspace' arguments: + // Technically these should be mutable since they are modified by the kernel. + // But since they are set back to zero once the kernel is finished we can + // hand wave and say that they have no net effect. + // + // The reason to mark 'workspace' as immutable is so that they don't interfere + // with using ScalarType arguments in the ops. If they are marked as mutable, + // pytorch throws an assert in + // 'torch._higher_order_ops._register_effectful_op' that prevents these + // kernels from being torch.compile'd. + // See the following document for more info on custom types and ops that use + // custom types: + // https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA + // Marlin (Dense) Optimized Quantized GEMM for GPTQ. - ops.def("marlin_gemm", &marlin_gemm); + ops.def( + "marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " + "Tensor! workspace, int size_m, int size_n, int size_k) -> Tensor"); ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm); // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ. - ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm); + ops.def( + "gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, " + "Tensor b_scales, Tensor workspace, " + "__torch__.torch.classes._core_C.ScalarType b_q_type, " + "int size_m, int size_n, int size_k) -> Tensor"); ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm); // Machete (Dense) Optimized Mixed Precision GEMM for Hopper. @@ -169,35 +213,55 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B); // gptq_marlin Optimized Quantized GEMM for GPTQ. - ops.def("gptq_marlin_gemm", &gptq_marlin_gemm); + ops.def( + "gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " + "Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, " + "__torch__.torch.classes._core_C.ScalarType b_q_type, " + "int size_m, int size_n, int size_k, bool is_k_full, " + "bool has_zp, bool use_fp32_reduce) -> Tensor"); ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm); // gptq_marlin repack from GPTQ. - ops.def("gptq_marlin_repack", &gptq_marlin_repack); + ops.def( + "gptq_marlin_repack(Tensor b_q_weight, Tensor perm, " + "SymInt size_k, SymInt size_n, int num_bits) -> Tensor"); ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack); + ops.impl("gptq_marlin_repack", torch::kMeta, &gptq_marlin_repack_meta); // awq_marlin repack from AWQ. - ops.def("awq_marlin_repack", &awq_marlin_repack); + ops.def( + "awq_marlin_repack(Tensor b_q_weight, SymInt size_k, " + "SymInt size_n, int num_bits) -> Tensor"); ops.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack); + ops.impl("awq_marlin_repack", torch::kMeta, &awq_marlin_repack_meta); // Dequantization for GGML. - ops.def("ggml_dequantize", &ggml_dequantize); + ops.def("ggml_dequantize(Tensor W, int type, int m, int n) -> Tensor"); ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize); // mmvq kernel for GGML. - ops.def("ggml_mul_mat_vec_a8", &ggml_mul_mat_vec_a8); + ops.def( + "ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, int row) " + "-> Tensor"); ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8); // mmq kernel for GGML. - ops.def("ggml_mul_mat_a8", &ggml_mul_mat_a8); + ops.def("ggml_mul_mat_a8(Tensor W, Tensor X, int type, int row) -> Tensor"); ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8); // fp8_marlin Optimized Quantized GEMM for FP8 weight-only. - ops.def("fp8_marlin_gemm", &fp8_marlin_gemm); + ops.def( + "fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " + "Tensor! workspace, int num_bits, int size_m, int size_n, " + "int size_k) -> Tensor"); ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm); // marlin_qqq_gemm for QQQ. - ops.def("marlin_qqq_gemm", &marlin_qqq_gemm); + ops.def( + "marlin_qqq_gemm(Tensor a, Tensor b_q_weight, " + "Tensor s_tok, Tensor s_ch, Tensor s_group, " + "Tensor! workspace, int size_m, int size_n, " + "int size_k) -> Tensor"); ops.impl("marlin_qqq_gemm", torch::kCUDA, &marlin_qqq_gemm); // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column @@ -219,16 +283,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Check if cutlass scaled_mm is supported for CUDA devices of the given // capability - ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8); - ops.impl("cutlass_scaled_mm_supports_fp8", torch::kCUDA, - &cutlass_scaled_mm_supports_fp8); + ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool"); + ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8); + // Mamba selective scan kernel ops.def( "selective_scan_fwd(Tensor! u, Tensor! delta," "Tensor! A, Tensor! B, Tensor! C," "Tensor? D_, Tensor? z_, Tensor? delta_bias_," "bool delta_softplus," - "Tensor? index_, Tensor? x) -> Tensor[]"); + "Tensor? index_, Tensor(a! -> *)? x) -> Tensor(a)[]"); ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); ops.def( @@ -250,7 +314,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { #endif // Quantized GEMM for GPTQ. - ops.def("gptq_gemm", &gptq_gemm); + // Note: even though the C++ inferred schema is correct for this op, it seems + // to prevent the meta function registry. + ops.def( + "gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, " + "Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) " + "-> Tensor"); ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm); // Post processing for GPTQ. @@ -270,8 +339,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Compute dynamic-per-token FP8 quantized tensor and scaling factor. ops.def( - "dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! " - "scale, Tensor? scale_ub) -> " + "dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, " + "Tensor! scale, Tensor? scale_ub) -> " "()"); ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA, &dynamic_per_token_scaled_fp8_quant); @@ -308,8 +377,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { // Copy the cache blocks from src to dst. cache_ops.def( - "copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor " - "block_mapping) -> ()"); + "copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, " + "Tensor block_mapping) -> ()"); cache_ops.impl("copy_blocks", torch::kCUDA, ©_blocks); // Reshape the key and value tensors and cache them. @@ -334,8 +403,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { // Convert the key and value cache to fp8 data type. cache_ops.def( - "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, str " - "kv_cache_dtype) -> ()"); + "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, " + "str kv_cache_dtype) -> ()"); cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8); } @@ -343,23 +412,27 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { // Cuda utils // Gets the specified device attribute. - cuda_utils.def("get_device_attribute", &get_device_attribute); - cuda_utils.impl("get_device_attribute", torch::kCUDA, &get_device_attribute); + cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int"); + cuda_utils.impl("get_device_attribute", &get_device_attribute); // Gets the maximum shared memory per block device attribute. - cuda_utils.def("get_max_shared_memory_per_block_device_attribute", - &get_max_shared_memory_per_block_device_attribute); + cuda_utils.def( + "get_max_shared_memory_per_block_device_attribute(int device_id) -> int"); cuda_utils.impl("get_max_shared_memory_per_block_device_attribute", - torch::kCUDA, &get_max_shared_memory_per_block_device_attribute); } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { // Custom all-reduce kernels - custom_ar.def("init_custom_ar", &init_custom_ar); + custom_ar.def( + "init_custom_ar(Tensor meta, Tensor rank_data, " + "str[] handles, int[] offsets, int rank, " + "bool full_nvlink) -> int"); custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar); - custom_ar.def("should_custom_ar", &should_custom_ar); + custom_ar.def( + "should_custom_ar(Tensor inp, int max_size, int world_size, " + "bool full_nvlink) -> bool"); custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar); custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()"); @@ -371,21 +444,15 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg); custom_ar.def("dispose", &dispose); - custom_ar.impl("dispose", torch::kCPU, &dispose); - custom_ar.def("meta_size", &meta_size); - custom_ar.impl("meta_size", torch::kCPU, &meta_size); - custom_ar.def("register_buffer", ®ister_buffer); + custom_ar.def( + "register_buffer(int fa, Tensor t, str[] handles, " + "int[] offsets) -> ()"); custom_ar.impl("register_buffer", torch::kCUDA, ®ister_buffer); custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta); - custom_ar.impl("get_graph_buffer_ipc_meta", torch::kCPU, - &get_graph_buffer_ipc_meta); - custom_ar.def("register_graph_buffers", ®ister_graph_buffers); - custom_ar.impl("register_graph_buffers", torch::kCPU, - ®ister_graph_buffers); #ifdef USE_ROCM custom_ar.def("allocate_meta_buffer", &allocate_meta_buffer); custom_ar.impl("allocate_meta_buffer", torch::kCUDA, &allocate_meta_buffer); diff --git a/docs/source/community/meetups.rst b/docs/source/community/meetups.rst index 3b01b109ebf2c..a3962e96e7913 100644 --- a/docs/source/community/meetups.rst +++ b/docs/source/community/meetups.rst @@ -5,6 +5,7 @@ vLLM Meetups We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below: +- `The sixth vLLM meetup `__, with NVIDIA, September 9th 2024. `[Slides] `__ - `The fifth vLLM meetup `__, with AWS, July 24th 2024. `[Slides] `__ - `The fourth vLLM meetup `__, with Cloudflare and BentoML, June 11th 2024. `[Slides] `__ - `The third vLLM meetup `__, with Roblox, April 2nd 2024. `[Slides] `__ diff --git a/docs/source/conf.py b/docs/source/conf.py index b4f5b4ab9d569..8435129e752e1 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -99,6 +99,7 @@ def setup(app): "aiohttp", "compressed_tensors", "cpuinfo", + "cv2", "torch", "transformers", "psutil", diff --git a/docs/source/getting_started/cpu-installation.rst b/docs/source/getting_started/cpu-installation.rst index 7fc469e06844f..816e0a29ef28b 100644 --- a/docs/source/getting_started/cpu-installation.rst +++ b/docs/source/getting_started/cpu-installation.rst @@ -59,6 +59,20 @@ Build from source $ pip install wheel packaging ninja "setuptools>=49.4.0" numpy $ pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu +- Third, build and install oneDNN library from source: + +.. code-block:: console + + $ git clone -b rls-v3.5 https://github.com/oneapi-src/oneDNN.git + $ cmake -B ./oneDNN/build -S ./oneDNN -G Ninja -DONEDNN_LIBRARY_TYPE=STATIC \ + -DONEDNN_BUILD_DOC=OFF \ + -DONEDNN_BUILD_EXAMPLES=OFF \ + -DONEDNN_BUILD_TESTS=OFF \ + -DONEDNN_BUILD_GRAPH=OFF \ + -DONEDNN_ENABLE_WORKLOAD=INFERENCE \ + -DONEDNN_ENABLE_PRIMITIVE=MATMUL + $ cmake --build ./oneDNN/build --target install --config Release + - Finally, build and install vLLM CPU backend: .. code-block:: console diff --git a/docs/source/getting_started/installation.rst b/docs/source/getting_started/installation.rst index f0e54c29fcad7..50a761b49490c 100644 --- a/docs/source/getting_started/installation.rst +++ b/docs/source/getting_started/installation.rst @@ -26,6 +26,10 @@ You can install vLLM using pip: $ # Install vLLM with CUDA 12.1. $ pip install vllm +.. note:: + + Although we recommend using ``conda`` to create and manage Python environments, it is highly recommended to use ``pip`` to install vLLM. This is because ``pip`` can install ``torch`` with separate library packages like ``NCCL``, while ``conda`` installs ``torch`` with statically linked ``NCCL``. This can cause issues when vLLM tries to use ``NCCL``. See `this issue `_ for more details. + .. note:: As of now, vLLM's binaries are compiled with CUDA 12.1 and public PyTorch release versions by default. @@ -34,7 +38,7 @@ You can install vLLM using pip: .. code-block:: console $ # Install vLLM with CUDA 11.8. - $ export VLLM_VERSION=0.4.0 + $ export VLLM_VERSION=0.6.1.post1 $ export PYTHON_VERSION=310 $ pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux1_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118 @@ -48,7 +52,7 @@ You can install vLLM using pip: .. code-block:: console - $ export VLLM_VERSION=0.5.4 # vLLM's main branch version is currently set to latest released tag + $ export VLLM_VERSION=0.6.1.post1 # vLLM's main branch version is currently set to latest released tag $ pip install https://vllm-wheels.s3.us-west-2.amazonaws.com/nightly/vllm-${VLLM_VERSION}-cp38-abi3-manylinux1_x86_64.whl $ # You can also access a specific commit $ # export VLLM_COMMIT=... @@ -80,11 +84,11 @@ You can also build and install vLLM from source: .. tip:: - Building from source requires quite a lot compilation. If you are building from source for multiple times, it is beneficial to cache the compilation results. For example, you can install `ccache `_ via either `conda install ccache` or `apt install ccache` . As long as `which ccache` command can find the `ccache` binary, it will be used automatically by the build system. After the first build, the subsequent builds will be much faster. + Building from source requires quite a lot compilation. If you are building from source for multiple times, it is beneficial to cache the compilation results. For example, you can install `ccache `_ via either ``conda install ccache`` or ``apt install ccache`` . As long as ``which ccache`` command can find the ``ccache`` binary, it will be used automatically by the build system. After the first build, the subsequent builds will be much faster. .. tip:: To avoid your system being overloaded, you can limit the number of compilation jobs - to be run simultaneously, via the environment variable `MAX_JOBS`. For example: + to be run simultaneously, via the environment variable ``MAX_JOBS``. For example: .. code-block:: console @@ -99,7 +103,7 @@ You can also build and install vLLM from source: $ # Use `--ipc=host` to make sure the shared memory is large enough. $ docker run --gpus all -it --rm --ipc=host nvcr.io/nvidia/pytorch:23.10-py3 - If you don't want to use docker, it is recommended to have a full installation of CUDA Toolkit. You can download and install it from `the official website `_. After installation, set the environment variable `CUDA_HOME` to the installation path of CUDA Toolkit, and make sure that the `nvcc` compiler is in your `PATH`, e.g.: + If you don't want to use docker, it is recommended to have a full installation of CUDA Toolkit. You can download and install it from `the official website `_. After installation, set the environment variable ``CUDA_HOME`` to the installation path of CUDA Toolkit, and make sure that the ``nvcc`` compiler is in your ``PATH``, e.g.: .. code-block:: console diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 1bb3a448f2c92..3dcc242803752 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -107,6 +107,10 @@ Decoder-only Language Models - MiniCPM - :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc. - + * - :code:`MiniCPM3ForCausalLM` + - MiniCPM3 + - :code:`openbmb/MiniCPM3-4B`, etc. + - * - :code:`MistralForCausalLM` - Mistral, Mistral-Instruct - :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc. @@ -227,6 +231,11 @@ Multimodal Language Models - Image\ :sup:`E+` - :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc. - + * - :code:`LlavaNextVideoForConditionalGeneration` + - LLaVA-NeXT-Video + - Video + - :code:`llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. (see note) + - * - :code:`MiniCPMV` - MiniCPM-V - Image\ :sup:`+` @@ -242,11 +251,21 @@ Multimodal Language Models - Image\ :sup:`E+` - :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc. - + * - :code:`PixtralForConditionalGeneration` + - Pixtral + - Image\ :sup:`+` + - :code:`mistralai/Pixtral-12B-2409` + - * - :code:`QWenLMHeadModel` - Qwen-VL - - Image\ :sup:`E` + - Image\ :sup:`E+` - :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc. - + * - :code:`Qwen2VLForConditionalGeneration` + - Qwen2-VL (see note) + - Image\ :sup:`+` / Video\ :sup:`+` + - :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc. + - * - :code:`UltravoxModel` - Ultravox - Audio\ :sup:`E+` @@ -260,6 +279,14 @@ Multimodal Language Models For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 +.. note:: + For :code:`LLaVA-NeXT-Video` and :code:`Qwen2-VL`, the latest release of :code:`huggingface/transformers` doesn't work yet, so we need to use a developer version (:code:`21fac7abba2a37fae86106f87fcf9974fd1e3830`) for now. + This can be installed by running the following command: + + .. code-block:: bash + + pip install git+https://github.com/huggingface/transformers.git@21fac7abba2a37fae86106f87fcf9974fd1e3830 + ---- If your model uses one of the above model architectures, you can seamlessly run your model with vLLM. @@ -319,7 +346,7 @@ Note that, as an inference engine, vLLM does not introduce new models. Therefore We have the following levels of testing for models: -1. **Strict Consistency**: We compare the output of the model with the output of the model in the HuggingFace Transformers library under greedy decoding. This is the most stringent test. Please refer to `test_models.py `_ and `test_big_models.py `_ for the models that have passed this test. +1. **Strict Consistency**: We compare the output of the model with the output of the model in the HuggingFace Transformers library under greedy decoding. This is the most stringent test. Please refer to `models tests `_ for the models that have passed this test. 2. **Output Sensibility**: We check if the output of the model is sensible and coherent, by measuring the perplexity of the output and checking for any obvious errors. This is a less stringent test. 3. **Runtime Functionality**: We check if the model can be loaded and run without errors. This is the least stringent test. Please refer to `functionality tests `_ and `examples `_ for the models that have passed this test. 4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category. diff --git a/examples/fp8/quantizer/README.md b/examples/fp8/quantizer/README.md index 0b6944f688b49..d0895e97dc341 100644 --- a/examples/fp8/quantizer/README.md +++ b/examples/fp8/quantizer/README.md @@ -1,6 +1,6 @@ ### Quantizer Utilities -`quantize.py`: NVIDIA Quantization utilities using AMMO, ported from TensorRT-LLM: -`https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/quantize.py` +`quantize.py`: NVIDIA Quantization utilities using TensorRT-Model-Optimizer, ported +from TensorRT-LLM: [`examples/quantization/quantize.py`](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/quantize.py) ### Prerequisite diff --git a/examples/offline_inference_pixtral.py b/examples/offline_inference_pixtral.py new file mode 100644 index 0000000000000..c12ff7021cf51 --- /dev/null +++ b/examples/offline_inference_pixtral.py @@ -0,0 +1,165 @@ +# ruff: noqa +import argparse + +from vllm import LLM +from vllm.sampling_params import SamplingParams + +# This script is an offline demo for running Pixtral. +# +# If you want to run a server/client setup, please follow this code: +# +# - Server: +# +# ```bash +# vllm serve mistralai/Pixtral-12B-2409 --tokenizer-mode mistral --limit-mm-per-prompt 'image=4' --max-model-len 16384 +# ``` +# +# - Client: +# +# ```bash +# curl --location 'http://:8000/v1/chat/completions' \ +# --header 'Content-Type: application/json' \ +# --header 'Authorization: Bearer token' \ +# --data '{ +# "model": "mistralai/Pixtral-12B-2409", +# "messages": [ +# { +# "role": "user", +# "content": [ +# {"type" : "text", "text": "Describe this image in detail please."}, +# {"type": "image_url", "image_url": {"url": "https://s3.amazonaws.com/cms.ipressroom.com/338/files/201808/5b894ee1a138352221103195_A680%7Ejogging-edit/A680%7Ejogging-edit_hero.jpg"}}, +# {"type" : "text", "text": "and this one as well. Answer in French."}, +# {"type": "image_url", "image_url": {"url": "https://www.wolframcloud.com/obj/resourcesystem/images/a0e/a0ee3983-46c6-4c92-b85d-059044639928/6af8cfb971db031b.png"}} +# ] +# } +# ] +# }' +# ``` +# +# Usage: +# python demo.py simple +# python demo.py advanced + + +def run_simple_demo(): + model_name = "mistralai/Pixtral-12B-2409" + sampling_params = SamplingParams(max_tokens=8192) + + # Lower max_num_seqs or max_model_len on low-VRAM GPUs. + llm = LLM(model=model_name, tokenizer_mode="mistral") + + prompt = "Describe this image in one sentence." + image_url = "https://picsum.photos/id/237/200/300" + + messages = [ + { + "role": + "user", + "content": [ + { + "type": "text", + "text": prompt + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + ], + }, + ] + outputs = llm.chat(messages, sampling_params=sampling_params) + + print(outputs[0].outputs[0].text) + + +def run_advanced_demo(): + model_name = "mistralai/Pixtral-12B-2409" + max_img_per_msg = 5 + max_tokens_per_img = 4096 + + sampling_params = SamplingParams(max_tokens=8192, temperature=0.7) + llm = LLM( + model=model_name, + tokenizer_mode="mistral", + limit_mm_per_prompt={"image": max_img_per_msg}, + max_model_len=max_img_per_msg * max_tokens_per_img, + ) + + prompt = "Describe the following image." + + url_1 = "https://huggingface.co/datasets/patrickvonplaten/random_img/resolve/main/yosemite.png" + url_2 = "https://picsum.photos/seed/picsum/200/300" + url_3 = "https://picsum.photos/id/32/512/512" + + messages = [ + { + "role": + "user", + "content": [ + { + "type": "text", + "text": prompt + }, + { + "type": "image_url", + "image_url": { + "url": url_1 + } + }, + { + "type": "image_url", + "image_url": { + "url": url_2 + } + }, + ], + }, + { + "role": "assistant", + "content": "The images show nature.", + }, + { + "role": "user", + "content": "More details please and answer only in French!.", + }, + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": url_3 + } + }, + ], + }, + ] + + outputs = llm.chat(messages=messages, sampling_params=sampling_params) + print(outputs[0].outputs[0].text) + + +def main(): + parser = argparse.ArgumentParser( + description="Run a demo in simple or advanced mode.") + + parser.add_argument( + "mode", + choices=["simple", "advanced"], + help="Specify the demo mode: 'simple' or 'advanced'", + ) + + args = parser.parse_args() + + if args.mode == "simple": + print("Running simple demo...") + run_simple_demo() + elif args.mode == "advanced": + print("Running advanced demo...") + run_advanced_demo() + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index aa1580343aee7..464eaf334e3de 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -9,12 +9,9 @@ from vllm import LLM, SamplingParams from vllm.assets.image import ImageAsset +from vllm.assets.video import VideoAsset from vllm.utils import FlexibleArgumentParser -# Input image and question -image = ImageAsset("cherry_blossom").pil_image.convert("RGB") -question = "What is the content of this image?" - # LLaVA-1.5 def run_llava(question): @@ -30,7 +27,16 @@ def run_llava(question): def run_llava_next(question): prompt = f"[INST] \n{question} [/INST]" - llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf") + llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=8192) + stop_token_ids = None + return llm, prompt, stop_token_ids + + +# LlaVA-NeXT-Video +# Currently only support for video input +def run_llava_next_video(question): + prompt = f"USER: