diff --git a/.buildkite/check-wheel-size.py b/.buildkite/check-wheel-size.py index 0412c5f37952d..e29eb78a9f945 100644 --- a/.buildkite/check-wheel-size.py +++ b/.buildkite/check-wheel-size.py @@ -2,8 +2,11 @@ import sys import zipfile -# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 250 MB -VLLM_MAX_SIZE_MB = int(os.environ.get('VLLM_MAX_SIZE_MB', 250)) +# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 300 MiB +# Note that we have 400 MiB quota, please use it wisely. +# See https://github.com/pypi/support/issues/3792 . +# Please also sync the value with the one in Dockerfile. +VLLM_MAX_SIZE_MB = int(os.environ.get('VLLM_MAX_SIZE_MB', 300)) def print_top_10_largest_files(zip_file): diff --git a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml index 708e548727cf5..679abf1814aa5 100644 --- a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml +++ b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml @@ -1,5 +1,6 @@ steps: - label: "Wait for container to be ready" + key: wait-for-container-image agents: queue: A100 plugins: @@ -10,12 +11,11 @@ steps: command: - sh .buildkite/nightly-benchmarks/scripts/wait-for-image.sh - - wait - - label: "A100" # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing" agents: queue: A100 + depends_on: wait-for-container-image plugins: - kubernetes: podSpec: @@ -49,6 +49,7 @@ steps: # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing" agents: queue: H200 + depends_on: wait-for-container-image plugins: - docker#v5.12.0: image: public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:$BUILDKITE_COMMIT @@ -73,7 +74,7 @@ steps: # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing" agents: queue: H100 - depends_on: block-h100 + depends_on: wait-for-container-image plugins: - docker#v5.12.0: image: public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:$BUILDKITE_COMMIT diff --git a/.buildkite/nightly-benchmarks/scripts/nightly-annotate.sh b/.buildkite/nightly-benchmarks/scripts/nightly-annotate.sh index 686f70dbece6c..69b6b146b3549 100644 --- a/.buildkite/nightly-benchmarks/scripts/nightly-annotate.sh +++ b/.buildkite/nightly-benchmarks/scripts/nightly-annotate.sh @@ -43,7 +43,7 @@ main() { - # The figures should be genereated by a separate process outside the CI/CD pipeline + # The figures should be generated by a separate process outside the CI/CD pipeline # # generate figures # python3 -m pip install tabulate pandas matplotlib diff --git a/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh b/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh index 3f38cf5137535..32bd34c431c89 100644 --- a/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh +++ b/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh @@ -301,6 +301,104 @@ run_serving_tests() { kill_gpu_processes } +run_genai_perf_tests() { + # run genai-perf tests + + # $1: a json file specifying genai-perf test cases + local genai_perf_test_file + genai_perf_test_file=$1 + + # Iterate over genai-perf tests + jq -c '.[]' "$genai_perf_test_file" | while read -r params; do + # get the test name, and append the GPU type back to it. + test_name=$(echo "$params" | jq -r '.test_name') + + # if TEST_SELECTOR is set, only run the test cases that match the selector + if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then + echo "Skip test case $test_name." + continue + fi + + # prepend the current serving engine to the test name + test_name=${CURRENT_LLM_SERVING_ENGINE}_${test_name} + + # get common parameters + common_params=$(echo "$params" | jq -r '.common_parameters') + model=$(echo "$common_params" | jq -r '.model') + tp=$(echo "$common_params" | jq -r '.tp') + dataset_name=$(echo "$common_params" | jq -r '.dataset_name') + dataset_path=$(echo "$common_params" | jq -r '.dataset_path') + port=$(echo "$common_params" | jq -r '.port') + num_prompts=$(echo "$common_params" | jq -r '.num_prompts') + reuse_server=$(echo "$common_params" | jq -r '.reuse_server') + + # get client and server arguments + server_params=$(echo "$params" | jq -r ".${CURRENT_LLM_SERVING_ENGINE}_server_parameters") + qps_list=$(echo "$params" | jq -r '.qps_list') + qps_list=$(echo "$qps_list" | jq -r '.[] | @sh') + echo "Running over qps list $qps_list" + + # check if there is enough GPU to run the test + if [[ $gpu_count -lt $tp ]]; then + echo "Required num-shard $tp but only $gpu_count GPU found. Skip testcase $test_name." + continue + fi + + if [[ $reuse_server == "true" ]]; then + echo "Reuse previous server for test case $test_name" + else + kill_gpu_processes + bash "$VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/scripts/launch-server.sh" \ + "$server_params" "$common_params" + fi + + if wait_for_server; then + echo "" + echo "$CURRENT_LLM_SERVING_ENGINE server is up and running." + else + echo "" + echo "$CURRENT_LLM_SERVING_ENGINE failed to start within the timeout period." + break + fi + + # iterate over different QPS + for qps in $qps_list; do + # remove the surrounding single quote from qps + if [[ "$qps" == *"inf"* ]]; then + echo "qps was $qps" + qps=$num_prompts + echo "now qps is $qps" + fi + + new_test_name=$test_name"_qps_"$qps + backend=$CURRENT_LLM_SERVING_ENGINE + + if [[ "$backend" == *"vllm"* ]]; then + backend="vllm" + fi + #TODO: add output dir. + client_command="genai-perf profile \ + -m $model \ + --service-kind openai \ + --backend vllm \ + --endpoint-type chat \ + --streaming \ + --url localhost:$port \ + --request-rate $qps \ + --num-prompts $num_prompts \ + " + + echo "Client command: $client_command" + + eval "$client_command" + + #TODO: process/record outputs + done + done + + kill_gpu_processes + +} prepare_dataset() { @@ -328,12 +426,17 @@ main() { pip install -U transformers + pip install -r requirements-dev.txt + which genai-perf + # check storage df -h ensure_installed wget ensure_installed curl ensure_installed jq + # genai-perf dependency + ensure_installed libb64-0d prepare_dataset @@ -345,6 +448,10 @@ main() { # run the test run_serving_tests "$BENCHMARK_ROOT/tests/nightly-tests.json" + # run genai-perf tests + run_genai_perf_tests "$BENCHMARK_ROOT/tests/genai-perf-tests.json" + mv artifacts/ $RESULTS_FOLDER/ + # upload benchmark results to buildkite python3 -m pip install tabulate pandas python3 "$BENCHMARK_ROOT/scripts/summary-nightly-results.py" diff --git a/.buildkite/nightly-benchmarks/tests/genai-perf-tests.json b/.buildkite/nightly-benchmarks/tests/genai-perf-tests.json new file mode 100644 index 0000000000000..edbe9f2df0ce0 --- /dev/null +++ b/.buildkite/nightly-benchmarks/tests/genai-perf-tests.json @@ -0,0 +1,23 @@ +[ + { + "test_name": "llama8B_tp1_genai_perf", + "qps_list": [4,8,16,32], + "common_parameters": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "tp": 1, + "port": 8000, + "num_prompts": 500, + "reuse_server": false + }, + "vllm_server_parameters": { + "disable_log_stats": "", + "disable_log_requests": "", + "gpu_memory_utilization": 0.9, + "num_scheduler_steps": 10, + "max_num_seqs": 512, + "dtype": "bfloat16" + }, + "genai_perf_input_parameters": { + } + } +] \ No newline at end of file diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 51618a2955fb1..829414bf8a3ba 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -56,6 +56,11 @@ steps: env: DOCKER_BUILDKIT: "1" + - input: "Provide Release version here" + fields: + - text: "What is the release version?" + key: "release-version" + - block: "Build CPU release image" key: block-cpu-release-image-build depends_on: ~ @@ -66,7 +71,7 @@ steps: queue: cpu_queue_postmerge commands: - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$RELEASE_VERSION --progress plain -f Dockerfile.cpu ." - - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$RELEASE_VERSION" + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --progress plain -f Dockerfile.cpu ." + - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)" env: DOCKER_BUILDKIT: "1" diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index 4f1729d46dae2..e19ace782feb5 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -9,36 +9,33 @@ CORE_RANGE=${CORE_RANGE:-48-95} NUMA_NODE=${NUMA_NODE:-1} # Try building the docker image -numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build -t cpu-test -f Dockerfile.cpu . -numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" -t cpu-test-avx2 -f Dockerfile.cpu . +numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build -t cpu-test-"$BUILDKITE_BUILD_NUMBER" -f Dockerfile.cpu . +numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" -t cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2 -f Dockerfile.cpu . # Setup cleanup -remove_docker_container() { docker rm -f cpu-test-"$NUMA_NODE" cpu-test-avx2-"$NUMA_NODE" || true; } +remove_docker_container() { set -e; docker rm -f cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2-"$NUMA_NODE" || true; } trap remove_docker_container EXIT remove_docker_container # Run the image, setting --shm-size=4g for tensor parallel. docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus="$CORE_RANGE" \ - --cpuset-mems="$NUMA_NODE" --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test + --cpuset-mems="$NUMA_NODE" --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER" docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus="$CORE_RANGE" \ - --cpuset-mems="$NUMA_NODE" --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-avx2-"$NUMA_NODE" cpu-test-avx2 + --cpuset-mems="$NUMA_NODE" --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2 function cpu_tests() { set -e export NUMA_NODE=$2 # offline inference - docker exec cpu-test-avx2-"$NUMA_NODE" bash -c " + docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2-"$NUMA_NODE" bash -c " set -e - python3 examples/offline_inference.py" + python3 examples/offline_inference/basic.py" # Run basic model test - docker exec cpu-test-"$NUMA_NODE" bash -c " + docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " set -e - pip install pytest pytest-asyncio \ - decord einops librosa peft Pillow sentence-transformers soundfile \ - transformers_stream_generator matplotlib datamodel_code_generator - pip install torchvision --index-url https://download.pytorch.org/whl/cpu + pip install -r vllm/requirements-test.txt pytest -v -s tests/models/decoder_only/language -m cpu_model pytest -v -s tests/models/embedding/language -m cpu_model pytest -v -s tests/models/encoder_decoder/language -m cpu_model @@ -46,26 +43,26 @@ function cpu_tests() { pytest -v -s tests/models/decoder_only/vision_language -m cpu_model" # Run compressed-tensor test - docker exec cpu-test-"$NUMA_NODE" bash -c " + docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " set -e pytest -s -v \ tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \ tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynamic_per_token" # Run AWQ test - docker exec cpu-test-"$NUMA_NODE" bash -c " + docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " set -e pytest -s -v \ tests/quantization/test_ipex_quant.py" # Run chunked-prefill and prefix-cache test - docker exec cpu-test-"$NUMA_NODE" bash -c " + docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " set -e pytest -s -v -k cpu_model \ tests/basic_correctness/test_chunked_prefill.py" - # online inference - docker exec cpu-test-"$NUMA_NODE" bash -c " + # online serving + docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " set -e export VLLM_CPU_KVCACHE_SPACE=10 export VLLM_CPU_OMP_THREADS_BIND=$1 @@ -78,8 +75,14 @@ function cpu_tests() { --num-prompts 20 \ --endpoint /v1/completions \ --tokenizer facebook/opt-125m" + + # Run multi-lora tests + docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " + set -e + pytest -s -v \ + tests/lora/test_qwen2vl.py" } -# All of CPU tests are expected to be finished less than 25 mins. +# All of CPU tests are expected to be finished less than 40 mins. export -f cpu_tests -timeout 30m bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" +timeout 40m bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" diff --git a/.buildkite/run-gh200-test.sh b/.buildkite/run-gh200-test.sh index 4fc6d089cc666..3e4e409466b8a 100644 --- a/.buildkite/run-gh200-test.sh +++ b/.buildkite/run-gh200-test.sh @@ -24,5 +24,5 @@ remove_docker_container # Run the image and test offline inference docker run --name gh200-test --gpus=all --entrypoint="" gh200-test bash -c ' - python3 examples/offline_inference.py + python3 examples/offline_inference/basic.py ' diff --git a/.buildkite/run-hpu-test.sh b/.buildkite/run-hpu-test.sh index fa4f74fca7a11..1edcb1d2669e9 100644 --- a/.buildkite/run-hpu-test.sh +++ b/.buildkite/run-hpu-test.sh @@ -8,9 +8,17 @@ set -ex docker build -t hpu-test-env -f Dockerfile.hpu . # Setup cleanup +# certain versions of HPU software stack have a bug that can +# override the exit code of the script, so we need to use +# separate remove_docker_container and remove_docker_container_and_exit +# functions, while other platforms only need one remove_docker_container +# function. +EXITCODE=1 remove_docker_container() { docker rm -f hpu-test || true; } -trap remove_docker_container EXIT +remove_docker_container_and_exit() { remove_docker_container; exit $EXITCODE; } +trap remove_docker_container_and_exit EXIT remove_docker_container # Run the image and launch offline inference -docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference.py \ No newline at end of file +docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic.py +EXITCODE=$? diff --git a/.buildkite/run-neuron-test.sh b/.buildkite/run-neuron-test.sh index 9259391aaed49..1ad77cf50f612 100644 --- a/.buildkite/run-neuron-test.sh +++ b/.buildkite/run-neuron-test.sh @@ -3,6 +3,18 @@ # This script build the Neuron docker image and run the API server inside the container. # It serves a sanity check for compilation and basic model usage. set -e +set -v + +image_name="neuron/vllm-ci" +container_name="neuron_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)" + +HF_CACHE="$(realpath ~)/huggingface" +mkdir -p "${HF_CACHE}" +HF_MOUNT="/root/.cache/huggingface" + +NEURON_COMPILE_CACHE_URL="$(realpath ~)/neuron_compile_cache" +mkdir -p "${NEURON_COMPILE_CACHE_URL}" +NEURON_COMPILE_CACHE_MOUNT="/root/.cache/neuron_compile_cache" # Try building the docker image aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin 763104351884.dkr.ecr.us-west-2.amazonaws.com @@ -13,41 +25,33 @@ if [ -f /tmp/neuron-docker-build-timestamp ]; then last_build=$(cat /tmp/neuron-docker-build-timestamp) current_time=$(date +%s) if [ $((current_time - last_build)) -gt 86400 ]; then - docker system prune -f + # Remove dangling images (those that are not tagged and not used by any container) + docker image prune -f + # Remove unused volumes / force the system prune for old images as well. + docker volume prune -f && docker system prune -f + # Remove huggingface model artifacts and compiler cache + rm -rf "${HF_MOUNT:?}/*" + rm -rf "${NEURON_COMPILE_CACHE_MOUNT:?}/*" echo "$current_time" > /tmp/neuron-docker-build-timestamp fi else date "+%s" > /tmp/neuron-docker-build-timestamp fi -docker build -t neuron -f Dockerfile.neuron . +docker build -t "${image_name}" -f Dockerfile.neuron . # Setup cleanup -remove_docker_container() { docker rm -f neuron || true; } +remove_docker_container() { + docker image rm -f "${image_name}" || true; +} trap remove_docker_container EXIT -remove_docker_container # Run the image -docker run --device=/dev/neuron0 --device=/dev/neuron1 --network host --name neuron neuron python3 -m vllm.entrypoints.api_server \ - --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 --max-num-seqs 8 --max-model-len 128 --block-size 128 --device neuron --tensor-parallel-size 2 & - -# Wait for the server to start -wait_for_server_to_start() { - timeout=300 - counter=0 - - while [ "$(curl -s -o /dev/null -w '%{http_code}' localhost:8000/health)" != "200" ]; do - sleep 1 - counter=$((counter + 1)) - if [ $counter -ge $timeout ]; then - echo "Timeout after $timeout seconds" - break - fi - done -} -wait_for_server_to_start - -# Test a simple prompt -curl -X POST -H "Content-Type: application/json" \ - localhost:8000/generate \ - -d '{"prompt": "San Francisco is a"}' +docker run --rm -it --device=/dev/neuron0 --device=/dev/neuron1 --network host \ + -v "${HF_CACHE}:${HF_MOUNT}" \ + -e "HF_HOME=${HF_MOUNT}" \ + -v "${NEURON_COMPILE_CACHE_URL}:${NEURON_COMPILE_CACHE_MOUNT}" \ + -e "NEURON_COMPILE_CACHE_URL=${NEURON_COMPILE_CACHE_MOUNT}" \ + --name "${container_name}" \ + ${image_name} \ + /bin/bash -c "python3 /workspace/vllm/examples/offline_inference/neuron.py && python3 -m pytest /workspace/vllm/tests/neuron/ -v --capture=tee-sys" diff --git a/.buildkite/run-openvino-test.sh b/.buildkite/run-openvino-test.sh index 6b12f424fd828..6159b21ff8206 100755 --- a/.buildkite/run-openvino-test.sh +++ b/.buildkite/run-openvino-test.sh @@ -13,4 +13,4 @@ trap remove_docker_container EXIT remove_docker_container # Run the image and launch offline inference -docker run --network host --env VLLM_OPENVINO_KVCACHE_SPACE=1 --name openvino-test openvino-test python3 /workspace/examples/offline_inference.py +docker run --network host --env VLLM_OPENVINO_KVCACHE_SPACE=1 --name openvino-test openvino-test python3 /workspace/examples/offline_inference/basic.py diff --git a/.buildkite/run-tpu-test.sh b/.buildkite/run-tpu-test.sh old mode 100644 new mode 100755 index 770dad6ffa3a1..650af0fac4c61 --- a/.buildkite/run-tpu-test.sh +++ b/.buildkite/run-tpu-test.sh @@ -14,4 +14,13 @@ remove_docker_container # For HF_TOKEN. source /etc/environment # Run a simple end-to-end example. -docker run --privileged --net host --shm-size=16G -it -e "HF_TOKEN=$HF_TOKEN" --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && python3 -m pip install lm_eval[api]==0.4.4 && pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py" +docker run --privileged --net host --shm-size=16G -it \ + -e "HF_TOKEN=$HF_TOKEN" --name tpu-test \ + vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \ + && python3 -m pip install pytest \ + && python3 -m pip install lm_eval[api]==0.4.4 \ + && pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py \ + && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \ + && python3 /workspace/vllm/tests/tpu/test_compilation.py \ + && python3 /workspace/vllm/tests/tpu/test_quantization_accuracy.py \ + && python3 /workspace/vllm/examples/offline_inference/tpu.py" diff --git a/.buildkite/run-xpu-test.sh b/.buildkite/run-xpu-test.sh index e0a12afbe7320..4d344e58db8ac 100644 --- a/.buildkite/run-xpu-test.sh +++ b/.buildkite/run-xpu-test.sh @@ -14,6 +14,6 @@ remove_docker_container # Run the image and test offline inference/tensor parallel docker run --name xpu-test --device /dev/dri -v /dev/dri/by-path:/dev/dri/by-path --entrypoint="" xpu-test sh -c ' - python3 examples/offline_inference.py - python3 examples/offline_inference_cli.py -tp 2 + python3 examples/offline_inference/basic.py + python3 examples/offline_inference/cli.py -tp 2 ' diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index b563c96343f92..d5d02fdeb7f4b 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -38,7 +38,7 @@ steps: - pip install -r requirements-docs.txt - SPHINXOPTS=\"-W\" make html # Check API reference (if it fails, you may have missing mock imports) - - grep \"sig sig-object py\" build/html/dev/sampling_params.html + - grep \"sig sig-object py\" build/html/api/inference_params.html - label: Async Engine, Inputs, Utils, Worker Test # 24min fast_check: true @@ -76,7 +76,9 @@ steps: - tests/basic_correctness/test_basic_correctness - tests/basic_correctness/test_cpu_offload - tests/basic_correctness/test_preemption + - tests/basic_correctness/test_cumem.py commands: + - pytest -v -s basic_correctness/test_cumem.py - pytest -v -s basic_correctness/test_basic_correctness.py - pytest -v -s basic_correctness/test_cpu_offload.py - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py @@ -106,14 +108,12 @@ steps: source_file_dependencies: - vllm/ commands: - - pip install -e ./plugins/vllm_add_dummy_model - - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py + - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py --ignore=entrypoints/llm/test_collective_rpc.py - pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process - pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py - - pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process - pytest -v -s entrypoints/test_chat_utils.py - pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests @@ -127,11 +127,15 @@ steps: - tests/distributed - tests/spec_decode/e2e/test_integration_dist_tp4 - tests/compile + - examples/offline_inference/rlhf.py commands: - pytest -v -s distributed/test_utils.py - pytest -v -s compile/test_basic_correctness.py - pytest -v -s distributed/test_pynccl.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py + # TODO: create a dedicated test section for multi-GPU example tests + # when we have multiple distributed example tests + - python3 ../examples/offline_inference/rlhf.py - label: Metrics, Tracing Test # 10min num_gpus: 2 @@ -179,7 +183,16 @@ steps: - vllm/ - tests/v1 commands: - - VLLM_USE_V1=1 pytest -v -s v1 + # split the test to avoid interference + - VLLM_USE_V1=1 pytest -v -s v1/core + - VLLM_USE_V1=1 pytest -v -s v1/engine + - VLLM_USE_V1=1 pytest -v -s v1/sample + - VLLM_USE_V1=1 pytest -v -s v1/worker + - VLLM_USE_V1=1 pytest -v -s v1/test_stats.py + - VLLM_USE_V1=1 pytest -v -s v1/test_utils.py + # TODO: accuracy does not match, whether setting + # VLLM_USE_FLASHINFER_SAMPLER or not on H100. + - VLLM_USE_V1=1 pytest -v -s v1/e2e - label: Examples Test # 25min working_dir: "/vllm-workspace/examples" @@ -189,19 +202,19 @@ steps: - examples/ commands: - pip install tensorizer # for tensorizer test - - python3 offline_inference.py - - python3 cpu_offload.py - - python3 offline_inference_chat.py - - python3 offline_inference_with_prefix.py - - python3 llm_engine_example.py - - python3 offline_inference_vision_language.py - - python3 offline_inference_vision_language_multi_image.py - - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - - python3 offline_inference_encoder_decoder.py - - python3 offline_inference_classification.py - - python3 offline_inference_embedding.py - - python3 offline_inference_scoring.py - - python3 offline_profile.py --model facebook/opt-125m run_num_steps --num-steps 2 + - python3 offline_inference/basic.py + - python3 offline_inference/cpu_offload.py + - python3 offline_inference/chat.py + - python3 offline_inference/prefix_caching.py + - python3 offline_inference/llm_engine_example.py + - python3 offline_inference/vision_language.py + - python3 offline_inference/vision_language_multi_image.py + - python3 other/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 other/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors + - python3 offline_inference/encoder_decoder.py + - python3 offline_inference/classification.py + - python3 offline_inference/embedding.py + - python3 offline_inference/scoring.py + - python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2 - label: Prefix Caching Test # 9min mirror_hardwares: [amd] @@ -216,6 +229,7 @@ steps: - vllm/model_executor/layers - vllm/sampling_metadata.py - tests/samplers + - tests/conftest.py commands: - pytest -v -s samplers - VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers @@ -231,20 +245,22 @@ steps: - pytest -v -s test_logits_processor.py - pytest -v -s model_executor/test_guided_processors.py -- label: Speculative decoding tests # 30min +- label: Speculative decoding tests # 40min source_file_dependencies: - vllm/spec_decode - tests/spec_decode + - vllm/model_executor/models/eagle.py commands: - pytest -v -s spec_decode/e2e/test_multistep_correctness.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py + - pytest -v -s spec_decode/e2e/test_eagle_correctness.py - label: LoRA Test %N # 15min each mirror_hardwares: [amd] source_file_dependencies: - vllm/lora - tests/lora - command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py + command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_minicpmv_tp.py parallelism: 4 - label: "PyTorch Fullgraph Smoke Test" # 9min @@ -333,8 +349,6 @@ steps: - vllm/ - tests/models commands: - - pip install -e ./plugins/vllm_add_dummy_model - - pytest -v -s models/test_oot_registration.py # it needs a clean process - pytest -v -s models/test_registry.py - pytest -v -s models/test_initialization.py @@ -360,23 +374,26 @@ steps: - pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model' - pytest -v -s models/embedding/language -m 'not core_model' -- label: Multi-Modal Models Test (Standard) # 28min +- label: Multi-Modal Models Test (Standard) # 40min #mirror_hardwares: [amd] source_file_dependencies: - vllm/ - tests/models/decoder_only/audio_language - tests/models/decoder_only/vision_language - tests/models/embedding/vision_language + - tests/models/encoder_decoder/audio_language - tests/models/encoder_decoder/vision_language commands: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/multimodal - pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model' - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model' - pytest -v -s models/embedding/vision_language -m core_model + - pytest -v -s models/encoder_decoder/audio_language -m core_model - pytest -v -s models/encoder_decoder/language -m core_model - pytest -v -s models/encoder_decoder/vision_language -m core_model -- label: Multi-Modal Models Test (Extended) 1 # 1h16m +- label: Multi-Modal Models Test (Extended) 1 # 48m optional: true source_file_dependencies: - vllm/ @@ -459,7 +476,10 @@ steps: - vllm/worker/worker_base.py - vllm/worker/worker.py - vllm/worker/model_runner.py + - entrypoints/llm/test_collective_rpc.py commands: + - pytest -v -s entrypoints/llm/test_collective_rpc.py + - torchrun --nproc-per-node=2 distributed/test_torchrun_example.py - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' @@ -468,12 +488,31 @@ steps: - pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)' - pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)' - pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)' - - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - - pip install -e ./plugins/vllm_add_dummy_model - - pytest -v -s distributed/test_distributed_oot.py + # this test fails consistently. + # TODO: investigate and fix + # - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/disagg_test.py +- label: Plugin Tests (2 GPUs) # 40min + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + fast_check: true + source_file_dependencies: + - vllm/plugins/ + - tests/plugins/ + commands: + # begin platform plugin tests, all the code in-between runs on dummy platform + - pip install -e ./plugins/vllm_add_dummy_platform + - pytest -v -s plugins_tests/test_platform_plugins.py + - pip uninstall vllm_add_dummy_platform -y + # end platform plugin tests + # other tests continue here: + - pip install -e ./plugins/vllm_add_dummy_model + - pytest -v -s distributed/test_distributed_oot.py + - pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process + - pytest -v -s models/test_oot_registration.py # it needs a clean process + - label: Multi-step Tests (4 GPUs) # 36min working_dir: "/vllm-workspace/tests" num_gpus: 4 @@ -489,7 +528,9 @@ steps: - vllm/engine - tests/multi_step commands: - - pytest -v -s multi_step/test_correctness_async_llm.py + # this test is quite flaky + # TODO: investigate and fix. + # - pytest -v -s multi_step/test_correctness_async_llm.py - pytest -v -s multi_step/test_correctness_llm.py - label: Pipeline Parallelism Test # 45min @@ -520,6 +561,7 @@ steps: # requires multi-GPU testing for validation. - pytest -v -s -x lora/test_chatglm3_tp.py - pytest -v -s -x lora/test_llama_tp.py + - pytest -v -s -x lora/test_minicpmv_tp.py - label: Weight Loading Multiple GPU Test # 33min diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 3cb91fc0f8232..bc324d8b988b1 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -2,32 +2,35 @@ # for more info about CODEOWNERS file # This lists cover the "core" components of vLLM that require careful review -/vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -/vllm/core @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -/vllm/engine/llm_engine.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -/vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +/vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill +/vllm/core @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill +/vllm/engine/llm_engine.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill +/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill +/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill +/vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill +/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill +/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth +/vllm/model_executor/guided_decoding @mgoin +/vllm/multimodal @DarkLight1337 @ywang96 CMakeLists.txt @tlrmchlsmth # vLLM V1 -/vllm/v1 @WoosukKwon @robertgshaw2-neuralmagic @njhill @ywang96 @comaniac @alexm-neuralmagic +/vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat # Test ownership -/tests/async_engine @njhill @robertgshaw2-neuralmagic @simon-mo +/tests/async_engine @njhill @robertgshaw2-redhat @simon-mo /tests/test_inputs.py @DarkLight1337 @ywang96 -/tests/entrypoints @DarkLight1337 @robertgshaw2-neuralmagic @simon-mo +/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @simon-mo /tests/models @DarkLight1337 @ywang96 /tests/multimodal @DarkLight1337 @ywang96 /tests/prefix_caching @comaniac @KuntaiDu /tests/spec_decode @njhill @LiuXiaoxuanPKU /tests/kernels @tlrmchlsmth @WoosukKwon -/tests/quantization @mgoin @robertgshaw2-neuralmagic +/tests/quantization @mgoin @robertgshaw2-redhat /.buildkite/lm-eval-harness @mgoin @simon-mo /tests/distributed/test_multi_node_assignment.py @youkaichao /tests/distributed/test_pipeline_parallel.py @youkaichao /tests/distributed/test_same_node.py @youkaichao -/tests/multi_step @alexm-neuralmagic @comaniac +/tests/multi_step @alexm-redhat @comaniac /tests/weight_loading @mgoin @youkaichao /tests/basic_correctness/test_chunked_prefill @rkooo567 @comaniac diff --git a/.github/mergify.yml b/.github/mergify.yml index ca4bd7ee2b87f..43bc5ce623d3c 100644 --- a/.github/mergify.yml +++ b/.github/mergify.yml @@ -35,6 +35,43 @@ pull_request_rules: add: - frontend +- name: label-structured-output + description: Automatically apply structured-output label + conditions: + - or: + - files~=^vllm/model_executor/guided_decoding/ + - files=tests/model_executor/test_guided_processors.py + - files=tests/entrypoints/llm/test_guided_generate.py + - files=benchmarks/benchmark_serving_guided.py + - files=benchmarks/benchmark_guided.py + actions: + label: + add: + - structured-output + +- name: label-speculative-decoding + description: Automatically apply speculative-decoding label + conditions: + - or: + - files~=^vllm/spec_decode/ + - files=vllm/model_executor/layers/spec_decode_base_sampler.py + - files~=^tests/spec_decode/ + actions: + label: + add: + - speculative-decoding + +- name: label-v1 + description: Automatically apply v1 label + conditions: + - or: + - files~=^vllm/v1/ + - files~=^tests/v1/ + actions: + label: + add: + - v1 + - name: ping author on conflicts and add 'needs-rebase' label conditions: - conflict diff --git a/.github/workflows/actionlint.yml b/.github/workflows/actionlint.yml deleted file mode 100644 index 0226cf0ca00e9..0000000000000 --- a/.github/workflows/actionlint.yml +++ /dev/null @@ -1,40 +0,0 @@ -name: Lint GitHub Actions workflows -on: - push: - branches: - - "main" - paths: - - '.github/workflows/*.ya?ml' - - '.github/workflows/actionlint.*' - - '.github/workflows/matchers/actionlint.json' - pull_request: - branches: - - "main" - paths: - - '.github/workflows/*.ya?ml' - - '.github/workflows/actionlint.*' - - '.github/workflows/matchers/actionlint.json' - -env: - LC_ALL: en_US.UTF-8 - -defaults: - run: - shell: bash - -permissions: - contents: read - -jobs: - actionlint: - runs-on: ubuntu-latest - steps: - - name: "Checkout" - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - - name: "Run actionlint" - run: | - echo "::add-matcher::.github/workflows/matchers/actionlint.json" - tools/actionlint.sh -color diff --git a/.github/workflows/clang-format.yml b/.github/workflows/clang-format.yml deleted file mode 100644 index 68149d2dc019f..0000000000000 --- a/.github/workflows/clang-format.yml +++ /dev/null @@ -1,53 +0,0 @@ -name: clang-format - -on: - # Trigger the workflow on push or pull request, - # but only for the main branch - push: - branches: - - main - paths: - - '**/*.h' - - '**/*.cpp' - - '**/*.cu' - - '**/*.cuh' - - '.github/workflows/clang-format.yml' - pull_request: - branches: - - main - paths: - - '**/*.h' - - '**/*.cpp' - - '**/*.cu' - - '**/*.cuh' - - '.github/workflows/clang-format.yml' - -jobs: - clang-format: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.11"] - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install clang-format==18.1.5 - - name: Running clang-format - run: | - EXCLUDES=( - 'csrc/moe/topk_softmax_kernels.cu' - 'csrc/quantization/gguf/ggml-common.h' - 'csrc/quantization/gguf/dequantize.cuh' - 'csrc/quantization/gguf/vecdotq.cuh' - 'csrc/quantization/gguf/mmq.cuh' - 'csrc/quantization/gguf/mmvq.cuh' - ) - find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \ - | grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \ - | xargs clang-format --dry-run --Werror diff --git a/.github/workflows/codespell.yml b/.github/workflows/codespell.yml deleted file mode 100644 index 68887adaae54b..0000000000000 --- a/.github/workflows/codespell.yml +++ /dev/null @@ -1,45 +0,0 @@ -name: codespell - -on: - # Trigger the workflow on push or pull request, - # but only for the main branch - push: - branches: - - main - paths: - - "**/*.py" - - "**/*.md" - - "**/*.rst" - - pyproject.toml - - requirements-lint.txt - - .github/workflows/codespell.yml - pull_request: - branches: - - main - paths: - - "**/*.py" - - "**/*.md" - - "**/*.rst" - - pyproject.toml - - requirements-lint.txt - - .github/workflows/codespell.yml - -jobs: - codespell: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.12"] - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -r requirements-lint.txt - - name: Spelling check with codespell - run: | - codespell --toml pyproject.toml diff --git a/.github/workflows/lint-and-deploy.yaml b/.github/workflows/lint-and-deploy.yaml index ab6f6e5d2060d..556b60d2fca12 100644 --- a/.github/workflows/lint-and-deploy.yaml +++ b/.github/workflows/lint-and-deploy.yaml @@ -27,7 +27,7 @@ jobs: version: v3.10.1 - name: Run chart-testing (lint) - run: ct lint --target-branch ${{ github.event.repository.default_branch }} --chart-dirs examples/chart-helm --charts examples/chart-helm + run: ct lint --target-branch ${{ github.event.repository.default_branch }} --chart-dirs examples/online_serving/chart-helm --charts examples/online_serving/chart-helm - name: Setup minio run: | @@ -64,7 +64,8 @@ jobs: run: | export AWS_ACCESS_KEY_ID=minioadmin export AWS_SECRET_ACCESS_KEY=minioadmin - helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/chart-helm -f examples/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env" + sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" & + helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env" - name: curl test run: | diff --git a/.github/workflows/matchers/ruff.json b/.github/workflows/matchers/ruff.json deleted file mode 100644 index f6d4479ee1996..0000000000000 --- a/.github/workflows/matchers/ruff.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "problemMatcher": [ - { - "owner": "ruff", - "pattern": [ - { - "regexp": "^(.+?):(\\d+):(\\d+): (\\w+): (.+)$", - "file": 1, - "line": 2, - "column": 3, - "code": 4, - "message": 5 - } - ] - } - ] - } diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml deleted file mode 100644 index 73eeacf1fa562..0000000000000 --- a/.github/workflows/mypy.yaml +++ /dev/null @@ -1,51 +0,0 @@ -name: mypy - -on: - # Trigger the workflow on push or pull request, - # but only for the main branch - push: - branches: - - main - paths: - - '**/*.py' - - '.github/workflows/mypy.yaml' - - 'tools/mypy.sh' - - 'pyproject.toml' - pull_request: - branches: - - main - # This workflow is only relevant when one of the following files changes. - # However, we have github configured to expect and require this workflow - # to run and pass before github with auto-merge a pull request. Until github - # allows more flexible auto-merge policy, we can just run this on every PR. - # It doesn't take that long to run, anyway. - #paths: - # - '**/*.py' - # - '.github/workflows/mypy.yaml' - # - 'tools/mypy.sh' - # - 'pyproject.toml' - -jobs: - mypy: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.9", "3.10", "3.11", "3.12"] - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install mypy==1.11.1 - pip install types-setuptools - pip install types-PyYAML - pip install types-requests - pip install types-setuptools - - name: Mypy - run: | - echo "::add-matcher::.github/workflows/matchers/mypy.json" - tools/mypy.sh 1 ${{ matrix.python-version }} diff --git a/.github/workflows/png-lint.yml b/.github/workflows/png-lint.yml deleted file mode 100644 index 4932af943a07b..0000000000000 --- a/.github/workflows/png-lint.yml +++ /dev/null @@ -1,37 +0,0 @@ -name: Lint PNG exports from excalidraw -on: - push: - branches: - - "main" - paths: - - '*.excalidraw.png' - - '.github/workflows/png-lint.yml' - pull_request: - branches: - - "main" - paths: - - '*.excalidraw.png' - - '.github/workflows/png-lint.yml' - -env: - LC_ALL: en_US.UTF-8 - -defaults: - run: - shell: bash - -permissions: - contents: read - -jobs: - actionlint: - runs-on: ubuntu-latest - steps: - - name: "Checkout" - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - - name: "Run png-lint.sh to check excalidraw exported images" - run: | - tools/png-lint.sh diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml new file mode 100644 index 0000000000000..06564969dc778 --- /dev/null +++ b/.github/workflows/pre-commit.yml @@ -0,0 +1,19 @@ +name: pre-commit + +on: + pull_request: + push: + branches: [main] + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: "3.12" + - run: echo "::add-matcher::.github/workflows/matchers/actionlint.json" + - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 + with: + extra_args: --all-files --hook-stage manual diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml deleted file mode 100644 index 7266cc378cfb0..0000000000000 --- a/.github/workflows/ruff.yml +++ /dev/null @@ -1,52 +0,0 @@ -name: ruff - -on: - # Trigger the workflow on push or pull request, - # but only for the main branch - push: - branches: - - main - paths: - - "**/*.py" - - pyproject.toml - - requirements-lint.txt - - .github/workflows/matchers/ruff.json - - .github/workflows/ruff.yml - pull_request: - branches: - - main - # This workflow is only relevant when one of the following files changes. - # However, we have github configured to expect and require this workflow - # to run and pass before github with auto-merge a pull request. Until github - # allows more flexible auto-merge policy, we can just run this on every PR. - # It doesn't take that long to run, anyway. - #paths: - # - "**/*.py" - # - pyproject.toml - # - requirements-lint.txt - # - .github/workflows/matchers/ruff.json - # - .github/workflows/ruff.yml - -jobs: - ruff: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.12"] - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -r requirements-lint.txt - - name: Analysing the code with ruff - run: | - echo "::add-matcher::.github/workflows/matchers/ruff.json" - ruff check --output-format github . - - name: Run isort - run: | - isort . --check-only diff --git a/.github/workflows/shellcheck.yml b/.github/workflows/shellcheck.yml deleted file mode 100644 index 4b1587e373e17..0000000000000 --- a/.github/workflows/shellcheck.yml +++ /dev/null @@ -1,37 +0,0 @@ -name: Lint shell scripts -on: - push: - branches: - - "main" - paths: - - '**/*.sh' - - '.github/workflows/shellcheck.yml' - pull_request: - branches: - - "main" - paths: - - '**/*.sh' - - '.github/workflows/shellcheck.yml' - -env: - LC_ALL: en_US.UTF-8 - -defaults: - run: - shell: bash - -permissions: - contents: read - -jobs: - shellcheck: - runs-on: ubuntu-latest - steps: - - name: "Checkout" - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - - name: "Check shell scripts" - run: | - tools/shellcheck.sh diff --git a/.github/workflows/sphinx-lint.yml b/.github/workflows/sphinx-lint.yml deleted file mode 100644 index e0bb24276a653..0000000000000 --- a/.github/workflows/sphinx-lint.yml +++ /dev/null @@ -1,32 +0,0 @@ -name: Lint documentation - -on: - push: - branches: - - main - paths: - - "docs/**" - pull_request: - branches: - - main - paths: - - "docs/**" - -jobs: - sphinx-lint: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.12"] - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -r requirements-lint.txt - - name: Linting docs - run: tools/sphinx-lint.sh diff --git a/.github/workflows/yapf.yml b/.github/workflows/yapf.yml deleted file mode 100644 index ff441f94435ad..0000000000000 --- a/.github/workflows/yapf.yml +++ /dev/null @@ -1,38 +0,0 @@ -name: yapf - -on: - # Trigger the workflow on push or pull request, - # but only for the main branch - push: - branches: - - main - paths: - - "**/*.py" - - .github/workflows/yapf.yml - pull_request: - branches: - - main - paths: - - "**/*.py" - - .github/workflows/yapf.yml - -jobs: - yapf: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.12"] - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install yapf==0.32.0 - pip install toml==0.10.2 - - name: Running yapf - run: | - yapf --diff --recursive . diff --git a/.gitignore b/.gitignore index bb7e4d5b244a8..89dab8f13bab1 100644 --- a/.gitignore +++ b/.gitignore @@ -79,10 +79,7 @@ instance/ # Sphinx documentation docs/_build/ -docs/source/getting_started/examples/*.rst -!**/*.template.rst -docs/source/getting_started/examples/*.md -!**/*.template.md +docs/source/getting_started/examples/ # PyBuilder .pybuilder/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000..ae518e1902f53 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,106 @@ +default_stages: + - pre-commit # Run locally + - manual # Run in CI +repos: +- repo: https://github.com/google/yapf + rev: v0.43.0 + hooks: + - id: yapf + args: [--in-place, --verbose] + additional_dependencies: [toml] # TODO: Remove when yapf is upgraded +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.9.3 + hooks: + - id: ruff + args: [--output-format, github] +- repo: https://github.com/codespell-project/codespell + rev: v2.4.0 + hooks: + - id: codespell + exclude: 'benchmarks/sonnet.txt|(build|tests/(lora/data|models/fixtures|prompts))/.*' +- repo: https://github.com/PyCQA/isort + rev: 5.13.2 + hooks: + - id: isort +- repo: https://github.com/pre-commit/mirrors-clang-format + rev: v19.1.7 + hooks: + - id: clang-format + exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))' + types_or: [c++, cuda] + args: [--style=file, --verbose] +- repo: https://github.com/jackdewinter/pymarkdown + rev: v0.9.27 + hooks: + - id: pymarkdown + files: docs/.* +- repo: https://github.com/rhysd/actionlint + rev: v1.7.7 + hooks: + - id: actionlint +- repo: local + hooks: + - id: mypy-local + name: Run mypy for local Python installation + entry: tools/mypy.sh 0 "local" + language: python + types: [python] + additional_dependencies: &mypy_deps [mypy==1.11.1, types-setuptools, types-PyYAML, types-requests] + stages: [pre-commit] # Don't run in CI + - id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward + name: Run mypy for Python 3.9 + entry: tools/mypy.sh 1 "3.9" + language: python + types: [python] + additional_dependencies: *mypy_deps + stages: [manual] # Only run in CI + - id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward + name: Run mypy for Python 3.10 + entry: tools/mypy.sh 1 "3.10" + language: python + types: [python] + additional_dependencies: *mypy_deps + stages: [manual] # Only run in CI + - id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward + name: Run mypy for Python 3.11 + entry: tools/mypy.sh 1 "3.11" + language: python + types: [python] + additional_dependencies: *mypy_deps + stages: [manual] # Only run in CI + - id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward + name: Run mypy for Python 3.12 + entry: tools/mypy.sh 1 "3.12" + language: python + types: [python] + additional_dependencies: *mypy_deps + stages: [manual] # Only run in CI + - id: shellcheck + name: Lint shell scripts + entry: tools/shellcheck.sh + language: script + types: [shell] + - id: png-lint + name: Lint PNG exports from excalidraw + entry: tools/png-lint.sh + language: script + types: [png] + - id: signoff-commit + name: Sign-off Commit + entry: bash + args: + - -c + - | + if ! grep -q "^Signed-off-by: $(git config user.name) <$(git config user.email)>" .git/COMMIT_EDITMSG; then + printf "\nSigned-off-by: $(git config user.name) <$(git config user.email)>\n" >> .git/COMMIT_EDITMSG + fi + language: system + verbose: true + stages: [commit-msg] + - id: suggestion + name: Suggestion + entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."' + language: system + verbose: true + pass_filenames: false + diff --git a/CMakeLists.txt b/CMakeLists.txt old mode 100644 new mode 100755 index 83c8033434f3b..c823c9ff895c3 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,9 +24,6 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) # Suppress potential warnings about unused manually-specified variables set(ignoreMe "${VLLM_PYTHON_PATH}") -# Prevent installation of dependencies (cutlass) by default. -install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) - # # Supported python versions. These versions will be searched in order, the # first match will be selected. These should be kept in sync with setup.py. @@ -181,6 +178,31 @@ message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}") # Define other extension targets # +# +# cumem_allocator extension +# + +set(VLLM_CUMEM_EXT_SRC + "csrc/cumem_allocator.cpp") + +set_gencode_flags_for_srcs( + SRCS "${VLLM_CUMEM_EXT_SRC}" + CUDA_ARCHS "${CUDA_ARCHS}") + +if(VLLM_GPU_LANG STREQUAL "CUDA") + message(STATUS "Enabling cumem allocator extension.") + # link against cuda driver library + list(APPEND CUMEM_LIBS cuda) + define_gpu_extension_target( + cumem_allocator + DESTINATION vllm + LANGUAGE CXX + SOURCES ${VLLM_CUMEM_EXT_SRC} + LIBRARIES ${CUMEM_LIBS} + USE_SABI 3.8 + WITH_SOABI) +endif() + # # _C extension # @@ -223,13 +245,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") FetchContent_Declare( cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git - GIT_TAG 8aa95dbb888be6d81c6fbf7169718c5244b53227 + GIT_TAG v3.7.0 GIT_PROGRESS TRUE # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE - GIT_SHALLOW FALSE + GIT_SHALLOW TRUE ) endif() FetchContent_MakeAvailable(cutlass) @@ -253,7 +275,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # Only build Marlin kernels if we are building for at least some compatible archs. # Keep building Marlin for 9.0 as there are some group sizes and shapes that # are not supported by Machete yet. - cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0" ${CUDA_ARCHS}) + cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}") if (MARLIN_ARCHS) set(MARLIN_SRCS "csrc/quantization/fp8/fp8_marlin.cu" @@ -274,10 +296,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require - # CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now). - cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}") + # CUDA 12.0 or later (and only work on Hopper, 9.0a for now). + cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu") + set(SRCS + "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu" + "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu" + "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu" + "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu" + "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_3X_ARCHS}") @@ -329,7 +356,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # 2:4 Sparse Kernels # The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor - # require CUDA 12.2 or later (and only work on Hopper, 9.0/9.0a for now). + # require CUDA 12.2 or later (and only work on Hopper, 9.0a for now). if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS) set(SRCS "csrc/sparse/cutlass/sparse_compressor_c3x.cu" "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu") @@ -510,7 +537,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP") endif() # vllm-flash-attn currently only supported on CUDA -if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda") +if (NOT VLLM_GPU_LANG STREQUAL "CUDA") return() endif () @@ -533,7 +560,7 @@ endif() # They should be identical but if they aren't, this is a massive footgun. # # The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place. -# To only install vllm-flash-attn, use --component vllm_flash_attn_c. +# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3). # If no component is specified, vllm-flash-attn is still installed. # If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading. @@ -545,43 +572,41 @@ if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR}) endif() if(VLLM_FLASH_ATTN_SRC_DIR) - FetchContent_Declare(vllm-flash-attn SOURCE_DIR ${VLLM_FLASH_ATTN_SRC_DIR}) + FetchContent_Declare( + vllm-flash-attn SOURCE_DIR + ${VLLM_FLASH_ATTN_SRC_DIR} + BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn + ) else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 04325b6798bcc326c86fb35af62d05a9c8c8eceb + GIT_TAG d4e09037abf588af1ec47d0e966b237ee376876c GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn ) endif() -# Set the parent build flag so that the vllm-flash-attn library does not redo compile flag and arch initialization. -set(VLLM_PARENT_BUILD ON) - -# Ensure the vllm/vllm_flash_attn directory exists before installation -install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" COMPONENT vllm_flash_attn_c) - -# Make sure vllm-flash-attn install rules are nested under vllm/ -install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" COMPONENT vllm_flash_attn_c) -install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c) -install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" COMPONENT vllm_flash_attn_c) # Fetch the vllm-flash-attn library FetchContent_MakeAvailable(vllm-flash-attn) message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}") -# Restore the install prefix -install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c) -install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" COMPONENT vllm_flash_attn_c) +# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in +# case only one is built, in the case both are built redundant work is done) +install( + DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ + DESTINATION vllm_flash_attn + COMPONENT _vllm_fa2_C + FILES_MATCHING PATTERN "*.py" +) -# Copy over the vllm-flash-attn python files install( - DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ - DESTINATION vllm/vllm_flash_attn - COMPONENT vllm_flash_attn_c - FILES_MATCHING PATTERN "*.py" + DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ + DESTINATION vllm_flash_attn + COMPONENT _vllm_fa3_C + FILES_MATCHING PATTERN "*.py" ) # Nothing after vllm-flash-attn, see comment about macros above diff --git a/Dockerfile b/Dockerfile index 153bff9cf565f..0b9f74e08dc68 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,8 +2,8 @@ # to run the OpenAI compatible server. # Please update any changes made here to -# docs/source/dev/dockerfile/dockerfile.md and -# docs/source/assets/dev/dockerfile-stages-dependency.png +# docs/source/contributing/dockerfile/dockerfile.md and +# docs/source/assets/contributing/dockerfile-stages-dependency.png ARG CUDA_VERSION=12.4.1 #################### BASE BUILD IMAGE #################### @@ -52,7 +52,7 @@ WORKDIR /workspace # after this step RUN --mount=type=cache,target=/root/.cache/pip \ if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ - python3 -m pip install --index-url https://download.pytorch.org/whl/nightly/cu124 "torch==2.6.0.dev20241210+cu124" "torchvision==0.22.0.dev20241215"; \ + python3 -m pip install --index-url https://download.pytorch.org/whl/nightly/cu126 "torch==2.7.0.dev20250121+cu126" "torchvision==0.22.0.dev20250121"; \ fi COPY requirements-common.txt requirements-common.txt @@ -126,8 +126,8 @@ RUN --mount=type=cache,target=/root/.cache/ccache \ # Check the size of the wheel if RUN_WHEEL_CHECK is true COPY .buildkite/check-wheel-size.py check-wheel-size.py -# Default max size of the wheel is 250MB -ARG VLLM_MAX_SIZE_MB=250 +# sync the default value with .buildkite/check-wheel-size.py +ARG VLLM_MAX_SIZE_MB=300 ENV VLLM_MAX_SIZE_MB=$VLLM_MAX_SIZE_MB ARG RUN_WHEEL_CHECK=true RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \ @@ -149,7 +149,8 @@ RUN --mount=type=cache,target=/root/.cache/pip \ #################### vLLM installation IMAGE #################### # image with vLLM installed -FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu22.04 AS vllm-base +# TODO: Restore to base image after FlashInfer AOT wheel fixed +FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS vllm-base ARG CUDA_VERSION=12.4.1 ARG PYTHON_VERSION=3.12 WORKDIR /vllm-workspace @@ -194,12 +195,30 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install dist/*.whl --verbose +# How to build this FlashInfer wheel: +# $ export FLASHINFER_ENABLE_AOT=1 +# $ # Note we remove 7.0 from the arch list compared to the list below, since FlashInfer only supports sm75+ +# $ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.6 8.9 9.0+PTX' +# $ git clone https://github.com/flashinfer-ai/flashinfer.git --recursive +# $ cd flashinfer +# $ git checkout 524304395bd1d8cd7d07db083859523fcaa246a4 +# $ python3 setup.py bdist_wheel --dist-dir=dist --verbose + RUN --mount=type=cache,target=/root/.cache/pip \ . /etc/environment && \ if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \ - python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl; \ + python3 -m pip install https://wheels.vllm.ai/flashinfer/524304395bd1d8cd7d07db083859523fcaa246a4/flashinfer_python-0.2.0.post1-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl; \ fi COPY examples examples + +# Although we build Flashinfer with AOT mode, there's still +# some issues w.r.t. JIT compilation. Therefore we need to +# install build dependencies for JIT compilation. +# TODO: Remove this once FlashInfer AOT wheel is fixed +COPY requirements-build.txt requirements-build.txt +RUN --mount=type=cache,target=/root/.cache/pip \ + python3 -m pip install -r requirements-build.txt + #################### vLLM installation IMAGE #################### #################### TEST IMAGE #################### @@ -234,8 +253,8 @@ RUN mv vllm test_docs/ #################### TEST IMAGE #################### #################### OPENAI API SERVER #################### -# openai api server alternative -FROM vllm-base AS vllm-openai +# base openai image with additional requirements, for any subsequent openai-style images +FROM vllm-base AS vllm-openai-base # install additional dependencies for openai api server RUN --mount=type=cache,target=/root/.cache/pip \ @@ -247,5 +266,14 @@ RUN --mount=type=cache,target=/root/.cache/pip \ ENV VLLM_USAGE_SOURCE production-docker-image +# define sagemaker first, so it is not default from `docker build` +FROM vllm-openai-base AS vllm-sagemaker + +COPY examples/online_serving/sagemaker-entrypoint.sh . +RUN chmod +x sagemaker-entrypoint.sh +ENTRYPOINT ["./sagemaker-entrypoint.sh"] + +FROM vllm-openai-base AS vllm-openai + ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] #################### OPENAI API SERVER #################### diff --git a/Dockerfile.cpu b/Dockerfile.cpu index f163edc27cba8..ebe226cf6d148 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -26,10 +26,10 @@ RUN pip install intel_extension_for_pytorch==2.5.0 WORKDIR /workspace -COPY requirements-build.txt requirements-build.txt ARG PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" ENV PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL} RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,src=requirements-build.txt,target=requirements-build.txt \ pip install --upgrade pip && \ pip install -r requirements-build.txt @@ -37,9 +37,9 @@ FROM cpu-test-1 AS build WORKDIR /workspace/vllm -COPY requirements-common.txt requirements-common.txt -COPY requirements-cpu.txt requirements-cpu.txt RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,src=requirements-common.txt,target=requirements-common.txt \ + --mount=type=bind,src=requirements-cpu.txt,target=requirements-cpu.txt \ pip install -v -r requirements-cpu.txt COPY . . diff --git a/Dockerfile.hpu b/Dockerfile.hpu index 87e0c1a6a934e..66cf68c32f2ca 100644 --- a/Dockerfile.hpu +++ b/Dockerfile.hpu @@ -1,4 +1,4 @@ -FROM vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest +FROM vault.habana.ai/gaudi-docker/1.19.1/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest COPY ./ /workspace/vllm diff --git a/Dockerfile.neuron b/Dockerfile.neuron index 77162bc82de62..e9cb82889decd 100644 --- a/Dockerfile.neuron +++ b/Dockerfile.neuron @@ -1,6 +1,6 @@ # default base image # https://gallery.ecr.aws/neuron/pytorch-inference-neuronx -ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.1.2-neuronx-py310-sdk2.20.2-ubuntu20.04" +ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.5.1-neuronx-py310-sdk2.21.0-ubuntu22.04" FROM $BASE_IMAGE @@ -15,16 +15,17 @@ RUN apt-get update && \ ffmpeg libsm6 libxext6 libgl1 ### Mount Point ### -# When launching the container, mount the code directory to /app -ARG APP_MOUNT=/app +# When launching the container, mount the code directory to /workspace +ARG APP_MOUNT=/workspace VOLUME [ ${APP_MOUNT} ] WORKDIR ${APP_MOUNT}/vllm RUN python3 -m pip install --upgrade pip RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas -RUN python3 -m pip install sentencepiece transformers==4.36.2 -U +RUN python3 -m pip install sentencepiece transformers==4.45.2 -U RUN python3 -m pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com -U -RUN python3 -m pip install --pre neuronx-cc==2.15.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U +RUN python3 -m pip install neuronx-cc==2.16.345.0 --extra-index-url=https://pip.repos.neuron.amazonaws.com -U +RUN python3 -m pip install pytest COPY . . ARG GIT_REPO_CHECK=0 @@ -42,4 +43,7 @@ RUN --mount=type=bind,source=.git,target=.git \ # install development dependencies (for testing) RUN python3 -m pip install -e tests/vllm_test_utils +# overwrite entrypoint to run bash script +RUN echo "import subprocess; import sys; subprocess.check_call(sys.argv[1:])" > /usr/local/bin/dockerd-entrypoint.py + CMD ["/bin/bash"] diff --git a/Dockerfile.openvino b/Dockerfile.openvino index 8bd188ffde408..32bcbfa9cc168 100644 --- a/Dockerfile.openvino +++ b/Dockerfile.openvino @@ -14,6 +14,7 @@ ARG GIT_REPO_CHECK=0 RUN --mount=type=bind,source=.git,target=.git \ if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi +RUN python3 -m pip install -U pip # install build requirements RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/requirements-build.txt # build vLLM with OpenVINO backend diff --git a/Dockerfile.ppc64le b/Dockerfile.ppc64le index 971248577983f..c4c1f3e357972 100644 --- a/Dockerfile.ppc64le +++ b/Dockerfile.ppc64le @@ -4,12 +4,12 @@ USER root ENV PATH="/usr/local/cargo/bin:$PATH:/opt/conda/bin/" -RUN apt-get update -y && apt-get install -y git wget curl vim libnuma-dev libsndfile-dev libprotobuf-dev build-essential ffmpeg libsm6 libxext6 libgl1 +RUN apt-get update -y && apt-get install -y git wget kmod curl vim libnuma-dev libsndfile-dev libprotobuf-dev build-essential ffmpeg libsm6 libxext6 libgl1 libssl-dev # Some packages in requirements-cpu are installed here # IBM provides optimized packages for ppc64le processors in the open-ce project for mamba # Currently these may not be available for venv or pip directly -RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults python=3.10 torchvision-cpu=0.16.2 rust && micromamba clean --all --yes +RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults python=3.10 rust && micromamba clean --all --yes COPY ./ /workspace/vllm @@ -18,11 +18,9 @@ ARG GIT_REPO_CHECK=0 RUN --mount=type=bind,source=.git,target=.git \ if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi -# These packages will be in rocketce eventually RUN --mount=type=cache,target=/root/.cache/pip \ - pip install -v --prefer-binary --extra-index-url https://repo.fury.io/mgiessing \ + RUSTFLAGS='-L /opt/conda/lib' pip install -v --prefer-binary --extra-index-url https://repo.fury.io/mgiessing \ 'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \ - torch==2.3.1 \ -r requirements-cpu.txt \ xformers uvloop==0.20.0 diff --git a/Dockerfile.rocm b/Dockerfile.rocm index e733994f8c33e..14c522afd7f9e 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -1,174 +1,119 @@ -# Default ROCm 6.2 base image -ARG BASE_IMAGE="rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0" +# default base image +ARG REMOTE_VLLM="0" +ARG USE_CYTHON="0" +ARG BUILD_RPD="1" +ARG COMMON_WORKDIR=/app +ARG BASE_IMAGE=rocm/vllm-dev:base -# Default ROCm ARCHes to build vLLM for. -ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100" +FROM ${BASE_IMAGE} AS base -# Whether to install CK-based flash-attention -# If 0, will not install flash-attention -ARG BUILD_FA="1" -ARG FA_GFX_ARCHS="gfx90a;gfx942" -ARG FA_BRANCH="3cea2fb" - -# Whether to build triton on rocm -ARG BUILD_TRITON="1" -ARG TRITON_BRANCH="e192dba" - -### Base image build stage -FROM $BASE_IMAGE AS base - -# Import arg(s) defined before this build stage -ARG PYTORCH_ROCM_ARCH +ARG ARG_PYTORCH_ROCM_ARCH +ENV PYTORCH_ROCM_ARCH=${ARG_PYTORCH_ROCM_ARCH:-${PYTORCH_ROCM_ARCH}} # Install some basic utilities -RUN apt-get update && apt-get install python3 python3-pip -y -RUN apt-get update && apt-get install -y \ - curl \ - ca-certificates \ - sudo \ - git \ - bzip2 \ - libx11-6 \ - build-essential \ - wget \ - unzip \ - tmux \ - ccache \ - && rm -rf /var/lib/apt/lists/* - -# When launching the container, mount the code directory to /vllm-workspace -ARG APP_MOUNT=/vllm-workspace -WORKDIR ${APP_MOUNT} - -RUN python3 -m pip install --upgrade pip -# Remove sccache so it doesn't interfere with ccache -# TODO: implement sccache support across components +RUN apt-get update -q -y && apt-get install -q -y \ + sqlite3 libsqlite3-dev libfmt-dev libmsgpack-dev libsuitesparse-dev +# Remove sccache +RUN python3 -m pip install --upgrade pip && pip install setuptools_scm RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)" +ARG COMMON_WORKDIR +WORKDIR ${COMMON_WORKDIR} + + +# ----------------------- +# vLLM fetch stages +FROM base AS fetch_vllm_0 +ONBUILD COPY ./ vllm/ +FROM base AS fetch_vllm_1 +ARG VLLM_REPO="https://github.com/vllm-project/vllm.git" +ARG VLLM_BRANCH="main" +ONBUILD RUN git clone ${VLLM_REPO} \ + && cd vllm \ + && git checkout ${VLLM_BRANCH} +FROM fetch_vllm_${REMOTE_VLLM} AS fetch_vllm + +# ----------------------- +# vLLM build stages +FROM fetch_vllm AS build_vllm +ARG USE_CYTHON +# Build vLLM +RUN cd vllm \ + && python3 -m pip install -r requirements-rocm.txt \ + && python3 setup.py clean --all \ + && if [ ${USE_CYTHON} -eq "1" ]; then python3 setup_cython.py build_ext --inplace; fi \ + && python3 setup.py bdist_wheel --dist-dir=dist +FROM scratch AS export_vllm +ARG COMMON_WORKDIR +COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/dist/*.whl / +COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/requirements*.txt / +COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/benchmarks /benchmarks +COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/tests /tests +COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/examples /examples +COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/.buildkite /.buildkite + +# ----------------------- +# Test vLLM image +FROM base AS test + +RUN python3 -m pip install --upgrade pip && rm -rf /var/lib/apt/lists/* + +# Install vLLM +RUN --mount=type=bind,from=export_vllm,src=/,target=/install \ + cd /install \ + && pip install -U -r requirements-rocm.txt \ + && pip uninstall -y vllm \ + && pip install *.whl + +WORKDIR /vllm-workspace +ARG COMMON_WORKDIR +COPY --from=build_vllm ${COMMON_WORKDIR}/vllm /vllm-workspace + +# install development dependencies (for testing) +RUN cd /vllm-workspace \ + && rm -rf vllm \ + && python3 -m pip install -e tests/vllm_test_utils \ + && python3 -m pip install lm-eval[api]==0.4.4 \ + && python3 -m pip install pytest-shard + +# ----------------------- +# Final vLLM image +FROM base AS final -# Install torch == 2.6.0 on ROCm -RUN --mount=type=cache,target=/root/.cache/pip \ - case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ - *"rocm-6.2"*) \ - python3 -m pip uninstall -y torch torchvision \ - && python3 -m pip install --pre \ - torch==2.6.0.dev20241113+rocm6.2 \ - 'setuptools-scm>=8' \ - torchvision==0.20.0.dev20241113+rocm6.2 \ - --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2;; \ +RUN python3 -m pip install --upgrade pip && rm -rf /var/lib/apt/lists/* +# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt. +# Manually remove it so that later steps of numpy upgrade can continue +RUN case "$(which python3)" in \ + *"/opt/conda/envs/py_3.9"*) \ + rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/;; \ *) ;; esac -ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer -ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin: -ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib: -ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/: - -ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} -ENV CCACHE_DIR=/root/.cache/ccache - - -### AMD-SMI build stage -FROM base AS build_amdsmi -# Build amdsmi wheel always -RUN cd /opt/rocm/share/amd_smi \ - && python3 -m pip wheel . --wheel-dir=/install - - -### Flash-Attention wheel build stage -FROM base AS build_fa -ARG BUILD_FA -ARG FA_GFX_ARCHS -ARG FA_BRANCH -# Build ROCm flash-attention wheel if `BUILD_FA = 1` -RUN --mount=type=cache,target=${CCACHE_DIR} \ - if [ "$BUILD_FA" = "1" ]; then \ - mkdir -p libs \ - && cd libs \ - && git clone https://github.com/ROCm/flash-attention.git \ - && cd flash-attention \ - && git checkout "${FA_BRANCH}" \ - && git submodule update --init \ - && GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \ - # Create an empty directory otherwise as later build stages expect one - else mkdir -p /install; \ - fi - - -### Triton wheel build stage -FROM base AS build_triton -ARG BUILD_TRITON -ARG TRITON_BRANCH -# Build triton wheel if `BUILD_TRITON = 1` -RUN --mount=type=cache,target=${CCACHE_DIR} \ - if [ "$BUILD_TRITON" = "1" ]; then \ - mkdir -p libs \ - && cd libs \ - && python3 -m pip install ninja cmake wheel pybind11 \ - && git clone https://github.com/OpenAI/triton.git \ - && cd triton \ - && git checkout "${TRITON_BRANCH}" \ - && cd python \ - && python3 setup.py bdist_wheel --dist-dir=/install; \ - # Create an empty directory otherwise as later build stages expect one - else mkdir -p /install; \ - fi - - -### Final vLLM build stage -FROM base AS final -# Import the vLLM development directory from the build context -COPY . . -ARG GIT_REPO_CHECK=0 -RUN --mount=type=bind,source=.git,target=.git \ - if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi +RUN python3 -m pip install --upgrade huggingface-hub[cli] +ARG BUILD_RPD +RUN if [ ${BUILD_RPD} -eq "1" ]; then \ + git clone -b nvtx_enabled https://github.com/ROCm/rocmProfileData.git \ + && cd rocmProfileData/rpd_tracer \ + && pip install -r requirements.txt && cd ../ \ + && make && make install \ + && cd hipMarker && python3 setup.py install ; fi -RUN python3 -m pip install --upgrade pip +# Install vLLM +RUN --mount=type=bind,from=export_vllm,src=/,target=/install \ + cd /install \ + && pip install -U -r requirements-rocm.txt \ + && pip uninstall -y vllm \ + && pip install *.whl -# Package upgrades for useful functionality or to avoid dependency issues -RUN --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install --upgrade numba scipy huggingface-hub[cli] pytest-shard +ARG COMMON_WORKDIR +# Copy over the benchmark scripts as well +COPY --from=export_vllm /benchmarks ${COMMON_WORKDIR}/vllm/benchmarks +COPY --from=export_vllm /examples ${COMMON_WORKDIR}/vllm/examples -# Workaround for ray >= 2.10.0 ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 -# Silences the HF Tokenizers warning ENV TOKENIZERS_PARALLELISM=false -RUN --mount=type=cache,target=${CCACHE_DIR} \ - --mount=type=bind,source=.git,target=.git \ - --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install -Ur requirements-rocm.txt \ - && python3 setup.py clean --all \ - && python3 setup.py develop - -# Copy amdsmi wheel into final image -RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \ - mkdir -p libs \ - && cp /install/*.whl libs \ - # Preemptively uninstall to avoid same-version no-installs - && python3 -m pip uninstall -y amdsmi; - -# Copy triton wheel(s) into final image if they were built -RUN --mount=type=bind,from=build_triton,src=/install,target=/install \ - mkdir -p libs \ - && if ls /install/*.whl; then \ - cp /install/*.whl libs \ - # Preemptively uninstall to avoid same-version no-installs - && python3 -m pip uninstall -y triton; fi - -# Copy flash-attn wheel(s) into final image if they were built -RUN --mount=type=bind,from=build_fa,src=/install,target=/install \ - mkdir -p libs \ - && if ls /install/*.whl; then \ - cp /install/*.whl libs \ - # Preemptively uninstall to avoid same-version no-installs - && python3 -m pip uninstall -y flash-attn; fi - -# Install wheels that were built to the final image -RUN --mount=type=cache,target=/root/.cache/pip \ - if ls libs/*.whl; then \ - python3 -m pip install libs/*.whl; fi - -# install development dependencies (for testing) -RUN python3 -m pip install -e tests/vllm_test_utils +# Performance environment variable. +ENV HIP_FORCE_DEV_KERNARG=1 CMD ["/bin/bash"] + diff --git a/Dockerfile.rocm_base b/Dockerfile.rocm_base new file mode 100644 index 0000000000000..5bbe98b0c2204 --- /dev/null +++ b/Dockerfile.rocm_base @@ -0,0 +1,158 @@ +ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:6.3.1-complete +ARG HIPBLASLT_BRANCH="4d40e36" +ARG HIPBLAS_COMMON_BRANCH="7c1566b" +ARG LEGACY_HIPBLASLT_OPTION= +ARG RCCL_BRANCH="648a58d" +ARG RCCL_REPO="https://github.com/ROCm/rccl" +ARG TRITON_BRANCH="e5be006" +ARG TRITON_REPO="https://github.com/triton-lang/triton.git" +ARG PYTORCH_BRANCH="8d4926e" +ARG PYTORCH_VISION_BRANCH="v0.19.1" +ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" +ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" +ARG FA_BRANCH="b7d29fb" +ARG FA_REPO="https://github.com/ROCm/flash-attention.git" + +FROM ${BASE_IMAGE} AS base + +ENV PATH=/opt/rocm/llvm/bin:$PATH +ENV ROCM_PATH=/opt/rocm +ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib: +ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942 +ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} + +ARG PYTHON_VERSION=3.12 + +RUN mkdir -p /app +WORKDIR /app +ENV DEBIAN_FRONTEND=noninteractive + +# Install Python and other dependencies +RUN apt-get update -y \ + && apt-get install -y software-properties-common git curl sudo vim less \ + && add-apt-repository ppa:deadsnakes/ppa \ + && apt-get update -y \ + && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \ + python${PYTHON_VERSION}-lib2to3 python-is-python3 \ + && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \ + && update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \ + && ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \ + && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \ + && python3 --version && python3 -m pip --version + +RUN pip install -U packaging cmake ninja wheel setuptools pybind11 Cython + +FROM base AS build_hipblaslt +ARG HIPBLASLT_BRANCH +ARG HIPBLAS_COMMON_BRANCH +# Set to "--legacy_hipblas_direct" for ROCm<=6.2 +ARG LEGACY_HIPBLASLT_OPTION +RUN git clone https://github.com/ROCm/hipBLAS-common.git +RUN cd hipBLAS-common \ + && git checkout ${HIPBLAS_COMMON_BRANCH} \ + && mkdir build \ + && cd build \ + && cmake .. \ + && make package \ + && dpkg -i ./*.deb +RUN git clone https://github.com/ROCm/hipBLASLt +RUN cd hipBLASLt \ + && git checkout ${HIPBLASLT_BRANCH} \ + && ./install.sh -d --architecture ${PYTORCH_ROCM_ARCH} ${LEGACY_HIPBLASLT_OPTION} \ + && cd build/release \ + && make package +RUN mkdir -p /app/install && cp /app/hipBLASLt/build/release/*.deb /app/hipBLAS-common/build/*.deb /app/install + +FROM base AS build_rccl +ARG RCCL_BRANCH +ARG RCCL_REPO +RUN git clone ${RCCL_REPO} +RUN cd rccl \ + && git checkout ${RCCL_BRANCH} \ + && ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH} +RUN mkdir -p /app/install && cp /app/rccl/build/release/*.deb /app/install + +FROM base AS build_triton +ARG TRITON_BRANCH +ARG TRITON_REPO +RUN git clone ${TRITON_REPO} +RUN cd triton \ + && git checkout ${TRITON_BRANCH} \ + && cd python \ + && python3 setup.py bdist_wheel --dist-dir=dist +RUN mkdir -p /app/install && cp /app/triton/python/dist/*.whl /app/install + +FROM base AS build_amdsmi +RUN cd /opt/rocm/share/amd_smi \ + && pip wheel . --wheel-dir=dist +RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install + +FROM base AS build_pytorch +ARG PYTORCH_BRANCH +ARG PYTORCH_VISION_BRANCH +ARG PYTORCH_REPO +ARG PYTORCH_VISION_REPO +ARG FA_BRANCH +ARG FA_REPO +RUN git clone ${PYTORCH_REPO} pytorch +RUN cd pytorch && git checkout ${PYTORCH_BRANCH} && \ + pip install -r requirements.txt && git submodule update --init --recursive \ + && python3 tools/amd_build/build_amd.py \ + && CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \ + && pip install dist/*.whl +RUN git clone ${PYTORCH_VISION_REPO} vision +RUN cd vision && git checkout ${PYTORCH_VISION_BRANCH} \ + && python3 setup.py bdist_wheel --dist-dir=dist \ + && pip install dist/*.whl +RUN git clone ${FA_REPO} +RUN cd flash-attention \ + && git checkout ${FA_BRANCH} \ + && git submodule update --init \ + && MAX_JOBS=64 GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist +RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \ + && cp /app/vision/dist/*.whl /app/install \ + && cp /app/flash-attention/dist/*.whl /app/install + +FROM base AS final +RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \ + dpkg -i /install/*deb \ + && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \ + && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status +RUN --mount=type=bind,from=build_rccl,src=/app/install/,target=/install \ + dpkg -i /install/*deb \ + && sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \ + && sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status +RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \ + pip install /install/*.whl +RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \ + pip install /install/*.whl +RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \ + pip install /install/*.whl + +ARG BASE_IMAGE +ARG HIPBLASLT_BRANCH +ARG LEGACY_HIPBLASLT_OPTION +ARG RCCL_BRANCH +ARG RCCL_REPO +ARG TRITON_BRANCH +ARG TRITON_REPO +ARG PYTORCH_BRANCH +ARG PYTORCH_VISION_BRANCH +ARG PYTORCH_REPO +ARG PYTORCH_VISION_REPO +ARG FA_BRANCH +ARG FA_REPO +RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \ + && echo "HIPBLAS_COMMON_BRANCH: ${HIPBLAS_COMMON_BRANCH}" >> /app/versions.txt \ + && echo "HIPBLASLT_BRANCH: ${HIPBLASLT_BRANCH}" >> /app/versions.txt \ + && echo "LEGACY_HIPBLASLT_OPTION: ${LEGACY_HIPBLASLT_OPTION}" >> /app/versions.txt \ + && echo "RCCL_BRANCH: ${RCCL_BRANCH}" >> /app/versions.txt \ + && echo "RCCL_REPO: ${RCCL_REPO}" >> /app/versions.txt \ + && echo "TRITON_BRANCH: ${TRITON_BRANCH}" >> /app/versions.txt \ + && echo "TRITON_REPO: ${TRITON_REPO}" >> /app/versions.txt \ + && echo "PYTORCH_BRANCH: ${PYTORCH_BRANCH}" >> /app/versions.txt \ + && echo "PYTORCH_VISION_BRANCH: ${PYTORCH_VISION_BRANCH}" >> /app/versions.txt \ + && echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \ + && echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \ + && echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \ + && echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt diff --git a/Dockerfile.tpu b/Dockerfile.tpu index b617932a85b47..e268b39476665 100644 --- a/Dockerfile.tpu +++ b/Dockerfile.tpu @@ -1,4 +1,4 @@ -ARG NIGHTLY_DATE="20241017" +ARG NIGHTLY_DATE="20250124" ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE" FROM $BASE_IMAGE diff --git a/Dockerfile.ubi b/Dockerfile.ubi index 2e5ad2967c534..54ba9ceb8a78b 100644 --- a/Dockerfile.ubi +++ b/Dockerfile.ubi @@ -42,13 +42,15 @@ FROM python-install as cuda-base RUN curl -Lo /etc/yum.repos.d/cuda-rhel9.repo \ https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo +ENV CUDA_HOME="/usr/local/cuda" \ + PATH="${CUDA_HOME}/bin:${PATH}" +ENV LD_LIBRARY_PATH="${CUDA_HOME}/lib64:${CUDA_HOME}/lib64/stubs/:${CUDA_HOME}/extras/CUPTI/lib64:${LD_LIBRARY_PATH}" RUN microdnf install -y \ cuda-nvcc-12-4 cuda-nvtx-12-4 cuda-libraries-devel-12-4 && \ - microdnf clean all + microdnf clean all && \ + ln -s ${CUDA_HOME}/lib64/stubs/libcuda.so /usr/lib64/ + -ENV CUDA_HOME="/usr/local/cuda" \ - PATH="${CUDA_HOME}/bin:${PATH}" \ - LD_LIBRARY_PATH="${CUDA_HOME}/lib64:${CUDA_HOME}/extras/CUPTI/lib64:${LD_LIBRARY_PATH}" ## Python cuda base ################################################################# FROM cuda-base AS python-cuda-base diff --git a/README.md b/README.md index f83c9d759b359..5fd30f2b1b9d7 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,8 @@ Easy, fast, and cheap LLM serving for everyone --- *Latest News* 🔥 +- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html). +- [2025/01] We hosted [the eighth vLLM meetup](https://lu.ma/zep56hui) with Google Cloud! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1epVkt4Zu8Jz_S5OhEHPc798emsYh2BwYfRuDDVEF7u4/edit?usp=sharing). - [2024/12] vLLM joins [pytorch ecosystem](https://pytorch.org/blog/vllm-joins-pytorch)! Easy, Fast, and Cheap LLM Serving for Everyone! - [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing), and Snowflake team [here](https://docs.google.com/presentation/d/1qF3RkDAbOULwz9WK5TOltt2fE9t6uIc_hVNLFAaQX6A/edit?usp=sharing). - [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there! @@ -34,10 +36,12 @@ Easy, fast, and cheap LLM serving for everyone ## About vLLM is a fast and easy-to-use library for LLM inference and serving. +Originally developed in the [Sky Computing Lab](https://sky.cs.berkeley.edu) at UC Berkeley, vLLM has evloved into a community-driven project with contributions from both academia and industry. + vLLM is fast with: - State-of-the-art serving throughput -- Efficient management of attention key and value memory with **PagedAttention** +- Efficient management of attention key and value memory with [**PagedAttention**](https://blog.vllm.ai/2023/06/20/vllm.html) - Continuous batching of incoming requests - Fast model execution with CUDA/HIP graph - Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), INT4, INT8, and FP8. @@ -68,16 +72,16 @@ Find the full list of supported models [here](https://docs.vllm.ai/en/latest/mod ## Getting Started -Install vLLM with `pip` or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source): +Install vLLM with `pip` or [from source](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#build-wheel-from-source): ```bash pip install vllm ``` -Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to learn more. -- [Installation](https://vllm.readthedocs.io/en/latest/getting_started/installation.html) -- [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html) -- [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html) +Visit our [documentation](https://docs.vllm.ai/en/latest/) to learn more. +- [Installation](https://docs.vllm.ai/en/latest/getting_started/installation/index.html) +- [Quickstart](https://docs.vllm.ai/en/latest/getting_started/quickstart.html) +- [List of Supported Models](https://docs.vllm.ai/en/latest/models/supported_models.html) ## Contributing @@ -90,28 +94,33 @@ vLLM is a community project. Our compute resources for development and testing a - +Cash Donations: - a16z +- Dropbox +- Sequoia Capital +- Skywork AI +- ZhenFund + +Compute Resources: - AMD - Anyscale - AWS - Crusoe Cloud - Databricks - DeepInfra -- Dropbox - Google Cloud - Lambda Lab - Nebius +- Novita AI - NVIDIA - Replicate - Roblox - RunPod -- Sequoia Capital -- Skywork AI - Trainy - UC Berkeley - UC San Diego -- ZhenFund + +Slack Sponsor: Anyscale We also have an official fundraising venue through [OpenCollective](https://opencollective.com/vllm). We plan to use the fund to support the development, maintenance, and adoption of vLLM. diff --git a/SECURITY.md b/SECURITY.md index ad3f1f16ab560..47196a1f1221e 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -4,7 +4,7 @@ If you believe you have found a security vulnerability in vLLM, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem. -Please report security issues privately using [the vulnerability submission form](https://github.com/vllm-project/vllm/security/advisories/new). +Please report security issues privately using [the vulnerability submission form](https://github.com/vllm-project/vllm/security/advisories/new). Reports will then be triaged by the [vulnerability management team](https://docs.vllm.ai/en/latest/contributing/vulnerability_management.html). --- diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index b67849038cf0d..0612e8778aca5 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -22,6 +22,7 @@ class RequestFuncInput: prompt_len: int output_len: int model: str + model_name: Optional[str] = None best_of: int = 1 logprobs: Optional[int] = None extra_body: Optional[dict] = None @@ -34,6 +35,7 @@ class RequestFuncOutput: generated_text: str = "" success: bool = False latency: float = 0.0 + output_tokens: int = 0 ttft: float = 0.0 # Time to first token itl: List[float] = field( default_factory=list) # List of inter-token latencies @@ -49,7 +51,8 @@ async def async_request_tgi( api_url = request_func_input.api_url assert api_url.endswith("generate_stream") - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: params = { "best_of": request_func_input.best_of, "max_new_tokens": request_func_input.output_len, @@ -78,7 +81,7 @@ async def async_request_tgi( continue chunk_bytes = chunk_bytes.decode("utf-8") - #NOTE: Sometimes TGI returns a ping response without + # NOTE: Sometimes TGI returns a ping response without # any data, we should skip it. if chunk_bytes.startswith(":"): continue @@ -121,7 +124,8 @@ async def async_request_trt_llm( api_url = request_func_input.api_url assert api_url.endswith("generate_stream") - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: assert request_func_input.best_of == 1 payload = { "accumulate_tokens": True, @@ -155,7 +159,7 @@ async def async_request_trt_llm( timestamp = time.perf_counter() # First token if ttft == 0.0: - ttft = time.perf_counter() - st + ttft = timestamp - st output.ttft = ttft # Decoding phase @@ -185,7 +189,8 @@ async def async_request_deepspeed_mii( request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: assert request_func_input.best_of == 1 payload = { @@ -233,17 +238,23 @@ async def async_request_openai_completions( ("completions", "profile") ), "OpenAI Completions API URL must end with 'completions' or 'profile'." - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: payload = { - "model": request_func_input.model, + "model": request_func_input.model_name \ + if request_func_input.model_name else request_func_input.model, "prompt": request_func_input.prompt, "temperature": 0.0, "best_of": request_func_input.best_of, "max_tokens": request_func_input.output_len, "logprobs": request_func_input.logprobs, "stream": True, - "ignore_eos": request_func_input.ignore_eos, + "stream_options": { + "include_usage": True, + }, } + if request_func_input.ignore_eos: + payload["ignore_eos"] = request_func_input.ignore_eos if request_func_input.extra_body: payload.update(request_func_input.extra_body) headers = { @@ -254,7 +265,6 @@ async def async_request_openai_completions( output.prompt_len = request_func_input.prompt_len generated_text = "" - ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st try: @@ -269,15 +279,16 @@ async def async_request_openai_completions( chunk = chunk_bytes.decode("utf-8").removeprefix( "data: ") - if chunk == "[DONE]": - latency = time.perf_counter() - st - else: + if chunk != "[DONE]": data = json.loads(chunk) # NOTE: Some completion API might have a last # usage summary response without a token so we # want to check a token was generated - if data["choices"][0]["text"]: + if choices := data.get("choices"): + # Note that text could be empty here + # e.g. for special tokens + text = choices[0].get("text") timestamp = time.perf_counter() # First token if not first_chunk_received: @@ -291,7 +302,10 @@ async def async_request_openai_completions( most_recent_timestamp) most_recent_timestamp = timestamp - generated_text += data["choices"][0]["text"] + generated_text += text or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens") if first_chunk_received: output.success = True else: @@ -300,7 +314,7 @@ async def async_request_openai_completions( "Never received a valid chunk to calculate TTFT." "This response will be marked as failed!") output.generated_text = generated_text - output.latency = latency + output.latency = most_recent_timestamp - st else: output.error = response.reason or "" output.success = False @@ -323,12 +337,14 @@ async def async_request_openai_chat_completions( "chat/completions" ), "OpenAI Chat Completions API URL must end with 'chat/completions'." - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: content = [{"type": "text", "text": request_func_input.prompt}] if request_func_input.multi_modal_content: content.append(request_func_input.multi_modal_content) payload = { - "model": request_func_input.model, + "model": request_func_input.model_name \ + if request_func_input.model_name else request_func_input.model, "messages": [ { "role": "user", @@ -338,8 +354,12 @@ async def async_request_openai_chat_completions( "temperature": 0.0, "max_completion_tokens": request_func_input.output_len, "stream": True, - "ignore_eos": request_func_input.ignore_eos, + "stream_options": { + "include_usage": True, + }, } + if request_func_input.ignore_eos: + payload["ignore_eos"] = request_func_input.ignore_eos if request_func_input.extra_body: payload.update(request_func_input.extra_body) headers = { @@ -365,17 +385,15 @@ async def async_request_openai_chat_completions( chunk = chunk_bytes.decode("utf-8").removeprefix( "data: ") - if chunk == "[DONE]": - latency = time.perf_counter() - st - else: + if chunk != "[DONE]": timestamp = time.perf_counter() data = json.loads(chunk) - delta = data["choices"][0]["delta"] - if delta.get("content", None): + if choices := data.get("choices"): + content = choices[0]["delta"].get("content") # First token if ttft == 0.0: - ttft = time.perf_counter() - st + ttft = timestamp - st output.ttft = ttft # Decoding phase @@ -383,13 +401,16 @@ async def async_request_openai_chat_completions( output.itl.append(timestamp - most_recent_timestamp) - generated_text += delta["content"] + generated_text += content or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens") most_recent_timestamp = timestamp output.generated_text = generated_text output.success = True - output.latency = latency + output.latency = most_recent_timestamp - st else: output.error = response.reason or "" output.success = False @@ -417,14 +438,35 @@ def get_model(pretrained_model_name_or_path: str) -> str: def get_tokenizer( - pretrained_model_name_or_path: str, trust_remote_code: bool + pretrained_model_name_or_path: str, + tokenizer_mode: str = "auto", + trust_remote_code: bool = False, + **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: if pretrained_model_name_or_path is not None and not os.path.exists( pretrained_model_name_or_path): pretrained_model_name_or_path = get_model( pretrained_model_name_or_path) - return AutoTokenizer.from_pretrained(pretrained_model_name_or_path, - trust_remote_code=trust_remote_code) + if tokenizer_mode == "slow": + if kwargs.get("use_fast", False): + raise ValueError( + "Cannot use the fast tokenizer in slow tokenizer mode.") + kwargs["use_fast"] = False + if tokenizer_mode == "mistral": + try: + from vllm.transformers_utils.tokenizer import MistralTokenizer + except ImportError as e: + raise ImportError("MistralTokenizer requires vllm package.\n" + "Please install it with `pip install vllm` " + "to use mistral tokenizer mode.") from e + return MistralTokenizer.from_pretrained( + str(pretrained_model_name_or_path)) + else: + return AutoTokenizer.from_pretrained( + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + **kwargs, + ) ASYNC_REQUEST_FUNCS = { diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 0a14aedd5feba..77c4f6aa927e4 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -13,6 +13,7 @@ from vllm import LLM, SamplingParams from vllm.engine.arg_utils import EngineArgs from vllm.inputs import PromptType +from vllm.sampling_params import BeamSearchParams from vllm.utils import FlexibleArgumentParser @@ -40,6 +41,20 @@ def main(args: argparse.Namespace): "prompt_token_ids": batch } for batch in dummy_prompt_token_ids.tolist()] + def llm_generate(): + if not args.use_beam_search: + llm.generate(dummy_prompts, + sampling_params=sampling_params, + use_tqdm=False) + else: + llm.beam_search( + dummy_prompts, + BeamSearchParams( + beam_width=args.n, + max_tokens=args.output_len, + ignore_eos=True, + )) + def run_to_completion(profile_dir: Optional[str] = None): if profile_dir: with torch.profiler.profile( @@ -49,15 +64,11 @@ def run_to_completion(profile_dir: Optional[str] = None): ], on_trace_ready=torch.profiler.tensorboard_trace_handler( str(profile_dir))) as p: - llm.generate(dummy_prompts, - sampling_params=sampling_params, - use_tqdm=False) - print(p.key_averages()) + llm_generate() + print(p.key_averages().table(sort_by="self_cuda_time_total")) else: start_time = time.perf_counter() - llm.generate(dummy_prompts, - sampling_params=sampling_params, - use_tqdm=False) + llm_generate() end_time = time.perf_counter() latency = end_time - start_time return latency diff --git a/benchmarks/benchmark_long_document_qa_throughput.py b/benchmarks/benchmark_long_document_qa_throughput.py new file mode 100644 index 0000000000000..0b8fba38156f1 --- /dev/null +++ b/benchmarks/benchmark_long_document_qa_throughput.py @@ -0,0 +1,183 @@ +""" +Offline benchmark to test the long document QA throughput. + +Example usage: + # This workload samples 8 different prompts with a default input + # length of 20000 tokens, then replicates each prompt 2 times + # in random order. + python benchmark_long_document_qa_throughput.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --num-documents 8 \ + --repeat-count 2 + +Commandline arguments: + --num-documents: The number of documents to sample prompts from. + + --document-length: The length of each document in tokens. + (Optional, default: 20000) + + --output-len: The number of tokens to generate for each prompt. + (Optional, default: 10) + + --repeat-count: The number of times to repeat each prompt. + (Optional, default: 2) + + --repeat-mode: The mode to repeat prompts. The supported modes are: + - 'random': shuffle the prompts randomly. (Default) + - 'tile': the entire prompt list is repeated in sequence. (Potentially + lowest cache hit) + - 'interleave': each prompt is repeated consecutively before + moving to the next element. (Highest cache hit) + + --shuffle-seed: Random seed when the repeat mode is "random". + (Optional, default: 0) + +In the meantime, it also supports all the vLLM engine args to initialize the +LLM engine. You can refer to the `vllm.engine.arg_utils.EngineArgs` for more +details. +""" + +import dataclasses +import random +import time + +from vllm import LLM, SamplingParams +from vllm.engine.arg_utils import EngineArgs +from vllm.utils import FlexibleArgumentParser + + +def test_long_document_qa(llm=None, sampling_params=None, prompts=None): + """ + Test long document QA with the given prompts and sampling parameters. + Print the time spent in processing all the prompts. + + Args: + llm: The language model used for generating responses. + sampling_params: Sampling parameter used to generate the response. + prompts: A list of prompt strings to be processed by the LLM. + """ + start_time = time.time() + llm.generate(prompts, sampling_params=sampling_params) + end_time = time.time() + print(f"Time to execute all requests: {end_time - start_time:.4f} secs") + + +def repeat_prompts(prompts, repeat_count, mode: str): + """ + Repeat each prompt in the list for a specified number of times. + The order of prompts in the output list depends on the mode. + + Args: + prompts: A list of prompts to be repeated. + repeat_count: The number of times each prompt is repeated. + mode: The mode of repetition. Supported modes are: + - 'random': Shuffle the prompts randomly after repetition. + - 'tile': Repeat the entire prompt list in sequence. + Example: [1, 2, 3] -> [1, 2, 3, 1, 2, 3]. + - 'interleave': Repeat each prompt consecutively before moving to + the next. Example: [1, 2, 3] -> [1, 1, 2, 2, 3, 3]. + + Returns: + A list of repeated prompts in the specified order. + + Raises: + ValueError: If an invalid mode is provided. + """ + print("Repeat mode: ", mode) + if mode == 'random': + repeated_prompts = prompts * repeat_count + random.shuffle(repeated_prompts) + return repeated_prompts + elif mode == 'tile': + return prompts * repeat_count + elif mode == 'interleave': + repeated_prompts = [] + for prompt in prompts: + repeated_prompts.extend([prompt] * repeat_count) + return repeated_prompts + else: + raise ValueError(f"Invalid mode: {mode}, only support " + "'random', 'tile', 'interleave'") + + +def main(args): + random.seed(args.shuffle_seed) + + # Prepare the prompts: + # we append the document id at the beginning to avoid any of the document + # being the prefix of other documents + prompts = [ + str(i) + ' '.join(['hi'] * args.document_length) + for i in range(args.num_documents) + ] + + prompts = repeat_prompts(prompts, args.repeat_count, mode=args.repeat_mode) + + warmup_prompts = [ + "This is warm up request " + str(i) + \ + ' '.join(['hi'] * args.document_length) + for i in range(args.num_documents)] + + # Create the LLM engine + engine_args = EngineArgs.from_cli_args(args) + llm = LLM(**dataclasses.asdict(engine_args)) + sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) + + print("------warm up------") + test_long_document_qa( + llm=llm, + prompts=warmup_prompts, + sampling_params=sampling_params, + ) + + print("------start generating------") + test_long_document_qa( + llm=llm, + prompts=prompts, + sampling_params=sampling_params, + ) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description= + 'Benchmark the performance with or without automatic prefix caching.') + + parser.add_argument( + '--document-length', + type=int, + # Roughly the number of tokens for a system paper, + # excluding images + default=20000, + help='Range of input lengths for sampling prompts,' + 'specified as "min:max" (e.g., "128:256").') + + parser.add_argument('--num-documents', + type=int, + default=8, + help='Range of input lengths for sampling prompts,' + 'specified as "min:max" (e.g., "128:256").') + + parser.add_argument('--output-len', type=int, default=10) + + parser.add_argument('--repeat-count', + type=int, + default=2, + help='Number of times to repeat each prompt') + + parser.add_argument("--repeat-mode", + type=str, + default='random', + help='The mode to repeat prompts. The supported ' + 'modes are "random", "tile", and "interleave". ' + 'See repeat_prompts() in the source code for details.') + + parser.add_argument("--shuffle-seed", + type=int, + default=0, + help='Random seed when the repeat mode is "random"') + + parser = EngineArgs.add_cli_args(parser) + args = parser.parse_args() + main(args) diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index 5e9381f712e10..3ab421a89c935 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -10,7 +10,8 @@ --model meta-llama/Llama-2-7b-chat-hf \ --enable-prefix-caching \ --num-prompts 1 \ - --repeat-count 100 + --repeat-count 100 \ + --input-length-range 128:256 ShareGPT example usage: # This command samples 20 prompts with input lengths diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 4eb0e1f8ac903..8b3212831e7e0 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -25,6 +25,7 @@ import argparse import asyncio import base64 +import gc import io import json import os @@ -199,7 +200,7 @@ def sample_sonnet_requests( return sampled_requests -def sample_mmmu_pro_vision_requests( +def sample_vision_arena_requests( dataset, num_requests: int, tokenizer: PreTrainedTokenizerBase, @@ -211,13 +212,7 @@ def sample_mmmu_pro_vision_requests( if len(sampled_requests) == num_requests: break - # MMMU-Pro vision direct prompt - # Ref: https://github.com/MMMU-Benchmark/MMMU/blob/6ce42f4d8f70c1841c67867152648974415b5cac/mmmu-pro/prompts.yaml#L5 - prompt = ( - "Answer with the option letter from the given choices directly. " - "The last line of your response should be of the following " - "format: 'Answer: $LETTER' (without quotes) where LETTER is one of " - "options.") + prompt = data["turns"][0][0]['content'] prompt_token_ids = tokenizer(prompt).input_ids if fixed_output_len is None: @@ -229,10 +224,10 @@ def sample_mmmu_pro_vision_requests( output_len = fixed_output_len assert isinstance( - data["image"], + data["images"][0], Image), ("Input image format must be `PIL.Image.Image`, " f"given {type(data['image'])}.") - image: Image = data["image"] + image: Image = data["images"][0] image = image.convert("RGB") image_data = io.BytesIO() image.save(image_data, format='JPEG') @@ -251,7 +246,7 @@ def sample_mmmu_pro_vision_requests( def sample_hf_requests( dataset_path: str, - dataset_subset: str, + dataset_subset: Optional[str], dataset_split: str, num_requests: int, tokenizer: PreTrainedTokenizerBase, @@ -259,19 +254,17 @@ def sample_hf_requests( fixed_output_len: Optional[int] = None, ) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]: - # Special case for MMMU-Pro vision dataset - if dataset_path == 'MMMU/MMMU_Pro' and dataset_subset == 'vision': - assert dataset_split == "test" + # Special case for vision_arena dataset + if dataset_path == 'lmarena-ai/vision-arena-bench-v0.1' \ + and dataset_subset is None: + assert dataset_split == "train" dataset = load_dataset(dataset_path, name=dataset_subset, split=dataset_split, streaming=True) - assert "image" in dataset.features, ( - "MMMU/MMMU_Pro vision dataset must have 'image' column.") - filter_func = lambda x: isinstance(x["image"], Image) - dataset = dataset.shuffle(seed=random_seed).filter(filter_func) - return sample_mmmu_pro_vision_requests(dataset, num_requests, - tokenizer, fixed_output_len) + dataset = dataset.shuffle(seed=random_seed) + return sample_vision_arena_requests(dataset, num_requests, tokenizer, + fixed_output_len) dataset = load_dataset(dataset_path, name=dataset_subset, @@ -423,7 +416,7 @@ def calculate_metrics( tokenizer: PreTrainedTokenizerBase, selected_percentile_metrics: List[str], selected_percentiles: List[float], - gootput_config_dict: Dict[str, float], + goodput_config_dict: Dict[str, float], ) -> Tuple[BenchmarkMetrics, List[int]]: actual_output_lens: List[int] = [] total_input = 0 @@ -436,19 +429,23 @@ def calculate_metrics( e2els: List[float] = [] for i in range(len(outputs)): if outputs[i].success: - # We use the tokenizer to count the number of output tokens for all - # serving backends instead of looking at len(outputs[i].itl) since - # multiple output tokens may be bundled together - # Note : this may inflate the output token count slightly - output_len = len( - tokenizer(outputs[i].generated_text, - add_special_tokens=False).input_ids) + output_len = outputs[i].output_tokens + + if output_len is None: + # We use the tokenizer to count the number of output tokens + # for some serving backends instead of looking at + # len(outputs[i].itl) since multiple output tokens may be + # bundled together + # Note : this may inflate the output token count slightly + output_len = len( + tokenizer(outputs[i].generated_text, + add_special_tokens=False).input_ids) actual_output_lens.append(output_len) total_input += input_requests[i][1] tpot = 0 if output_len > 1: - tpot = (outputs[i].latency - outputs[i].ttft) / (output_len - - 1) + latency_minus_ttft = outputs[i].latency - outputs[i].ttft + tpot = latency_minus_ttft / (output_len - 1) tpots.append(tpot) # Note: if output_len <= 1, we regard tpot as 0 for goodput all_tpots.append(tpot) @@ -459,21 +456,21 @@ def calculate_metrics( else: actual_output_lens.append(0) - if gootput_config_dict: + if goodput_config_dict: valid_metrics = [] slo_values = [] - if "ttft" in gootput_config_dict: + if "ttft" in goodput_config_dict: valid_metrics.append(ttfts) - slo_values.append(gootput_config_dict["ttft"] / + slo_values.append(goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION) - if "tpot" in gootput_config_dict: + if "tpot" in goodput_config_dict: valid_metrics.append(all_tpots) - slo_values.append(gootput_config_dict["tpot"] / + slo_values.append(goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION) - if "e2el" in gootput_config_dict: + if "e2el" in goodput_config_dict: valid_metrics.append(e2els) - slo_values.append(gootput_config_dict["e2el"] / + slo_values.append(goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION) for req_metric in zip(*valid_metrics): @@ -525,6 +522,7 @@ async def benchmark( api_url: str, base_url: str, model_id: str, + model_name: str, tokenizer: PreTrainedTokenizerBase, input_requests: List[Tuple[str, int, int]], logprobs: Optional[int], @@ -536,7 +534,7 @@ async def benchmark( selected_percentile_metrics: List[str], selected_percentiles: List[str], ignore_eos: bool, - gootput_config_dict: Dict[str, float], + goodput_config_dict: Dict[str, float], max_concurrency: Optional[int], ): if backend in ASYNC_REQUEST_FUNCS: @@ -553,6 +551,7 @@ async def benchmark( "Multi-modal content is only supported on 'openai-chat' backend.") test_input = RequestFuncInput( model=model_id, + model_name=model_name, prompt=test_prompt, api_url=api_url, prompt_len=test_prompt_len, @@ -573,6 +572,7 @@ async def benchmark( if profile: print("Starting profiler...") profile_input = RequestFuncInput(model=model_id, + model_name=model_name, prompt=test_prompt, api_url=base_url + "/start_profile", prompt_len=test_prompt_len, @@ -616,6 +616,7 @@ async def limited_request_func(request_func_input, pbar): async for request in get_request(input_requests, request_rate, burstiness): prompt, prompt_len, output_len, mm_content = request request_func_input = RequestFuncInput(model=model_id, + model_name=model_name, prompt=prompt, api_url=api_url, prompt_len=prompt_len, @@ -657,7 +658,7 @@ async def limited_request_func(request_func_input, pbar): tokenizer=tokenizer, selected_percentile_metrics=selected_percentile_metrics, selected_percentiles=selected_percentiles, - gootput_config_dict=gootput_config_dict, + goodput_config_dict=goodput_config_dict, ) print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) @@ -669,7 +670,7 @@ async def limited_request_func(request_func_input, pbar): metrics.total_output)) print("{:<40} {:<10.2f}".format("Request throughput (req/s):", metrics.request_throughput)) - if gootput_config_dict: + if goodput_config_dict: print("{:<40} {:<10.2f}".format("Request goodput (req/s):", metrics.request_goodput)) print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", @@ -684,7 +685,7 @@ async def limited_request_func(request_func_input, pbar): "total_output_tokens": metrics.total_output, "request_throughput": metrics.request_throughput, "request_goodput:": - metrics.request_goodput if gootput_config_dict else None, + metrics.request_goodput if goodput_config_dict else None, "output_throughput": metrics.output_throughput, "total_token_throughput": metrics.total_token_throughput, "input_lens": [output.prompt_len for output in outputs], @@ -740,11 +741,11 @@ def process_one_metric( def check_goodput_args(args): # Check and parse goodput arguments - gootput_config_dict = {} + goodput_config_dict = {} VALID_NAMES = ["ttft", "tpot", "e2el"] if args.goodput: - gootput_config_dict = parse_goodput(args.goodput) - for slo_name, slo_val in gootput_config_dict.items(): + goodput_config_dict = parse_goodput(args.goodput) + for slo_name, slo_val in goodput_config_dict.items(): if slo_name not in VALID_NAMES: raise ValueError( f"Invalid metric name found, {slo_name}: {slo_val}. " @@ -755,22 +756,22 @@ def check_goodput_args(args): f"Invalid value found, {slo_name}: {slo_val}. " "The service level objective value should be " "non-negative.") - return gootput_config_dict + return goodput_config_dict def parse_goodput(slo_pairs): - gootput_config_dict = {} + goodput_config_dict = {} try: for slo_pair in slo_pairs: slo_name, slo_val = slo_pair.split(":") - gootput_config_dict[slo_name] = float(slo_val) + goodput_config_dict[slo_name] = float(slo_val) except ValueError as err: raise argparse.ArgumentTypeError( "Invalid format found for service level objectives. " "Specify service level objectives for goodput as \"KEY:VALUE\" " "pairs, where the key is a metric name, and the value is a " "number in milliseconds.") from err - return gootput_config_dict + return goodput_config_dict def main(args: argparse.Namespace): @@ -780,6 +781,7 @@ def main(args: argparse.Namespace): backend = args.backend model_id = args.model + model_name = args.served_model_name tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model tokenizer_mode = args.tokenizer_mode @@ -869,7 +871,11 @@ def main(args: argparse.Namespace): else: raise ValueError(f"Unknown dataset: {args.dataset_name}") - gootput_config_dict = check_goodput_args(args) + goodput_config_dict = check_goodput_args(args) + + # Avoid GC processing "static" data - reduce pause times. + gc.collect() + gc.freeze() benchmark_result = asyncio.run( benchmark( @@ -877,6 +883,7 @@ def main(args: argparse.Namespace): api_url=api_url, base_url=base_url, model_id=model_id, + model_name=model_name, tokenizer=tokenizer, input_requests=input_requests, logprobs=args.logprobs, @@ -890,7 +897,7 @@ def main(args: argparse.Namespace): float(p) for p in args.metric_percentiles.split(",") ], ignore_eos=args.ignore_eos, - gootput_config_dict=gootput_config_dict, + goodput_config_dict=goodput_config_dict, max_concurrency=args.max_concurrency, )) @@ -919,8 +926,8 @@ def main(args: argparse.Namespace): ) # Traffic - result_json["request_rate"] = ( - args.request_rate if args.request_rate < float("inf") else "inf") + result_json["request_rate"] = (args.request_rate if args.request_rate + < float("inf") else "inf") result_json["burstiness"] = args.burstiness result_json["max_concurrency"] = args.max_concurrency @@ -1222,5 +1229,12 @@ def main(args: argparse.Namespace): 'always use the slow tokenizer. \n* ' '"mistral" will always use the `mistral_common` tokenizer.') + parser.add_argument("--served-model-name", + type=str, + default=None, + help="The model name used in the API. " + "If not specified, the model name will be the " + "same as the ``--model`` argument. ") + args = parser.parse_args() main(args) diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index d0353bc8cb42a..b87496ca3b2b4 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -3,7 +3,7 @@ import itertools import pickle as pkl import time -from typing import Callable, Iterable, List, Tuple +from typing import Callable, Iterable, List, Optional, Tuple import torch import torch.utils.benchmark as TBenchmark @@ -12,6 +12,8 @@ from weight_shapes import WEIGHT_SHAPES from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + w8a8_block_fp8_matmul) from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) @@ -38,8 +40,15 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, ).blocked_autorange(min_run_time=min_run_time) -def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, - sub_label: str) -> Iterable[TMeasurement]: +def bench_int8( + dtype: torch.dtype, + m: int, + k: int, + n: int, + label: str, + sub_label: str, + bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]: + """Benchmark INT8-based kernels.""" assert dtype == torch.int8 a, b = make_rand_tensors(torch.int8, m, n, k) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) @@ -48,155 +57,132 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, azp = torch.zeros((m, ), device="cuda", dtype=torch.int32) azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32) + bench_fns = { + "pytorch_bf16_bf16_bf16_matmul-no-scales": + lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16) + ), + "pytorch_fp16_fp16_fp16_matmul-no-scales": + lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)), + "cutlass_i8_i8_bf16_scaled_mm": + lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16), + "cutlass_i8_i8_bf16_scaled_mm_bias": + lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16, + bias), + "cutlass_i8_i8_bf16_scaled_mm_azp": + lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. + bfloat16, azp_adj), + "cutlass_i8_i8_bf16_scaled_mm_azp_bias": + lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. + bfloat16, azp_adj, None, bias), + "cutlass_i8_i8_bf16_scaled_mm_azp_pt": + lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. + bfloat16, azp_adj, azp), + "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias": + lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. + bfloat16, azp_adj, azp, bias), + } + timers = [] - # pytorch impl - bfloat16 - timers.append( - bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", - torch.mm, a.to(dtype=torch.bfloat16), - b.to(dtype=torch.bfloat16))) - - # pytorch impl - float16 - timers.append( - bench_fn(label, sub_label, - "pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm, - a.to(dtype=torch.float16), b.to(dtype=torch.float16))) - - # cutlass impl - timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm", - ops.cutlass_scaled_mm, a, b, scale_a, scale_b, - torch.bfloat16)) - - # cutlass with bias - timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias", - ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16, - bias)) - - # cutlass with azp per-tensor - timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp", - ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b, - torch.bfloat16, azp_adj)) - - # cutlass with azp per-tensor + bias - timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_bias", - ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b, - torch.bfloat16, azp_adj, None, bias)) - - # cutlass with azp per-token - timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt", - ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b, - torch.bfloat16, azp_adj, azp)) - - # cutlass with azp per-token + bias - timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias", - ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b, - torch.bfloat16, azp_adj, azp, bias)) + for name, fn in bench_fns.items(): + # If bench_kernels is None, run all. Otherwise, run only exact matches. + if bench_kernels is None or name in bench_kernels: + print(f"Running {name}") + timers.append(bench_fn(label, sub_label, name, fn)) return timers -def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, - sub_label: str) -> Iterable[TMeasurement]: +def bench_fp8( + dtype: torch.dtype, + m: int, + k: int, + n: int, + label: str, + sub_label: str, + bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]: + """Benchmark FP8-based kernels.""" assert dtype == torch.float8_e4m3fn a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k) + a_cont = a.contiguous() scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + block_scale_a = torch.rand((m, k // 128), + device="cuda", + dtype=torch.float32) + block_scale_b = torch.rand((k // 128, n // 128), + device="cuda", + dtype=torch.float32) + block_scale_a_M_major = block_scale_a.t().contiguous().t() + block_scale_b_K_major = block_scale_b.t().contiguous().t() bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) - timers = [] + print(m, k, n) + + bench_fns = { + "pytorch_bf16_bf16_bf16_matmul-no-scales": + lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16) + ), + "pytorch_fp16_fp16_fp16_matmul-no-scales": + lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)), + "pytorch_fp8_fp8_fp16_scaled_mm": + lambda: torch._scaled_mm( + a, b, scale_a, scale_b, out_dtype=torch.float16), + "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum": + lambda: torch._scaled_mm(a, + b, + scale_a, + scale_b, + out_dtype=torch.float16, + use_fast_accum=True), + "pytorch_fp8_fp8_bf16_scaled_mm": + lambda: torch._scaled_mm( + a, b, scale_a, scale_b, out_dtype=torch.bfloat16), + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum": + lambda: torch._scaled_mm(a, + b, + scale_a, + scale_b, + out_dtype=torch.bfloat16, + use_fast_accum=True), + "cutlass_fp8_fp8_bf16_scaled_mm": + lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16), + "cutlass_fp8_fp8_fp16_scaled_mm": + lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16), + "cutlass_fp8_fp8_bf16_scaled_mm_bias": + lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16, + bias), + "cutlass_fp8_fp8_fp16_scaled_mm_bias": + lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16, + bias.to(dtype=torch.float16)), + "triton_fp8_fp8_fp16_scaled_mm_blockwise": + lambda: w8a8_block_fp8_matmul(a_cont, b.t(), block_scale_a, + block_scale_b.t(), (128, 128)), + "cutlass_fp8_fp8_fp16_scaled_mm_blockwise": + lambda: ops.cutlass_scaled_mm(a, b, block_scale_a_M_major, + block_scale_b_K_major, torch.float16), + } - # pytorch impl w. bf16 - timers.append( - bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", - torch.mm, a.to(dtype=torch.bfloat16, device="cuda"), - b.to(dtype=torch.bfloat16, device="cuda"))) - - # pytorch impl: bf16 output, without fp8 fast accum - timers.append( - bench_fn(label, - sub_label, - "pytorch_fp8_fp8_bf16_scaled_mm", - torch._scaled_mm, - a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.bfloat16)) - - # pytorch impl: bf16 output, with fp8 fast accum - timers.append( - bench_fn(label, - sub_label, - "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", - torch._scaled_mm, - a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.bfloat16, - use_fast_accum=True)) - - # pytorch impl: fp16 output, without fp8 fast accum - timers.append( - bench_fn(label, - sub_label, - "pytorch_fp8_fp8_fp16_scaled_mm", - torch._scaled_mm, - a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.float16)) - - # pytorch impl: fp16 output, with fp8 fast accum - timers.append( - bench_fn(label, - sub_label, - "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum", - torch._scaled_mm, - a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.float16, - use_fast_accum=True)) - - # cutlass impl: bf16 output - timers.append( - bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm", - ops.cutlass_scaled_mm, a, b, scale_a, scale_b, - torch.bfloat16)) - # cutlass impl: fp16 output - timers.append( - bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm", - ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16)) - - # cutlass impl: bf16 output, with bias - timers.append( - bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm_bias", - ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16, - bias)) - - # cutlass impl: fp16 output, with bias - timers.append( - bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm_bias", - ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16, - bias.to(dtype=torch.float16))) + timers = [] + for name, fn in bench_fns.items(): + # If bench_kernels is None, run all. Otherwise, run only exact matches. + if bench_kernels is None or name in bench_kernels: + print(f"Running {name}") + timers.append(bench_fn(label, sub_label, name, fn)) return timers -def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str, - sub_label: str) -> Iterable[TMeasurement]: +def bench(dtype: torch.dtype, + m: int, + k: int, + n: int, + label: str, + sub_label: str, + bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]: if dtype == torch.int8: - return bench_int8(dtype, m, k, n, label, sub_label) + return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels) if dtype == torch.float8_e4m3fn: - return bench_fp8(dtype, m, k, n, label, sub_label) + return bench_fp8(dtype, m, k, n, label, sub_label, bench_kernels) raise ValueError("unsupported type") @@ -207,18 +193,22 @@ def print_timers(timers: Iterable[TMeasurement]): def run(dtype: torch.dtype, - MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: + MKNs: Iterable[Tuple[int, int, int]], + bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]: results = [] for m, k, n in MKNs: - timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", - f"MKN=({m}x{k}x{n})") + timers = bench(dtype, + m, + k, + n, + f"scaled-{dtype}-gemm", + f"MKN=({m}x{k}x{n})", + bench_kernels=bench_kernels) print_timers(timers) results.extend(timers) - return results -# output makers def make_output(data: Iterable[TMeasurement], MKNs: Iterable[Tuple[int, int, int]], base_description: str, @@ -232,15 +222,11 @@ def make_output(data: Iterable[TMeasurement], pkl.dump(data, f) -# argparse runners - - def run_square_bench(args): dim_sizes = list( range(args.dim_start, args.dim_end + 1, args.dim_increment)) MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) - data = run(args.dtype, MKNs) - + data = run(args.dtype, MKNs, bench_kernels=args.kernels) make_output(data, MKNs, f"square_bench-{args.dtype}") @@ -251,8 +237,7 @@ def run_range_bench(args): Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes MKNs = list(zip(Ms, Ks, Ns)) - data = run(args.dtype, MKNs) - + data = run(args.dtype, MKNs, bench_kernels=args.kernels) make_output(data, MKNs, f"range_bench-{args.dtype}") @@ -278,7 +263,7 @@ def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: for k, n in KNs: MKNs.append((m, k, n)) - data = run(args.dtype, MKNs) + data = run(args.dtype, MKNs, bench_kernels=args.kernels) model_bench_data.append(data) # Print all results @@ -328,6 +313,15 @@ def to_torch_dtype(dt): type=to_torch_dtype, required=True, help="Available options are ['int8', 'fp8']") + parser.add_argument( + "--kernels", + nargs="+", + type=str, + default=None, + help= + "Exact names of the kernels to benchmark. If not set, runs all kernels." + ) + subparsers = parser.add_subparsers(dest="cmd") square_parser = subparsers.add_parser("square_bench") @@ -362,4 +356,4 @@ def to_torch_dtype(dt): model_parser.set_defaults(func=run_model_bench) args = parser.parse_args() - args.func(args) \ No newline at end of file + args.func(args) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py new file mode 100644 index 0000000000000..e1f613e1da509 --- /dev/null +++ b/benchmarks/kernels/benchmark_lora.py @@ -0,0 +1,1147 @@ +import argparse +import copy +import json +import pickle +import time +from dataclasses import dataclass +from enum import Enum, auto +from itertools import product +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from utils import ArgPool, Bench, CudaGraphBenchParams +from weight_shapes import WEIGHT_SHAPES + +from vllm.lora.ops.triton_ops.bgmv_expand import bgmv_expand +from vllm.lora.ops.triton_ops.bgmv_expand_slice import bgmv_expand_slice +from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink +from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand +from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink +from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT +from vllm.utils import FlexibleArgumentParser + +DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) +DEFAULT_TP_SIZES = [1] +DEFAULT_BATCH_SIZES = [ + 1, 16, 32, 64, 128, 192, 256, 320, 384, 448, 512, 640, 768, 896, 1024, + 2048, 3072, 4096, 5120, 6144, 7168, 8192 +] +DEFAULT_HIDDEN_SIZES = [1024, 2048, 4096, 8192, 16384] +DEFAULT_LORA_RANKS = [16] +DEFAULT_NUM_LORAS = [1, 2, 3, 4] +DEFAULT_SORT_BY_LORA_IDS = [False, True] +DEFAULT_SEQ_LENGTHS = [1] +DEFAULT_EXPAND_FN_ADD_INPUTS = [True, False] + + +# Utilities +def dtype_to_str(dtype: torch.dtype): + if dtype == torch.float16: + return "f16" + if dtype == torch.bfloat16: + return "bf16" + if dtype == torch.float32: + return "f32" + raise ValueError(f"Unsupported dtype {dtype}") + + +def make_rand_lora_weight_tensor(k: int, + n: int, + num_loras: int, + dtype: torch.dtype, + device: str = "cuda") -> torch.Tensor: + + # LoRA weights column major + return torch.rand((num_loras, n, k), dtype=dtype).to(device) + + +def make_rand_tensors( + a_shape: Tuple[int], + b_shape: Tuple[int], + c_shape: Tuple[int], + a_dtype: torch.dtype, + b_dtype: torch.dtype, + c_dtype: torch.dtype, + num_slices: int, + device: str = "cuda", +) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]: + """ + Make LoRA input/output matrices. + """ + A = torch.rand(a_shape, dtype=a_dtype).to(device) + + # LoRA weights column major + Bs = [ + torch.rand(b_shape, dtype=b_dtype).to(device) + for _ in range(num_slices) + ] + + C = torch.zeros(c_shape, dtype=c_dtype).to(device) + return A, Bs, C + + +def make_prompt_lora_mapping(num_prompts: int, num_active_loras: int, + sort_by_lora_id: bool, + device: str) -> torch.Tensor: + """ + All prompts are mapped to a Lora ID in range [0, num_active_loras). + where 0 refers to first lora, 1 refers to second lora and so on. + """ + assert num_active_loras > 0 + + if not sort_by_lora_id: + return torch.randint(0, + num_active_loras, (num_prompts, ), + dtype=torch.long) + + # Divide LoRAs equally and in order. + part_size = num_prompts // num_active_loras + part_size = max(part_size, 1) + + lora_id = 0 + prompt_lora_mapping = [] + while len(prompt_lora_mapping) < num_prompts: + prompt_lora_mapping.extend([lora_id] * part_size) + lora_id = lora_id + 1 if lora_id + 1 < num_active_loras else lora_id + return torch.tensor(prompt_lora_mapping[:num_prompts], + dtype=torch.long, + device=device) + + +def make_token_lora_mapping(num_tokens: int, num_prompts: int, + prompt_lora_mapping: torch.Tensor, + seq_len_tensor: torch.Tensor, device: str): + """ + Make token_lora_mapping from prompt_lora_mapping and seq_lens_tensor + """ + assert prompt_lora_mapping.shape[0] == num_prompts + + # token to lora index mapping + token_lora_mapping = [0] * num_tokens + current_offset = 0 + for b_id in range(num_prompts): + lora_index = prompt_lora_mapping[b_id].item() + s = current_offset + e = s + seq_len_tensor[b_id].item() + token_lora_mapping[s:e] = [lora_index] * (e - s) + current_offset += seq_len_tensor[b_id].item() + + return torch.tensor(token_lora_mapping, dtype=torch.long, device=device) + + +def ref_group_gemm(ref_out: torch.Tensor, input: torch.Tensor, + lora_weights: List[torch.Tensor], + seq_lens_cpu: torch.Tensor, + prompt_lora_mapping_cpu: torch.Tensor, scaling: float, + add_inputs: Optional[bool]): + """ + Torch group gemm reference implementation to test correctness of + benchmarking operations. + """ + batches = seq_lens_cpu.size(0) + out_list = [] + current_offset = 0 + for lora_index, b_length in zip(range(batches), seq_lens_cpu): + x = input[current_offset:b_length + current_offset, :] + current_offset += b_length + w = lora_weights[prompt_lora_mapping_cpu[lora_index]] + result = torch.nn.functional.linear(x, w) + result *= scaling + out_list.append(result) + torch.cat(out_list, dim=0) + + cat_result = torch.cat(out_list, dim=0) + + if add_inputs: + ref_out += cat_result + else: + ref_out.copy_(cat_result) + + +class OpType(Enum): + """ + LoRA Ops to benchmark and its properties. + """ + SGMV_SHRINK = auto() + BGMV_SHRINK = auto() + SGMV_EXPAND = auto() + BGMV_EXPAND = auto() + BGMV_EXPAND_SLICE = auto() + + @staticmethod + def from_str(s: str) -> "OpType": + if s.lower() == 'sgmv_shrink': + return OpType.SGMV_SHRINK + if s.lower() == 'sgmv_expand': + return OpType.SGMV_EXPAND + if s.lower() == 'bgmv_shrink': + return OpType.BGMV_SHRINK + if s.lower() == 'bgmv_expand': + return OpType.BGMV_EXPAND + if s.lower() == "bgmv_expand_slice": + return OpType.BGMV_EXPAND_SLICE + raise ValueError(f"Unrecognized str {s} to convert to OpType") + + def is_shrink_fn(self) -> bool: + return self in [OpType.SGMV_SHRINK, OpType.BGMV_SHRINK] + + def is_expand_fn(self) -> bool: + return self in [OpType.SGMV_EXPAND, OpType.BGMV_EXPAND] + + def is_prefill_op(self) -> bool: + return self in [OpType.SGMV_SHRINK, OpType.SGMV_EXPAND] + + def is_decode_op(self) -> bool: + return self in [ + OpType.BGMV_SHRINK, OpType.BGMV_EXPAND, OpType.BGMV_EXPAND_SLICE + ] + + def is_expand_slice_fn(self) -> bool: + return self in [OpType.BGMV_EXPAND_SLICE] + + def num_slices(self) -> List[int]: + if self in [OpType.SGMV_EXPAND, OpType.SGMV_SHRINK]: + # SGMV kernels supports slices + return [1, 2, 3] + if self in [OpType.BGMV_SHRINK, OpType.BGMV_EXPAND]: + return [1] + if self in [OpType.BGMV_EXPAND_SLICE]: + return [2, 3] + raise ValueError(f"Unrecognized OpType {self}") + + def mkn(self, batch_size: int, seq_length: int, hidden_size: int, + lora_rank: int) -> Tuple[int, int, int]: + num_tokens = batch_size * seq_length + if self.is_shrink_fn(): + m = num_tokens + k = hidden_size + n = lora_rank + else: + assert self.is_expand_fn() or self.is_expand_slice_fn() + m = num_tokens + k = lora_rank + n = hidden_size + return m, k, n + + def matmul_dtypes( + self, op_dtype: torch.dtype + ) -> Tuple[torch.dtype, torch.dtype, torch.dtype]: + """ + return a type, b type and c type for A x B = C + """ + if self.is_shrink_fn(): + return op_dtype, op_dtype, torch.float32 + else: + assert self.is_expand_fn() or self.is_expand_slice_fn() + return torch.float32, op_dtype, op_dtype + + def matmul_shapes( + self, batch_size: int, seq_length: int, hidden_size: int, + lora_rank: int, num_loras: int, + num_slices: int) -> Tuple[Tuple[int], Tuple[int], Tuple[int]]: + """ + Given num_slices, return the shapes of the A, B, and C matrices + in A x B = C, for the op_type + """ + m, k, n = self.mkn(batch_size, seq_length, hidden_size, lora_rank) + + b_shape = (num_loras, n, k) # col-major + if self == OpType.SGMV_SHRINK: + # SGMV shrink supports num_slices inherently in the kernel + return ((m, k), b_shape, (num_slices, m, n)) + if self == OpType.SGMV_EXPAND: + # SGMV expand supports num_slices inherently in the kernel + return ((num_slices, m, k), b_shape, (m, n * num_slices)) + if self == OpType.BGMV_SHRINK: + return ((m, k), b_shape, (m, n)) + if self == OpType.BGMV_EXPAND: + return ((m, k), b_shape, (m, n)) + if self == OpType.BGMV_EXPAND_SLICE: + return ((num_slices, m, k), b_shape, (m, n * num_slices)) + + raise ValueError(f"Unrecognized op_type {self}") + + def bench_fn(self) -> Callable: + + def emulate_bgmv_expand_slice(kwargs_list: List[Dict[str, Any]]): + for x in kwargs_list: + bgmv_expand_slice(**x) + + if self == OpType.SGMV_SHRINK: + return sgmv_shrink + if self == OpType.SGMV_EXPAND: + return sgmv_expand + if self == OpType.BGMV_SHRINK: + return bgmv_shrink + if self == OpType.BGMV_EXPAND: + return bgmv_expand + if self == OpType.BGMV_EXPAND_SLICE: + return emulate_bgmv_expand_slice + raise ValueError(f"Unrecognized optype {self}") + + def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor, + lora_weights: List[torch.Tensor], + **kwargs) -> Callable: + """Each benchmark operation expected the input, lora_weights and outputs + in a slightly different format. Refer to self.matmul_shapes(). + run_ref_group_gemm accounts for those differences in executing a + reference group gemm for correctness testing. + """ + w_dtype = lora_weights[0].dtype + num_slices = len(lora_weights) + if self == OpType.SGMV_SHRINK: + for slice_idx in range(num_slices): + ref_group_gemm(ref_out=output[slice_idx, :], + input=input, + lora_weights=lora_weights[slice_idx], + **kwargs) + if self == OpType.SGMV_EXPAND: + hidden_size = lora_weights[0].shape[1] + for slice_idx in range(num_slices): + slice_offset = slice_idx * hidden_size + ref_group_gemm( + ref_out=output[:, slice_offset:slice_offset + hidden_size], + input=input[slice_idx].clone().to(dtype=w_dtype), + lora_weights=lora_weights[slice_idx], + **kwargs) + if self == OpType.BGMV_SHRINK: + assert num_slices == 1 + ref_group_gemm(ref_out=output, + input=input, + lora_weights=lora_weights[0], + **kwargs) + if self == OpType.BGMV_EXPAND: + assert num_slices == 1 + ref_group_gemm(ref_out=output, + input=input.clone().to(dtype=w_dtype), + lora_weights=lora_weights[0], + **kwargs) + if self == OpType.BGMV_EXPAND_SLICE: + hidden_size = lora_weights[0].shape[1] + for slice_idx in range(num_slices): + slice_offset = slice_idx * hidden_size + ref_group_gemm( + ref_out=output[:, slice_offset:slice_offset + hidden_size], + input=input[slice_idx].clone().to(dtype=w_dtype), + lora_weights=lora_weights[slice_idx], + **kwargs) + raise ValueError(f"Unrecognized optype {self}") + + +@dataclass +class BenchmarkContext: + """ + LoRA benchmark context + """ + batch_size: int + hidden_size: int + num_loras: int + num_active_loras: int + lora_rank: int + sort_by_lora_id: bool + dtype: torch.dtype + seq_length: Optional[int] = None + num_slices: Optional[int] = None # num_slices for slice based ops + + def with_seq_length(self, seq_length: int) -> "BenchmarkContext": + ctx = copy.copy(self) + ctx.seq_length = seq_length + return ctx + + def with_num_slices(self, num_slices: int) -> "BenchmarkContext": + ctx = copy.copy(self) + ctx.num_slices = num_slices + return ctx + + def bench_label(self) -> str: + return f"lora-{self.dtype}" + + def bench_sublabel(self, op_type: OpType) -> str: + m, k, n = op_type.mkn(self.batch_size, self.seq_length, + self.hidden_size, self.lora_rank) + desc = { + 'bs': self.batch_size, + 'sl': self.seq_length, + 'm': m, + 'k': k, + 'n': n, + 'num_loras': self.num_loras, + 'sort_by_lora': self.sort_by_lora_id, + 'num_slices': self.num_slices, + } + return json.dumps(desc) + + +@dataclass +class BenchmarkTensors: + """ + Input/Output tensors used for benchmarks + """ + # matmul tensors + input: torch.Tensor + lora_weights_lst: List[torch.Tensor] + output: torch.Tensor + # metadata tensors + seq_lens: torch.Tensor + seq_start_loc: torch.Tensor + prompt_lora_mapping: torch.Tensor + token_lora_mapping: torch.Tensor + + def io_types(self) -> str: + return (f"{dtype_to_str(self.input.dtype)}x" + f"{dtype_to_str(self.lora_weights_lst[0].dtype)}=>" + f"{dtype_to_str(self.output.dtype)}") + + @staticmethod + def make(ctx: BenchmarkContext, + op_type: OpType, + device: str = "cuda") -> "BenchmarkTensors": + + # Make input / output matmul tensors. + a_shape, b_shape, c_shape = op_type.matmul_shapes( + ctx.batch_size, ctx.seq_length, ctx.hidden_size, ctx.lora_rank, + ctx.num_loras, ctx.num_slices) + a_type, b_type, c_type = op_type.matmul_dtypes(ctx.dtype) + input_tensor, lora_weights, output_tensor = \ + make_rand_tensors(a_shape, b_shape, c_shape, a_type, b_type, c_type, + num_slices = ctx.num_slices) + + # Make metadata tensors. + # Keep the metadata tensors in the CPU for further processing if needed. + # The tensors get moved to the GPU before benchmarking. + assert ctx.num_active_loras <= ctx.num_loras + total_tokens = ctx.batch_size * ctx.seq_length + + # Prepare seq lens tensor + seq_len_tensor = torch.randint(ctx.seq_length, ctx.seq_length + 1, + (ctx.batch_size, )) + # Prepare seq_start_loc tensor + seq_start_loc_tensor = torch.cumsum(torch.tensor( + [0] + seq_len_tensor[:-1].tolist(), dtype=torch.long), + dim=0) + assert total_tokens == seq_len_tensor.sum() + # Prepare prompt lora indices tensor + prompt_lora_indices_tensor = make_prompt_lora_mapping( + ctx.batch_size, ctx.num_active_loras, ctx.sort_by_lora_id, "cpu") + # Prepare token lora indices tensor + token_lora_indices_tensor = make_token_lora_mapping( + total_tokens, ctx.batch_size, prompt_lora_indices_tensor, + seq_len_tensor, "cpu") + + return BenchmarkTensors(input_tensor, lora_weights, output_tensor, + seq_len_tensor, seq_start_loc_tensor, + prompt_lora_indices_tensor, + token_lora_indices_tensor) + + def sanity_check(self) -> None: + """ + Fails asserts when non-conformality is detected. + """ + num_tokens = self.input.shape[-2] + # check metadata tensors + assert torch.sum(self.seq_lens) == num_tokens + num_seqs = self.seq_lens.shape[0] + assert self.seq_start_loc.shape[0] == num_seqs + assert self.prompt_lora_mapping.shape[0] == num_seqs + assert self.token_lora_mapping.shape[0] == num_tokens + + def to_device(self, device: str): + """ + Transfer tensors to device if the tensors aren't already on the device + """ + + def to_device(tensor: torch.Tensor): + if tensor.device != device: + tensor = tensor.to(device=device) + return tensor + + self.input = to_device(self.input) + self.output = to_device(self.output) + self.seq_lens = to_device(self.seq_lens) + self.seq_start_loc = to_device(self.seq_start_loc) + self.prompt_lora_mapping = to_device(self.prompt_lora_mapping) + self.token_lora_mapping = to_device(self.token_lora_mapping) + for i in range(len(self.lora_weights_lst)): + self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i]) + + def metadata(self) -> Tuple[int, int, int]: + """ + Return num_seqs, num_tokens and max_seq_len + """ + num_seqs = self.seq_lens.shape[0] + num_tokens = self.token_lora_mapping.shape[0] + max_seq_len = torch.max(self.seq_lens).item() + num_slices = len(self.lora_weights_lst) + return num_seqs, num_tokens, max_seq_len, num_slices + + def convert_to_sgmv_benchmark_tensors(self): + """ + For sgmv punica kernels, when consecutive sequences have the + same LoRA ID, we just merge them together. + This happens in punica.py::compute_metadata + """ + + # Collapse seq_lens and seq_start_loc + _, seq_lens = torch.unique_consecutive(self.token_lora_mapping, + return_counts=True) + cum_result = torch.cumsum(seq_lens, dim=0) + seq_start_loc = torch.zeros_like(seq_lens) + seq_start_loc[1:].copy_(cum_result[:-1]) + + # Collapse prompt mapping + prompt_lora_mapping = torch.unique_consecutive( + self.prompt_lora_mapping) + + assert torch.sum(seq_lens) == torch.sum(self.seq_lens), \ + f"dont match - new {torch.sum(seq_lens)} vs {torch.sum(self.seq_lens)}" + + self.prompt_lora_mapping = prompt_lora_mapping.to( + dtype=self.prompt_lora_mapping.dtype) + self.seq_lens = seq_lens.to(dtype=self.seq_lens.dtype) + self.seq_start_loc = seq_start_loc.to(dtype=self.seq_start_loc.dtype) + + def as_sgmv_shrink_kwargs(self) -> Dict[str, Any]: + self.convert_to_sgmv_benchmark_tensors() + self.sanity_check() + self.to_device(self.input.device) + + num_seqs, num_tokens, max_seq_len, num_slices = self.metadata() + + # Sanity check matrix shapes. + i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[ + 0].shape, self.output.shape + # Expected input shape [num_tokens, hidden_size] + assert len(i_shape) == 2 + assert i_shape[0] == num_tokens + hidden_size = i_shape[1] + # Expected lora weight shape [num_loras, lora_rank, hidden_size] + assert len(lw_shape) == 3 + assert lw_shape[2] == hidden_size + lora_rank = lw_shape[1] + # Expected output shape [num_slices, num_tokens, lora_rank] + assert len(o_shape) == 3 + assert o_shape == (num_slices, num_tokens, lora_rank) + + return { + 'inputs': self.input, + 'lora_a_weights': self.lora_weights_lst, + 'output_tensor': self.output, + 'b_seq_start_loc': self.seq_start_loc, + 'seq_len_tensor': self.seq_lens, + 'lora_indices_tensor': self.prompt_lora_mapping, + 'batches': num_seqs, + 'max_seq_length': max_seq_len, + 'token_nums': num_tokens, + 'scaling': 1.0, + } + + def as_sgmv_expand_kwargs(self, add_inputs: bool) -> Dict[str, Any]: + + self.convert_to_sgmv_benchmark_tensors() + self.sanity_check() + self.to_device(self.input.device) + + num_seqs, num_tokens, max_seq_len, num_slices = self.metadata() + + # Sanity check matrix shapes. + i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[ + 0].shape, self.output.shape + # Expected input shape : [num_slices, num_tokens, lora_rank] + assert len(i_shape) == 3 + assert i_shape[0] == num_slices + assert i_shape[1] == num_tokens + lora_rank = i_shape[2] + # Expected lora weight shape : [num_lora, hidden_size, lora_rank] + assert len(lw_shape) == 3 + assert lw_shape[2] == lora_rank + hidden_size = lw_shape[1] + # Expected output shape : [num_tokens, hidden_size * num_slices] + assert len(o_shape) == 2 + assert o_shape == (num_tokens, hidden_size * num_slices) + + return { + 'inputs': self.input, + 'lora_b_weights': self.lora_weights_lst, + 'output_tensor': self.output, + 'b_seq_start_loc': self.seq_start_loc, + 'seq_len_tensor': self.seq_lens, + 'lora_indices_tensor': self.prompt_lora_mapping, + 'batches': num_seqs, + 'max_seq_length': max_seq_len, + 'token_nums': num_tokens, + 'offset_start': 0, + 'add_inputs': add_inputs, + } + + def as_bgmv_shrink_kwargs(self) -> Dict[str, Any]: + assert len(self.lora_weights_lst) == 1 + self.to_device(self.input.device) + + _, num_tokens, _, _ = self.metadata() + # Sanity check shapes + i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[ + 0].shape, self.output.shape + # Expected input shape [num_tokens, hidden_size] + assert len(i_shape) == 2 + assert i_shape[0] == num_tokens + hidden_size = i_shape[1] + # Expected lora weight shape [num_loras, lora_rank, hidden_size] + assert len(lw_shape) == 3 + assert lw_shape[2] == hidden_size + lora_rank = lw_shape[1] + # Expected output shape [num_tokens, lora_rank] + assert len(o_shape) == 2 + assert o_shape == (num_tokens, lora_rank) + + return { + 'inputs': self.input, + 'lora_a_weights': self.lora_weights_lst[0], + 'output_tensor': self.output, + 'lora_indices_tensor': self.token_lora_mapping, + 'scaling': 1.0 + } + + def as_bgmv_expand_kwargs(self, add_inputs: bool): + assert len(self.lora_weights_lst) == 1 + self.to_device(self.input.device) + + _, num_tokens, _, _ = self.metadata() + # Sanity check shapes + i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[ + 0].shape, self.output.shape + # Expected input shape [num_tokens, lora_rank] + assert len(i_shape) == 2 + assert i_shape[0] == num_tokens + lora_rank = i_shape[1] + # Expected lora weight shape [num_loras, hidden_size, lora_rank] + assert len(lw_shape) == 3 + assert lw_shape[2] == lora_rank + hidden_size = lw_shape[1] + # Expected output shape [num_tokens, hidden_size] + assert len(o_shape) == 2 + assert o_shape == (num_tokens, hidden_size) + + return { + 'inputs': self.input, + 'lora_b_weights': self.lora_weights_lst[0], + 'output_tensor': self.output, + 'lora_indices_tensor': self.token_lora_mapping, + 'add_inputs': add_inputs + } + + def as_bgmv_expand_slice_kwargs(self, add_inputs: bool) -> Dict[str, Any]: + + _, num_tokens, _, num_slices = self.metadata() + # Sanity check shapes + i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[ + 0].shape, self.output.shape + # Expected input shape [num_slices, num_tokens, lora_rank] + assert len(i_shape) == 3 + assert i_shape[0] == num_slices + assert i_shape[1] == num_tokens + lora_rank = i_shape[2] + # Expected lora weight shape [num_loras, hidden_size, lora_rank] + assert len(lw_shape) == 3 + assert lw_shape[2] == lora_rank + hidden_size = lw_shape[1] + # Expected output shape [num_tokens, hidden_size * num_slices] + assert len(o_shape) == 2 + assert o_shape == (num_tokens, hidden_size * num_slices) + + self.to_device(self.input.device) + + kwargs_list = [] + for i in range(num_slices): + kwargs_list.append({ + 'inputs': self.input[i], + 'lora_b_weights': self.lora_weights_lst[i], + 'output_tensor': self.output, + 'lora_indices_tensor': self.token_lora_mapping, + 'slice_offset': i * hidden_size, + 'slice_size': hidden_size, + 'add_inputs': add_inputs, + }) + return {'kwargs_list': kwargs_list} + + def bench_fn_kwargs(self, + op_type: OpType, + add_inputs: Optional[bool] = None) -> Dict[str, Any]: + if op_type.is_shrink_fn(): + assert add_inputs is None + else: + assert add_inputs is not None + + if op_type == OpType.SGMV_SHRINK: + return self.as_sgmv_shrink_kwargs() + if op_type == OpType.SGMV_EXPAND: + return self.as_sgmv_expand_kwargs(add_inputs) + if op_type == OpType.BGMV_SHRINK: + return self.as_bgmv_shrink_kwargs() + if op_type == OpType.BGMV_EXPAND: + return self.as_bgmv_expand_kwargs(add_inputs) + if op_type == OpType.BGMV_EXPAND_SLICE: + return self.as_bgmv_expand_slice_kwargs(add_inputs) + raise ValueError(f"Unrecognized optype {self}") + + def test_correctness(self, op_type: OpType, + expand_fn_add_inputs: Optional[bool]) -> bool: + """ + Test correctness of op_type implementation against a grouped gemm + reference implementation. + """ + seq_lens_cpu = self.seq_lens.to(device="cpu") + prompt_lora_mapping_cpu = self.prompt_lora_mapping.to(device="cpu") + ref_output = self.output.clone() + + self.output.zero_() + op_type.bench_fn()( + **self.bench_fn_kwargs(op_type, expand_fn_add_inputs)) + + op_type.run_ref_group_gemm( + ref_output, + self.input, + self.lora_weights_lst, + seq_lens_cpu=seq_lens_cpu, + prompt_lora_mapping_cpu=prompt_lora_mapping_cpu, + scaling=1.0, + add_inputs=expand_fn_add_inputs) + + rtol, atol = { + torch.float16: (6e-2, 6e-2), + torch.bfloat16: (6e-2, 6e-2), + torch.float32: (1e-2, 1e-2), + }[self.output.dtype] + + return torch.allclose(ref_output, self.output, rtol=rtol, atol=atol) + + +def bench_optype(ctx: BenchmarkContext, + arg_pool_size: int, + op_type: OpType, + cuda_graph_nops: Optional[int] = None, + expand_fn_add_inputs: Optional[bool] = None, + test_correctness: bool = False) -> TMeasurement: + + assert arg_pool_size >= 1 + if op_type.is_shrink_fn(): + assert expand_fn_add_inputs is None + else: + assert expand_fn_add_inputs is not None + + # BenchmarkContext -> BenchmarkTensors + bench_tensors : List[BenchmarkTensors] = \ + [BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size)] + for bt in bench_tensors: + bt.sanity_check() + + # Test correctness of our implementation. + if test_correctness: + assert all([ + bt.test_correctness(op_type, expand_fn_add_inputs) + for bt in bench_tensors + ]) + + # BenchmarkTensors -> Dict (kwargs) + kwargs_list = [ + bt.bench_fn_kwargs(op_type, add_inputs=expand_fn_add_inputs) + for bt in bench_tensors + ] + + # Clear LoRA optimization hash-maps. + _LORA_A_PTR_DICT.clear() + _LORA_B_PTR_DICT.clear() + # Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are setup + for kwargs in kwargs_list: + op_type.bench_fn()(**kwargs) + torch.cuda.synchronize() + + # Merge into a single kwargs and qualify arguments as ArgPool + kwargs = {k: ArgPool([]) for k in kwargs_list[0]} + for _kwargs in kwargs_list: + for k, v in _kwargs.items(): + kwargs[k].values.append(v) + + describe_args = (f"add_inputs={expand_fn_add_inputs}" + if expand_fn_add_inputs is not None else "") + description = ( + f"{op_type.name}({describe_args}) ({bench_tensors[0].io_types()})") + + cuda_graph_params = None + if cuda_graph_nops: + cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops) + timer = None + with Bench(cuda_graph_params, + ctx.bench_label(), ctx.bench_sublabel(op_type), description, + op_type.bench_fn(), **kwargs) as bench: + timer = bench.run() + return timer + + +def bench_torch_mm(ctx: BenchmarkContext, + arg_pool_size: int, + op_type: OpType, + cuda_graph_nops: Optional[int] = None) -> TMeasurement: + """ + Benchmark basic torch.mm as a roofline. + + When all the input tokens have the same LoRA ID, the LoRA kernels are just + a matmul. This torch.mm benchmark serves as a roofline for that case. + + input op_type is used in determining the m, k, n dimensions for the matmul. + """ + + batch_size, hidden_size, lora_rank, seq_length, dtype = (ctx.batch_size, + ctx.hidden_size, + ctx.lora_rank, + ctx.seq_length, + ctx.dtype) + + m, k, n = op_type.mkn(batch_size, seq_length, hidden_size, lora_rank) + # For a fairer comparison. + n = n * ctx.num_slices + + # Get matmul input and output tensors for A x B = C + As, Bs, Cs = [], [], [] + for _ in range(arg_pool_size): + As.append(torch.rand((m, k), dtype=dtype).to("cuda")) + Bs.append(torch.rand((n, k), dtype=dtype).to("cuda").t()) + Cs.append(torch.rand((m, n), dtype=dtype).to("cuda")) + + # Make torch.mm kwargs + mm_kwargs = {'input': ArgPool(As), 'mat2': ArgPool(Bs), 'out': ArgPool(Cs)} + + description = ( + f"single-lora roofline using torch.mm ({dtype_to_str(dtype)}" + f"x{dtype_to_str(dtype)}" + f"=>{dtype_to_str(dtype)})") + cuda_graph_params = None + if cuda_graph_nops: + cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops) + with Bench(cuda_graph_params, ctx.bench_label(), + ctx.bench_sublabel(op_type), description, torch.mm, + **mm_kwargs) as bench: + return bench.run() + + +# runner +def use_cuda_graph_recommendation() -> str: + return """ + Triton kernels have a significant launch overhead with + launched directly via python. This overhead is more noticeable + for small the problem sizes. For these cases, it is recommended + to use the script with `--cuda-graph-nops N` to benchmark N + consecutive invocations of the benchmarking operations from + inside a CUDA Graph. Note that the returned measurement is for N + invocations of the operation. + """ + + +def print_timers(timers: List[TMeasurement], + args: Optional[argparse.Namespace] = None): + compare = TBenchmark.Compare(timers) + compare.print() + + if args and args.cuda_graph_nops: + print( + f"Note : The timings reported above is for {args.cuda_graph_nops} " + "consecutive invocations of the benchmarking functions. " + f"Please divide by {args.cuda_graph_nops} for single invocation " + "timings.") + + print("Note on Comparison with torch.mm : The torch.mm numbers are " + "benchmark numbers of a simple matmul emulating the single lora " + "case. It is provided as a roofline for comparing our LoRA Kernel " + "implementations. It is expected that the LoRA kernels will be " + "slower than torch.mm in cases where num_loras is big. But for " + "small num_loras the goal should be to match the torch.mm numbers.") + + +def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]): + + if args.cuda_graph_nops is not None: + assert args.cuda_graph_nops > 0 + print(f"Benchmarking {args.cuda_graph_nops} invocations inside a CUDA " + "Graph") + else: + print(f"CUDA Graphs not enabled.\n{use_cuda_graph_recommendation()}") + + timers = [] + for bench_ctx in bench_ctxs: + for seq_len in args.seq_lengths: + bench_ops: List[OpType] = [] + if seq_len == 1: + # bench all decode ops + bench_ops = [op for op in args.op_types if op.is_decode_op()] + else: + # bench all prefill ops + bench_ops = [op for op in args.op_types if op.is_prefill_op()] + + seq_len_timers = [] + for bench_op in bench_ops: + for num_slices in bench_op.num_slices(): + _ctx = bench_ctx.with_seq_length(seq_len).with_num_slices( + num_slices) + # Benchmark torch.mm as a roofline + seq_len_timers.append( + bench_torch_mm(_ctx, args.arg_pool_size, bench_op, + args.cuda_graph_nops)) + + # Benchmark bench_op + expand_fn_add_inputs = [ + None + ] if bench_op.is_shrink_fn() else args.expand_fn_add_inputs + for add_input_arg in expand_fn_add_inputs: + seq_len_timers.append( + bench_optype(_ctx, args.arg_pool_size, bench_op, + args.cuda_graph_nops, add_input_arg, + args.test_correctness)) + + print_timers(seq_len_timers) + timers.extend(seq_len_timers) + + # Result stdout dump + print("== All Results ====") + print_timers(timers, args) + + if args.output_directory: + # Result file dump + od = Path(args.output_directory) + if not od.exists(): + od.mkdir() + + timestamp = int(time.time()) + pkl_file = od / f"lora_bench-{timestamp}.pkl" + print(f"Writing benchmarks to {pkl_file}") + with open(pkl_file, "wb") as f: + pickle.dump(timers, f) + + +def as_benchmark_contexts(hidden_sizes: List[int], lora_ranks: List[int], + args: argparse.Namespace) -> List[BenchmarkContext]: + + ctxs: List[BenchmarkContext] = [] + for batch_size, hidden_size, lora_rank, num_loras, sort_by_lora_id in product( # noqa + args.batch_sizes, list(hidden_sizes), lora_ranks, args.num_loras, + args.sort_by_lora_id): + ctxs.append( + BenchmarkContext( + batch_size=batch_size, + hidden_size=hidden_size, + lora_rank=lora_rank, + num_loras=num_loras, + num_active_loras=args.num_active_loras + if args.num_active_loras else num_loras, + # To be filled based on the OpType to benchmark + seq_length=None, + sort_by_lora_id=sort_by_lora_id, + dtype=args.dtype, + # To be filled based on the OpType to benchmark + num_slices=None)) + + return ctxs + + +def run_list_bench(args: argparse.Namespace): + print(args) + + print("List bench :\n" + f" Hidden Sizes {args.hidden_sizes}" + f" LoRA Ranks {args.lora_ranks}") + + # Get all benchmarking contexts + bench_contexts: List[BenchmarkContext] = as_benchmark_contexts( + hidden_sizes=args.hidden_sizes, lora_ranks=args.lora_ranks, args=args) + + run(args, bench_contexts) + + +def run_range_bench(args: argparse.Namespace): + print(args) + + hidden_sizes = list( + range(args.hidden_sizes_start, args.hidden_sizes_end + 1, + args.hidden_sizes_increment)) + lora_ranks = list( + range(args.lora_ranks_start, args.lora_ranks_end + 1, + args.lora_ranks_increment)) + + print("Range bench :\n" + f" Hidden Sizes {hidden_sizes}" + f" LoRA Ranks {lora_ranks}") + + # Get all benchmarking contexts + bench_contexts: List[BenchmarkContext] = as_benchmark_contexts( + hidden_sizes=hidden_sizes, lora_ranks=lora_ranks, args=args) + + run(args, bench_contexts) + + +def run_model_bench(args: argparse.Namespace): + print(args) + + def hidden_sizes_from_model(model: str, tp_size: int) -> set[int]: + hidden_sizes = set() + for KN, tp_split_dim in WEIGHT_SHAPES[model]: + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + hidden_sizes.add(KN[1]) + return hidden_sizes + + # Get all hidden sizes + hidden_sizes: set[int] = set() + for model_name, tp_size in product(args.models, args.tp_sizes): + hidden_sizes = hidden_sizes.union( + hidden_sizes_from_model(model_name, tp_size)) + + print("Model bench :\n" + f" Hidden Sizes {hidden_sizes}" + f" LoRA Ranks {args.lora_ranks}") + + # Get all benchmarking contexts + bench_contexts: List[BenchmarkContext] = as_benchmark_contexts( + hidden_sizes=hidden_sizes, lora_ranks=args.lora_ranks, args=args) + + run(args, bench_contexts) + + +if __name__ == '__main__': + + def to_torch_dtype(dt): + if dt == "torch.float16": + return torch.float16 + if dt == "torch.bfloat16": + return torch.bfloat16 + raise ValueError("unsupported dtype") + + def get_bool(s: str) -> bool: + return s.lower() in ['true', '1'] + + def add_common_command_args(p: argparse.ArgumentParser): + p.add_argument( + "--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['torch.float16', 'torch.bfloat16']") + + p.add_argument( + "--arg-pool-size", + type=int, + default=32, + help="Run profiles with a pool of input/output/meta tensors instead" + "of simply reusing the same tensors for all runs. A bigger arg-pool" + "mitigates hardware caching effects during benchmarking.") + + p.add_argument( + "--cuda-graph-nops", + type=int, + help=("when set profiling is done using cudagraph, " + "with the given number of operations in a graph." + "Note that the measurement returned is the time " + "taken for N consecutive executions of the benchmarking " + "functions, where N is the value of this argument.")) + p.add_argument("--num-loras", + nargs="+", + type=int, + default=DEFAULT_NUM_LORAS) + p.add_argument("--num-active-loras", + type=int, + default=None, + help="Active LoRAs. When None, all LoRAs are active") + p.add_argument("--sort-by-lora-id", + nargs="+", + type=get_bool, + default=DEFAULT_SORT_BY_LORA_IDS) + p.add_argument("--op-types", + nargs="+", + type=OpType.from_str, + default=list(OpType)) + p.add_argument('--seq-lengths', + nargs="+", + type=int, + default=DEFAULT_SEQ_LENGTHS) + p.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + p.add_argument("--expand-fn-add-inputs", + nargs="+", + type=get_bool, + default=DEFAULT_EXPAND_FN_ADD_INPUTS) + p.add_argument( + '-o', + '--output-directory', + type=str, + help=("Output directory to store a the list of benchmarking" + "TMeasurement objects as a pickle file")) + + p.add_argument( + "--test-correctness", + action='store_true', + help=("When enabled, the benchmarking functions are tested" + "for correctness before the actual benchmarking")) + + parser = FlexibleArgumentParser( + description=f""" +Benchmark LoRA kernels: + {use_cuda_graph_recommendation()} + + list_bench example: + python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 + + model_bench example: + python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 + + range_bench example: + python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8 + """, # noqa: E501 + formatter_class=argparse.RawTextHelpFormatter) + + subparsers = parser.add_subparsers(dest="cmd", required=True) + + list_parser = subparsers.add_parser("list_bench") + list_parser.add_argument("--hidden-sizes", + nargs="+", + type=int, + default=DEFAULT_HIDDEN_SIZES) + list_parser.add_argument("--lora-ranks", + nargs="+", + type=int, + default=DEFAULT_LORA_RANKS) + add_common_command_args(list_parser) + list_parser.set_defaults(func=run_list_bench) + + range_parser = subparsers.add_parser("range_bench") + range_parser.add_argument("--hidden-sizes-start", type=int, required=True) + range_parser.add_argument("--hidden-sizes-end", type=int, required=True) + range_parser.add_argument("--hidden-sizes-increment", + type=int, + required=True) + range_parser.add_argument("--lora-ranks-start", type=int, required=True) + range_parser.add_argument("--lora-ranks-end", type=int, required=True) + range_parser.add_argument("--lora-ranks-increment", + type=int, + required=True) + add_common_command_args(range_parser) + range_parser.set_defaults(func=run_range_bench) + + model_parser = subparsers.add_parser("model_bench") + model_parser.add_argument("--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys()) + model_parser.add_argument("--tp-sizes", + nargs="+", + type=int, + default=DEFAULT_TP_SIZES) + model_parser.add_argument("--lora-ranks", + nargs="+", + type=int, + default=DEFAULT_LORA_RANKS) + add_common_command_args(model_parser) + model_parser.set_defaults(func=run_model_bench) + + args = parser.parse_args() + args.func(args) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 8f538c21f7f7e..068830f02fb5e 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -1,6 +1,7 @@ import argparse import time from datetime import datetime +from itertools import product from typing import Any, Dict, List, Tuple, TypedDict import ray @@ -13,6 +14,9 @@ from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser +FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm( +) else torch.float8_e4m3fn + class BenchmarkConfig(TypedDict): BLOCK_SIZE_M: int @@ -80,8 +84,8 @@ def benchmark_config( a1_scale = torch.randn(1, dtype=torch.float32) a2_scale = torch.randn(1, dtype=torch.float32) - w1 = w1.to(torch.float8_e4m3fn) - w2 = w2.to(torch.float8_e4m3fn) + w1 = w1.to(FP8_DTYPE) + w2 = w2.to(FP8_DTYPE) input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32) @@ -141,28 +145,172 @@ def run(): return avg -def get_configs_compute_bound() -> List[Dict[str, int]]: - # Reduced search space for faster tuning. - # TODO(woosuk): Increase the search space and use a performance model to - # prune the search space. +def get_rocm_tuning_space(use_fp16): + block_mn_range = [16, 32, 64, 128, 256] + block_k_range = [16, 32, 64, 128, 256] + if not use_fp16: + block_k_range.remove(16) # BLOCK_K=16 not supported for fp8 + num_warps_range = [1, 2, 4, 8] + group_m_range = [1, 4, 8, 16, 32] + num_stage_range = [2] + waves_per_eu_range = [0] + matrix_instr_nonkdim_range = [16, 32] if use_fp16 else [] + kpack_range = [1, 2] if use_fp16 else [] + + param_ranges = { + "BLOCK_SIZE_M": block_mn_range, + "BLOCK_SIZE_N": block_mn_range, + "BLOCK_SIZE_K": block_k_range, + "GROUP_SIZE_M": group_m_range, + "num_warps": num_warps_range, + "num_stages": num_stage_range, + "waves_per_eu": waves_per_eu_range, + } + if use_fp16: + param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range + param_ranges["kpack"] = kpack_range + + return param_ranges + + +def get_configs_compute_bound(use_fp16) -> List[Dict[str, int]]: configs: List[BenchmarkConfig] = [] - for num_stages in [2, 3, 4, 5]: - for block_m in [16, 32, 64, 128, 256]: - for block_k in [64, 128, 256]: - for block_n in [32, 64, 128, 256]: - for num_warps in [4, 8]: - for group_size in [1, 16, 32, 64]: - configs.append({ - "BLOCK_SIZE_M": block_m, - "BLOCK_SIZE_N": block_n, - "BLOCK_SIZE_K": block_k, - "GROUP_SIZE_M": group_size, - "num_warps": num_warps, - "num_stages": num_stages, - }) + + if current_platform.is_rocm(): + param_ranges = get_rocm_tuning_space(use_fp16) + else: + # Reduced search space for faster tuning. + # TODO(woosuk): Increase the search space and use a performance model to + # prune the search space. + block_m_range = [16, 32, 64, 128, 256] + block_n_range = [32, 64, 128, 256] + block_k_range = [64, 128, 256] + num_warps_range = [4, 8] + group_m_range = [1, 16, 32, 64] + num_stage_range = [2, 3, 4, 5] + + param_ranges = { + "BLOCK_SIZE_M": block_m_range, + "BLOCK_SIZE_N": block_n_range, + "BLOCK_SIZE_K": block_k_range, + "GROUP_SIZE_M": group_m_range, + "num_warps": num_warps_range, + "num_stages": num_stage_range, + } + + keys, values = zip(*param_ranges.items()) + for config_values in product(*values): + config = dict(zip(keys, config_values)) + configs.append(config) return configs +def prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size, + search_space, is_fp16): + N1, K1 = shard_intermediate_size, hidden_size + N2, K2 = hidden_size, shard_intermediate_size // 2 + pruned_space_1 = prune_rocm_configs(num_tokens * 2, N1, K1, search_space, + is_fp16) + pruned_space_2 = prune_rocm_configs(num_tokens * 2, N2, K2, search_space, + is_fp16) + search_space = merge_unique_dicts(pruned_space_1, pruned_space_2) + return search_space + + +# The following code is inspired by ROCm/Triton GEMM tuning script: +# https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/tune_gemm.py#L89 +def prune_rocm_configs(M, N, K, configs, is_fp16=True): + pruned_configs = [] + elemBytes_a = 2 if is_fp16 else 1 + elemBytes_b = 2 if is_fp16 else 1 + + mfma = 16 if M < 32 or N < 32 else 32 + + # TODO (zhanglx): figure out the boundary between large and small gemms + large_gemm = False + if M >= 2048 and N >= 2048: + large_gemm = True + + for config in configs: + BLOCK_SIZE_M = config.get("BLOCK_SIZE_M") + BLOCK_SIZE_N = config.get("BLOCK_SIZE_N") + BLOCK_SIZE_K = config.get("BLOCK_SIZE_K") + num_warps = config.get("num_warps") + + if is_fp16: + matrix_instr_nonkdim = config.get("matrix_instr_nonkdim") + if matrix_instr_nonkdim > mfma: + continue + if mfma == 4 and BLOCK_SIZE_K < 64: + continue + # some layouts could not work properly in case + # number elements per thread is less 1 + if BLOCK_SIZE_M * BLOCK_SIZE_N < 64: + continue + SPLIT_K = config.get("SPLIT_K", 1) + GROUP_M = config.get("GROUP_SIZE_M") + if is_fp16: + if (matrix_instr_nonkdim > BLOCK_SIZE_M + or matrix_instr_nonkdim > BLOCK_SIZE_N): + continue + if (matrix_instr_nonkdim >= M + and matrix_instr_nonkdim != BLOCK_SIZE_M): + continue + if (matrix_instr_nonkdim >= N + and matrix_instr_nonkdim != BLOCK_SIZE_N): + continue + # Skip BLOCK_SIZE that is too large compare to M/N + # unless BLOCK_SIZE is already small enough + if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16: + continue + if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16: + continue + # skip large split_k when not necessary + if SPLIT_K != 1 and not need_split_k(M, N, K): + continue + # skip split_k that leads to EVEN_K = false + leap = SPLIT_K * BLOCK_SIZE_K + modv = K % leap + if modv != 0: + continue + # skip large GROUP_M + if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1: + continue + # out of shared memory resource + # TODO (zhanglx): This does not consider the LDS usage in the epilogue + LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b) + if LDS > 65536: + continue + # Skip small block sizes and num_warps for large gemm + # For fp16 and f8, we want to only use BLOCK_SIZE >= 64 + if large_gemm: + if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64: + continue + if BLOCK_SIZE_K < 64: + continue + if num_warps < 4: + continue + + pruned_configs.append(config) + + return pruned_configs + + +def need_split_k(SIZE_M, SIZE_N, SIZE_K): + return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024 + + +def merge_unique_dicts(list1, list2): + result = [] + combined_list = list1.copy() + combined_list.extend(list2) + for dictionary in combined_list: + if dictionary not in result: + result.append(dictionary) + return result + + @ray.remote(num_gpus=1) class BenchmarkWorker: @@ -170,6 +318,10 @@ def __init__(self, seed: int) -> None: torch.set_default_device("cuda") current_platform.seed_everything(seed) self.seed = seed + # Get the device ID to allocate tensors and kernels + # on the respective GPU. This is required for Ray to work + # correctly with multi-GPU tuning on the ROCm platform. + self.device_id = int(ray.get_gpu_ids()[0]) def benchmark( self, @@ -191,9 +343,13 @@ def benchmark( op_config = get_moe_configs(num_experts, shard_intermediate_size // 2, dtype_str) if op_config is None: - config = get_default_config(num_tokens, num_experts, - shard_intermediate_size, hidden_size, - topk, dtype_str) + config = get_default_config(num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype_str, + is_marlin=False) else: config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))] @@ -217,25 +373,33 @@ def tune( ) -> Dict[str, int]: best_config = None best_time = float("inf") - for config in tqdm(search_space): - try: - kernel_time = benchmark_config(config, - num_tokens, - num_experts, - shard_intermediate_size, - hidden_size, - topk, - dtype, - use_fp8_w8a8, - use_int8_w8a16, - num_iters=10) - except triton.runtime.autotuner.OutOfResources: - # Some configurations may be invalid and fail to compile. - continue - - if kernel_time < best_time: - best_time = kernel_time - best_config = config + if current_platform.is_rocm(): + is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16) + search_space = prune_rocm_search_space(num_tokens, + shard_intermediate_size, + hidden_size, search_space, + is_fp16) + + with torch.cuda.device(self.device_id): + for config in tqdm(search_space): + try: + kernel_time = benchmark_config(config, + num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + num_iters=20) + except triton.runtime.autotuner.OutOfResources: + # Some configurations may be invalid and fail to compile. + continue + + if kernel_time < best_time: + best_time = kernel_time + best_config = config now = datetime.now() print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}") assert best_config is not None @@ -244,12 +408,27 @@ def tune( def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: return { - "BLOCK_SIZE_M": config["BLOCK_SIZE_M"], - "BLOCK_SIZE_N": config["BLOCK_SIZE_N"], - "BLOCK_SIZE_K": config["BLOCK_SIZE_K"], - "GROUP_SIZE_M": config["GROUP_SIZE_M"], - "num_warps": config["num_warps"], - "num_stages": config["num_stages"], + "BLOCK_SIZE_M": + config["BLOCK_SIZE_M"], + "BLOCK_SIZE_N": + config["BLOCK_SIZE_N"], + "BLOCK_SIZE_K": + config["BLOCK_SIZE_K"], + "GROUP_SIZE_M": + config["GROUP_SIZE_M"], + "num_warps": + config["num_warps"], + "num_stages": + config["num_stages"], + **({ + "waves_per_eu": config["waves_per_eu"] + } if "waves_per_eu" in config else {}), + **({ + "matrix_instr_nonkdim": config["matrix_instr_nonkdim"] + } if "matrix_instr_nonkdim" in config else {}), + **({ + "kpack": config["kpack"] + } if "kpack" in config else {}), } @@ -275,7 +454,8 @@ def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int, def main(args: argparse.Namespace): print(args) - config = AutoConfig.from_pretrained(args.model) + config = AutoConfig.from_pretrained( + args.model, trust_remote_code=args.trust_remote_code) if config.architectures[0] == "DbrxForCausalLM": E = config.ffn_config.moe_num_experts topk = config.ffn_config.moe_top_k @@ -286,6 +466,11 @@ def main(args: argparse.Namespace): topk = config.num_experts_per_tok intermediate_size = config.intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size + elif config.architectures[0] == "DeepseekV3ForCausalLM": + E = config.n_routed_experts + topk = config.num_experts_per_tok + intermediate_size = config.moe_intermediate_size + shard_intermediate_size = 2 * intermediate_size // args.tp_size else: # Default: Mixtral. E = config.num_local_experts @@ -294,7 +479,7 @@ def main(args: argparse.Namespace): shard_intermediate_size = 2 * intermediate_size // args.tp_size hidden_size = config.hidden_size - dtype = config.torch_dtype + dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_int8_w8a16 = args.dtype == "int8_w8a16" @@ -322,7 +507,8 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]: return ray.get(outputs) if args.tune: - search_space = get_configs_compute_bound() + is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16) + search_space = get_configs_compute_bound(is_fp16) print(f"Start tuning over {len(search_space)} configurations...") start = time.time() @@ -354,7 +540,11 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]: parser.add_argument("--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1") - parser.add_argument("--tp-size", "-tp", type=int, default=2) + parser.add_argument("--tp-size", + "-tp", + "--tensor-parallel-size", + type=int, + default=2) parser.add_argument("--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], @@ -362,6 +552,7 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]: parser.add_argument("--seed", type=int, default=0) parser.add_argument("--batch-size", type=int, required=False) parser.add_argument("--tune", action="store_true") + parser.add_argument("--trust-remote-code", action="store_true") args = parser.parse_args() main(args) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 14eef00b855ac..219013a38134b 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -98,7 +98,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: start_time = time.perf_counter() # Using default kv_scale - k_scale = v_scale = 1.0 + k_scale = v_scale = torch.tensor(1.0, + dtype=torch.float32, + device=device) for _ in range(num_iters): if version == "v1": diff --git a/benchmarks/kernels/utils.py b/benchmarks/kernels/utils.py new file mode 100644 index 0000000000000..fee877b6f76fa --- /dev/null +++ b/benchmarks/kernels/utils.py @@ -0,0 +1,210 @@ +import dataclasses +from typing import Any, Callable, Iterable, Optional + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement + + +@dataclasses.dataclass +class CudaGraphBenchParams: + num_ops_in_cuda_graph: int + + +@dataclasses.dataclass +class ArgPool: + """ + When some argument of the benchmarking function is annotated with this type, + the benchmarking class (BenchMM) will collapse the argument to a pick a + single value from the given list of values, during function invocation. + For every invocation during a benchmarking run, it will choose a + different value from the list. + """ + values: Iterable[Any] + + def __getitem__(self, index): + return self.values[index] + + +class Bench: + + class ArgsIterator: + + def __init__(self, args_list, kwargs_list): + assert len(args_list) == len(kwargs_list) + self.args_list = args_list + self.kwargs_list = kwargs_list + self.n = len(self.args_list) + self.idx = 0 + + def __next__(self): + while True: + yield (self.args_list[self.idx], self.kwargs_list[self.idx]) + self.idx += 1 + self.idx = self.idx % self.n + + def reset(self): + self.idx = 0 + + @property + def n_args(self): + return self.n + + def __init__(self, cuda_graph_params: Optional[CudaGraphBenchParams], + label: str, sub_label: str, description: str, fn: Callable, + *args, **kwargs): + + self.cuda_graph_params = cuda_graph_params + self.use_cuda_graph = self.cuda_graph_params is not None + self.label = label + self.sub_label = sub_label + self.description = description + self.fn = fn + + # Process args + self._args = args + self._kwargs = kwargs + self.args_list, self.kwargs_list = self.collapse_argpool( + *args, **kwargs) + self.args_iterator = self.ArgsIterator(self.args_list, + self.kwargs_list) + + # Cudagraph runner + self.g = None + if self.use_cuda_graph: + self.g = self.get_cuda_graph_runner() + + # benchmark run params + self.min_run_time = 1 + + def collapse_argpool(self, *args, **kwargs): + argpool_args = [arg for arg in args if isinstance(arg, ArgPool)] + [ + arg for arg in kwargs.values() if isinstance(arg, ArgPool) + ] + if len(argpool_args) == 0: + return [args], [kwargs] + + # Make sure all argpools are of the same size + argpool_size = len(argpool_args[0].values) + assert all([argpool_size == len(arg.values) for arg in argpool_args]) + + # create copies of the args + args_list = [] + kwargs_list = [] + for _ in range(argpool_size): + args_list.append(args) + kwargs_list.append(kwargs.copy()) + + for i in range(argpool_size): + # collapse args; Just pick the ith value + args_list[i] = tuple([ + arg[i] if isinstance(arg, ArgPool) else arg + for arg in args_list[i] + ]) + + # collapse kwargs + kwargs_i = kwargs_list[i] + arg_pool_keys = [ + k for k, v in kwargs_i.items() if isinstance(v, ArgPool) + ] + for k in arg_pool_keys: + # again just pick the ith value + kwargs_i[k] = kwargs_i[k][i] + kwargs_list[i] = kwargs_i + + return args_list, kwargs_list + + def get_cuda_graph_runner(self): + assert self.use_cuda_graph + assert self.args_iterator is not None + + num_graph_ops = self.cuda_graph_params.num_ops_in_cuda_graph + + # warmup + args_it = self.args_iterator.__next__() + for _ in range(2): + args, kwargs = next(args_it) + self.fn(*args, **kwargs) + + self.args_iterator.reset() + args_it = self.args_iterator.__next__() + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for _ in range(num_graph_ops): + args, kwargs = next(args_it) + self.fn(*args, **kwargs) + return g + + def run_cudagrah(self) -> TMeasurement: + assert self.use_cuda_graph + globals = {'g': self.g} + + return TBenchmark.Timer( + stmt="g.replay()", + globals=globals, + label=( + f"{self.label}" + f" | cugraph {self.cuda_graph_params.num_ops_in_cuda_graph} ops" + ), + sub_label=self.sub_label, + description=self.description, + ).blocked_autorange(min_run_time=self.min_run_time) + + def run_eager(self) -> TMeasurement: + setup = None + stmt = None + globals = None + + has_arg_pool = self.args_iterator.n_args > 1 + if has_arg_pool: + setup = ''' + args_iterator.reset() + args_it = args_iterator.__next__() + ''' + stmt = ''' + args, kwargs = next(args_it) + fn(*args, **kwargs) + ''' + globals = {'fn': self.fn, 'args_iterator': self.args_iterator} + else: + # no arg pool. Just use the args and kwargs directly + self.args_iterator.reset() + args_it = self.args_iterator.__next__() + args, kwargs = next(args_it) + + setup = "" + stmt = ''' + fn(*args, **kwargs) + ''' + globals = {'fn': self.fn, 'args': args, 'kwargs': kwargs} + + return TBenchmark.Timer( + stmt=stmt, + setup=setup, + globals=globals, + label=self.label, + sub_label=self.sub_label, + description=self.description, + ).blocked_autorange(min_run_time=self.min_run_time) + + def run(self) -> TMeasurement: + timer = None + if self.use_cuda_graph: # noqa SIM108 + timer = self.run_cudagrah() + else: + timer = self.run_eager() + if not timer.meets_confidence() or timer.has_warnings: + print("Doesn't meet confidence - re-running bench ...") + return self.run() + return timer + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + if exc_type: + print(f"exc type {exc_type}") + print(f"exc value {exc_value}") + print(f"exc traceback {traceback}") diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 68f7ca1af05ad..714abca2a5ff7 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -4,6 +4,11 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin") + set(MACOSX_FOUND TRUE) +endif() + + # # Define environment variables for special configurations # @@ -13,6 +18,9 @@ endif() include_directories("${CMAKE_SOURCE_DIR}/csrc") + +set (ENABLE_NUMA TRUE) + # # Check the compile flags # @@ -22,18 +30,28 @@ if (CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64") "-mf16c" ) endif() -list(APPEND CXX_COMPILE_FLAGS - "-fopenmp" - "-DVLLM_CPU_EXTENSION") -execute_process(COMMAND cat /proc/cpuinfo - RESULT_VARIABLE CPUINFO_RET - OUTPUT_VARIABLE CPUINFO) +if(MACOSX_FOUND) + list(APPEND CXX_COMPILE_FLAGS + "-Xpreprocessor" + "-fopenmp" + "-DVLLM_CPU_EXTENSION") +else() + list(APPEND CXX_COMPILE_FLAGS + "-fopenmp" + "-DVLLM_CPU_EXTENSION") +endif() -if (NOT CPUINFO_RET EQUAL 0) - message(FATAL_ERROR "Failed to check CPU features via /proc/cpuinfo") +if (NOT MACOSX_FOUND) + execute_process(COMMAND cat /proc/cpuinfo + RESULT_VARIABLE CPUINFO_RET + OUTPUT_VARIABLE CPUINFO) + if (NOT CPUINFO_RET EQUAL 0) + message(FATAL_ERROR "Failed to check CPU features via /proc/cpuinfo") + endif() endif() + function (find_isa CPUINFO TARGET OUT) string(FIND ${CPUINFO} ${TARGET} ISA_FOUND) if(NOT ISA_FOUND EQUAL -1) @@ -54,12 +72,17 @@ endfunction() is_avx512_disabled(AVX512_DISABLED) -find_isa(${CPUINFO} "avx2" AVX2_FOUND) -find_isa(${CPUINFO} "avx512f" AVX512_FOUND) -find_isa(${CPUINFO} "POWER10" POWER10_FOUND) -find_isa(${CPUINFO} "POWER9" POWER9_FOUND) -find_isa(${CPUINFO} "asimd" ASIMD_FOUND) # Check for ARM NEON support -find_isa(${CPUINFO} "bf16" ARM_BF16_FOUND) # Check for ARM BF16 support +if (MACOSX_FOUND AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + set(APPLE_SILICON_FOUND TRUE) +else() + find_isa(${CPUINFO} "avx2" AVX2_FOUND) + find_isa(${CPUINFO} "avx512f" AVX512_FOUND) + find_isa(${CPUINFO} "POWER10" POWER10_FOUND) + find_isa(${CPUINFO} "POWER9" POWER9_FOUND) + find_isa(${CPUINFO} "asimd" ASIMD_FOUND) # Check for ARM NEON support + find_isa(${CPUINFO} "bf16" ARM_BF16_FOUND) # Check for ARM BF16 support +endif() + if (AVX512_FOUND AND NOT AVX512_DISABLED) list(APPEND CXX_COMPILE_FLAGS @@ -103,6 +126,9 @@ elseif (ASIMD_FOUND) set(MARCH_FLAGS "-march=armv8.2-a+dotprod+fp16") endif() list(APPEND CXX_COMPILE_FLAGS ${MARCH_FLAGS}) +elseif(APPLE_SILICON_FOUND) + message(STATUS "Apple Silicon Detected") + set(ENABLE_NUMA OFF) else() message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA or ARMv8 support.") endif() @@ -139,7 +165,12 @@ endif() message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") -list(APPEND LIBS numa) +if(ENABLE_NUMA) + list(APPEND LIBS numa) +else() + message(STATUS "NUMA is disabled") + add_compile_definitions(-DVLLM_NUMA_DISABLED) +endif() # # _C extension diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 40430dae10c5b..1c1c539819d05 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -58,8 +58,8 @@ function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS) # set(SRCS ${ORIG_SRCS}) set(CXX_SRCS ${ORIG_SRCS}) - list(FILTER SRCS EXCLUDE REGEX "\.(cc)|(cpp)$") - list(FILTER CXX_SRCS INCLUDE REGEX "\.(cc)|(cpp)$") + list(FILTER SRCS EXCLUDE REGEX "\.(cc)|(cpp)|(hip)$") + list(FILTER CXX_SRCS INCLUDE REGEX "\.(cc)|(cpp)|(hip)$") # # Generate ROCm/HIP source file names from CUDA file names. @@ -259,7 +259,7 @@ endmacro() # in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`. # We have special handling for 9.0a, if 9.0a is in `SRC_CUDA_ARCHS` and 9.0 is # in `TGT_CUDA_ARCHS` then we should remove 9.0a from `SRC_CUDA_ARCHS` and add -# 9.0a to the result. +# 9.0a to the result (and remove 9.0 from TGT_CUDA_ARCHS). # The result is stored in `OUT_CUDA_ARCHS`. # # Example: @@ -270,34 +270,47 @@ endmacro() # function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) list(REMOVE_DUPLICATES SRC_CUDA_ARCHS) + set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS}) # if 9.0a is in SRC_CUDA_ARCHS and 9.0 is in CUDA_ARCHS then we should # remove 9.0a from SRC_CUDA_ARCHS and add 9.0a to _CUDA_ARCHS set(_CUDA_ARCHS) if ("9.0a" IN_LIST SRC_CUDA_ARCHS) list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a") - if ("9.0" IN_LIST TGT_CUDA_ARCHS) + if ("9.0" IN_LIST TGT_CUDA_ARCHS_) + list(REMOVE_ITEM TGT_CUDA_ARCHS_ "9.0") set(_CUDA_ARCHS "9.0a") endif() endif() list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) - # for each ARCH in CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that is - # less or eqault to ARCH - foreach(_ARCH ${CUDA_ARCHS}) - set(_TMP_ARCH) - foreach(_SRC_ARCH ${SRC_CUDA_ARCHS}) - if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH) - set(_TMP_ARCH ${_SRC_ARCH}) - else() - break() + # for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that + # is less or equal to ARCH (but has the same major version since SASS binary + # compatibility is only forward compatible within the same major version). + foreach(_ARCH ${TGT_CUDA_ARCHS_}) + set(_TMP_ARCH) + # Extract the major version of the target arch + string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}") + foreach(_SRC_ARCH ${SRC_CUDA_ARCHS}) + # Extract the major version of the source arch + string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}") + # Check major-version match AND version-less-or-equal + if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH) + if (SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR) + set(_TMP_ARCH "${_SRC_ARCH}") + endif() + else() + # If we hit a version greater than the target, we can break + break() + endif() + endforeach() + + # If we found a matching _TMP_ARCH, append it to _CUDA_ARCHS + if (_TMP_ARCH) + list(APPEND _CUDA_ARCHS "${_TMP_ARCH}") endif() endforeach() - if (_TMP_ARCH) - list(APPEND _CUDA_ARCHS ${_TMP_ARCH}) - endif() - endforeach() list(REMOVE_DUPLICATES _CUDA_ARCHS) set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 839dc36ba4e29..88275dbdd83a1 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -9,8 +9,16 @@ namespace vllm { +template +__device__ __forceinline__ scalar_t compute(const scalar_t& x, + const scalar_t& y) { + return act_first ? ACT_FN(x) * y : x * ACT_FN(y); +} // Activation and gating kernel template. -template + +template __global__ void act_and_mul_kernel( scalar_t* __restrict__ out, // [..., d] const scalar_t* __restrict__ input, // [..., 2, d] @@ -19,7 +27,7 @@ __global__ void act_and_mul_kernel( for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); - out[token_idx * d + idx] = ACT_FN(x) * y; + out[token_idx * d + idx] = compute(x, y); } } @@ -55,7 +63,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { } // namespace vllm // Launch activation and gating kernel. -#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ +// Use ACT_FIRST (bool) indicating whether to apply the activation function +// first. +#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST) \ int d = input.size(-1) / 2; \ int64_t num_tokens = input.numel() / input.size(-1); \ dim3 grid(num_tokens); \ @@ -64,7 +74,7 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ VLLM_DISPATCH_FLOATING_TYPES( \ input.scalar_type(), "act_and_mul_kernel", [&] { \ - vllm::act_and_mul_kernel> \ + vllm::act_and_mul_kernel, ACT_FIRST> \ <<>>(out.data_ptr(), \ input.data_ptr(), d); \ }); @@ -72,19 +82,27 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { void silu_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { - LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, true); +} + +void mul_and_silu(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] +{ + // The difference between mul_and_silu and silu_and_mul is that mul_and_silu + // applies the silu to the latter half of the input. + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, false); } void gelu_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { - LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel); + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, true); } void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { - LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel); + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel, true); } namespace vllm { diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh index 563e1438f0b01..eb216dc8baf10 100644 --- a/csrc/attention/attention_kernels.cuh +++ b/csrc/attention/attention_kernels.cuh @@ -105,7 +105,7 @@ __device__ void paged_attention_kernel( const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float k_scale, const float v_scale, const int tp_rank, + const float* k_scale, const float* v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int seq_idx = blockIdx.y; @@ -285,7 +285,7 @@ __device__ void paged_attention_kernel( Quant_vec k_vec_quant = *reinterpret_cast( k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_vecs[j] = fp8::scaled_convert( - k_vec_quant, k_scale); + k_vec_quant, *k_scale); } } @@ -415,7 +415,7 @@ __device__ void paged_attention_kernel( *reinterpret_cast(v_ptr + offset); // Vector conversion from V_quant_vec to V_vec. v_vec = fp8::scaled_convert(v_quant_vec, - v_scale); + *v_scale); } if (block_idx == num_seq_blocks - 1) { // NOTE(woosuk): When v_vec contains the tokens that are out of the @@ -513,7 +513,7 @@ __global__ void paged_attention_v1_kernel( const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float k_scale, const float v_scale, const int tp_rank, + const float* k_scale, const float* v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { paged_attention_kernel& alibi_slopes, float k_scale, - float v_scale, const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { + const std::optional& alibi_slopes, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -80,6 +80,8 @@ void paged_attention_v1_launcher( CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* seq_lens_ptr = seq_lens.data_ptr(); + const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); + const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; int padded_max_seq_len = @@ -176,9 +178,10 @@ void paged_attention_v1( torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const std::optional& alibi_slopes, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { const bool is_block_sparse = (blocksparse_vert_stride > 1); diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu index c457bdb89008e..9935359e02fb1 100644 --- a/csrc/attention/paged_attention_v2.cu +++ b/csrc/attention/paged_attention_v2.cu @@ -37,7 +37,7 @@ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ - kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \ + kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \ blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_block_size, blocksparse_head_sliding_step); \ vllm::paged_attention_v2_reduce_kernel& alibi_slopes, float k_scale, - float v_scale, const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { + const std::optional& alibi_slopes, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -84,6 +84,8 @@ void paged_attention_v2_launcher( CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* seq_lens_ptr = seq_lens.data_ptr(); + const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); + const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); @@ -187,9 +189,10 @@ void paged_attention_v2( torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const std::optional& alibi_slopes, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { const bool is_block_sparse = (blocksparse_vert_stride > 1); diff --git a/csrc/cache.h b/csrc/cache.h index 11c4c5001daaa..55ed30bd8ce48 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -18,15 +18,20 @@ void copy_blocks(std::vector const& key_caches, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, const double k_scale, - const double v_scale); + const std::string& kv_cache_dtype, + torch::Tensor& k_scale, torch::Tensor& v_scale); void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, const std::string& kv_cache_dtype, - const double k_scale, const double v_scale); + torch::Tensor& k_scale, torch::Tensor& v_scale); + +void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe, + torch::Tensor& kv_cache, torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype, + torch::Tensor& scale); // Just for unittest void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 8a95279f9a25a..23a46b6ed8ad8 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -159,8 +159,8 @@ __global__ void reshape_and_cache_kernel( // block_size] const int64_t* __restrict__ slot_mapping, // [num_tokens] const int key_stride, const int value_stride, const int num_heads, - const int head_size, const int block_size, const int x, const float k_scale, - const float v_scale) { + const int head_size, const int block_size, const int x, + const float* k_scale, const float* v_scale) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; if (slot_idx < 0) { @@ -196,9 +196,9 @@ __global__ void reshape_and_cache_kernel( value_cache[tgt_value_idx] = tgt_value; } else { key_cache[tgt_key_idx] = - fp8::scaled_convert(tgt_key, k_scale); + fp8::scaled_convert(tgt_key, *k_scale); value_cache[tgt_value_idx] = - fp8::scaled_convert(tgt_value, v_scale); + fp8::scaled_convert(tgt_value, *v_scale); } } } @@ -214,7 +214,7 @@ __global__ void reshape_and_cache_flash_kernel( const int64_t* __restrict__ slot_mapping, // [num_tokens] const int block_stride, const int key_stride, const int value_stride, const int num_heads, const int head_size, const int block_size, - const float k_scale, const float v_scale) { + const float* k_scale, const float* v_scale) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; // NOTE: slot_idx can be -1 if the token is padded @@ -239,12 +239,57 @@ __global__ void reshape_and_cache_flash_kernel( value_cache[tgt_key_value_idx] = tgt_value; } else { key_cache[tgt_key_value_idx] = - fp8::scaled_convert(tgt_key, k_scale); + fp8::scaled_convert(tgt_key, *k_scale); value_cache[tgt_key_value_idx] = - fp8::scaled_convert(tgt_value, v_scale); + fp8::scaled_convert(tgt_value, *v_scale); } } } + +template +__global__ void concat_and_cache_mla_kernel( + const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank] + const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim] + cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank + // + pe_dim)] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride, // + const int kv_c_stride, // + const int k_pe_stride, // + const int kv_lora_rank, // + const int pe_dim, // + const int block_size, // + const float* scale // +) { + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + // NOTE: slot_idx can be -1 if the token is padded + if (slot_idx < 0) { + return; + } + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + + auto copy = [&](const scalar_t* __restrict__ src, cache_t* __restrict__ dst, + int src_stride, int dst_stride, int size, int offset) { + for (int i = threadIdx.x; i < size; i += blockDim.x) { + const int64_t src_idx = token_idx * src_stride + i; + const int64_t dst_idx = block_idx * block_stride + + block_offset * (kv_lora_rank + pe_dim) + i + + offset; + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + dst[dst_idx] = src[src_idx]; + } else { + dst[dst_idx] = + fp8::scaled_convert(src[src_idx], *scale); + } + } + }; + + copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0); + copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank); +} + } // namespace vllm // KV_T is the stored data type of kv-cache. @@ -258,7 +303,9 @@ __global__ void reshape_and_cache_flash_kernel( reinterpret_cast(key_cache.data_ptr()), \ reinterpret_cast(value_cache.data_ptr()), \ slot_mapping.data_ptr(), key_stride, value_stride, \ - num_heads, head_size, block_size, x, k_scale, v_scale); + num_heads, head_size, block_size, x, \ + reinterpret_cast(k_scale.data_ptr()), \ + reinterpret_cast(v_scale.data_ptr())); void reshape_and_cache( torch::Tensor& key, // [num_tokens, num_heads, head_size] @@ -268,8 +315,8 @@ void reshape_and_cache( torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] torch::Tensor& slot_mapping, // [num_tokens] - const std::string& kv_cache_dtype, const double k_scale, - const double v_scale) { + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale) { int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); @@ -299,7 +346,9 @@ void reshape_and_cache( reinterpret_cast(key_cache.data_ptr()), \ reinterpret_cast(value_cache.data_ptr()), \ slot_mapping.data_ptr(), block_stride, key_stride, \ - value_stride, num_heads, head_size, block_size, k_scale, v_scale); + value_stride, num_heads, head_size, block_size, \ + reinterpret_cast(k_scale.data_ptr()), \ + reinterpret_cast(v_scale.data_ptr())); void reshape_and_cache_flash( torch::Tensor& key, // [num_tokens, num_heads, head_size] @@ -308,8 +357,8 @@ void reshape_and_cache_flash( torch::Tensor& value_cache, // [num_blocks, block_size, num_heads, head_size] torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] - const std::string& kv_cache_dtype, const double k_scale, - const double v_scale) { + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale) { // NOTE(woosuk): In vLLM V1, key.size(0) can be different from // slot_mapping.size(0) because of padding for CUDA graphs. // In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because @@ -339,6 +388,56 @@ void reshape_and_cache_flash( CALL_RESHAPE_AND_CACHE_FLASH); } +// KV_T is the stored data type of kv-cache. +// CACHE_T is the data type of key and value tensors. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \ + vllm::concat_and_cache_mla_kernel \ + <<>>( \ + reinterpret_cast(kv_c.data_ptr()), \ + reinterpret_cast(k_pe.data_ptr()), \ + reinterpret_cast(kv_cache.data_ptr()), \ + slot_mapping.data_ptr(), block_stride, kv_c_stride, \ + k_pe_stride, kv_lora_rank, pe_dim, block_size, \ + reinterpret_cast(scale.data_ptr())); + +void concat_and_cache_mla( + torch::Tensor& kv_c, // [num_tokens, kv_lora_rank] + torch::Tensor& k_pe, // [num_tokens, pe_dim] + torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank + + // pe_dim)] + torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] + const std::string& kv_cache_dtype, torch::Tensor& scale) { + // NOTE(woosuk): In vLLM V1, key.size(0) can be different from + // slot_mapping.size(0) because of padding for CUDA graphs. + // In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because + // both include padding. + // In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0) + // since key includes padding for CUDA graphs, while slot_mapping does not. + // In this case, slot_mapping.size(0) represents the actual number of tokens + // before padding. + // For compatibility with both cases, we use slot_mapping.size(0) as the + // number of tokens. + int num_tokens = slot_mapping.size(0); + int kv_lora_rank = kv_c.size(1); + int pe_dim = k_pe.size(1); + int block_size = kv_cache.size(1); + + TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim); + + int kv_c_stride = kv_c.stride(0); + int k_pe_stride = k_pe.stride(0); + int block_stride = kv_cache.stride(0); + + dim3 grid(num_tokens); + dim3 block(std::min(kv_lora_rank, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, + CALL_CONCAT_AND_CACHE_MLA); +} + namespace vllm { template diff --git a/csrc/core/math.hpp b/csrc/core/math.hpp index ba9f40a230c8e..ddfaca27147b4 100644 --- a/csrc/core/math.hpp +++ b/csrc/core/math.hpp @@ -1,7 +1,14 @@ +#pragma once + #include #include -inline uint32_t next_pow_2(uint32_t const num) { +inline constexpr uint32_t next_pow_2(uint32_t const num) { if (num <= 1) return num; return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); +} + +template +inline constexpr std::enable_if_t, T> ceil_div(T a, T b) { + return (a + b - 1) / b; } \ No newline at end of file diff --git a/csrc/core/scalar_type.hpp b/csrc/core/scalar_type.hpp index 408e736d5bc0f..c2ae554c9f8e8 100644 --- a/csrc/core/scalar_type.hpp +++ b/csrc/core/scalar_type.hpp @@ -32,7 +32,7 @@ class ScalarType { signed_(signed_), bias(bias), finite_values_only(finite_values_only), - nan_repr(nan_repr){}; + nan_repr(nan_repr) {}; static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) { return ScalarType(0, size_bits - 1, true, bias); diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index e21832ba7582f..b9764056e8a2d 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -386,7 +386,7 @@ void paged_attention_v1_impl_launcher( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const c10::optional& alibi_slopes) { + const std::optional& alibi_slopes) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -459,12 +459,12 @@ void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + int64_t max_seq_len, const std::optional& alibi_slopes, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { - TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", @@ -702,7 +702,7 @@ void paged_attention_v2_impl_launcher( torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, - int max_seq_len, const c10::optional& alibi_slopes) { + int max_seq_len, const std::optional& alibi_slopes) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -781,12 +781,12 @@ void paged_attention_v2( torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + int64_t max_seq_len, const std::optional& alibi_slopes, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { - TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp index 31d454328b2c1..e3809acad7453 100644 --- a/csrc/cpu/cache.cpp +++ b/csrc/cpu/cache.cpp @@ -107,10 +107,8 @@ void copy_blocks(std::vector const& key_caches, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, double k_scale, - double v_scale) { - TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); - + const std::string& kv_cache_dtype, + torch::Tensor& k_scale, torch::Tensor& v_scale) { int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); diff --git a/csrc/cpu/cpu_types.hpp b/csrc/cpu/cpu_types.hpp index 28db0479748bf..a71815106133a 100644 --- a/csrc/cpu/cpu_types.hpp +++ b/csrc/cpu/cpu_types.hpp @@ -2,13 +2,13 @@ #define CPU_TYPES_HPP #if defined(__x86_64__) - //x86 implementation + // x86 implementation #include "cpu_types_x86.hpp" #elif defined(__POWER9_VECTOR__) - //ppc implementation + // ppc implementation #include "cpu_types_vsx.hpp" #elif defined(__aarch64__) - //arm implementation + // arm implementation #include "cpu_types_arm.hpp" #else #warning "unsupported vLLM cpu implementation" diff --git a/csrc/cpu/cpu_types_arm.hpp b/csrc/cpu/cpu_types_arm.hpp index 73e0f8cb2e0fb..990e99f2fc069 100644 --- a/csrc/cpu/cpu_types_arm.hpp +++ b/csrc/cpu/cpu_types_arm.hpp @@ -1,48 +1,50 @@ #include -#include +#include #include namespace vec_op { #ifdef ARM_BF16_SUPPORT - #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) #else - #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) #endif -#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) #ifndef CPU_OP_GUARD -#define CPU_KERNEL_GUARD_IN(NAME) -#define CPU_KERNEL_GUARD_OUT(NAME) + #define CPU_KERNEL_GUARD_IN(NAME) + #define CPU_KERNEL_GUARD_OUT(NAME) #else -#define CPU_KERNEL_GUARD_IN(NAME) \ - std::cout << #NAME << " invoked." << std::endl; -#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl; + #define CPU_KERNEL_GUARD_IN(NAME) \ + std::cout << #NAME << " invoked." << std::endl; + #define CPU_KERNEL_GUARD_OUT(NAME) \ + std::cout << #NAME << " exit." << std::endl; #endif #define FORCE_INLINE __attribute__((always_inline)) inline namespace { - template - constexpr void unroll_loop_item(std::integer_sequence, F &&f) { - (f(std::integral_constant{}), ...); - }; -}; +template +constexpr void unroll_loop_item(std::integer_sequence, F&& f) { + (f(std::integral_constant{}), ...); +}; +}; // namespace template >> -constexpr void unroll_loop(F &&f) { +constexpr void unroll_loop(F&& f) { unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); } -template struct Vec { +template +struct Vec { constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }; }; @@ -54,70 +56,124 @@ struct FP16Vec8 : public Vec { float16x8_t reg; - explicit FP16Vec8(const void *ptr) - : reg(vld1q_f16(static_cast(ptr))) {}; + explicit FP16Vec8(const void* ptr) + : reg(vld1q_f16(static_cast(ptr))) {}; - explicit FP16Vec8(const FP32Vec8 &); + explicit FP16Vec8(const FP32Vec8&); - void save(void *ptr) const { - vst1q_f16(static_cast<__fp16 *>(ptr), reg); - } + void save(void* ptr) const { vst1q_f16(static_cast<__fp16*>(ptr), reg); } }; struct FP16Vec16 : public Vec { - constexpr static int VEC_ELEM_NUM = 16; - - float16x8x2_t reg; - - explicit FP16Vec16(const void *ptr) { - reg.val[0] = vld1q_f16(reinterpret_cast(ptr)); - reg.val[1] = vld1q_f16(reinterpret_cast(ptr) + 8); - } - - explicit FP16Vec16(const FP32Vec16& vec); - - void save(void *ptr) const { - vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]); - vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]); + constexpr static int VEC_ELEM_NUM = 16; + + float16x8x2_t reg; + + explicit FP16Vec16(const void* ptr) { + reg.val[0] = vld1q_f16(reinterpret_cast(ptr)); + reg.val[1] = vld1q_f16(reinterpret_cast(ptr) + 8); + } + + explicit FP16Vec16(const FP32Vec16& vec); + + void save(void* ptr) const { + vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]); + vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]); + } + + void save(void* ptr, const int elem_num) const { + int full_blocks = elem_num / 8; + int remainder = elem_num % 8; + + if (full_blocks > 0) { + vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]); + if (full_blocks > 1) { + vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]); + } } - - void save(void *ptr, const int elem_num) const { - int full_blocks = elem_num / 8; - int remainder = elem_num % 8; - - if (full_blocks > 0) { - vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]); - if (full_blocks > 1) { - vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]); - } - } - - if (remainder > 0) { - float16x8_t temp = reg.val[full_blocks]; - for (int i = 0; i < remainder; ++i) { - reinterpret_cast<__fp16*>(ptr)[full_blocks * 8 + i] = vgetq_lane_f16(temp, i); - } - } + + // Note: below is the unrolled version of the following code: + // + // for (int i = 0; i < remainder; ++i) { + // reinterpret_cast<__fp16*>(ptr)[full_blocks * 8 + i] = + // vgetq_lane_f16(temp, i); + // } + // + // For macOS build (Clang), the arm/neon intrinsics function + // `vgetq_lane_f16` needs the parameter `i` to be constant at compile + // time. + + if (remainder > 0) { + float16x8_t temp = reg.val[full_blocks]; + __fp16* fp16_ptr = reinterpret_cast<__fp16*>(ptr); + switch (remainder) { + case 1: + fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0); + break; + case 2: + fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0); + fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1); + break; + case 3: + fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0); + fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1); + fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2); + break; + case 4: + fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0); + fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1); + fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2); + fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3); + break; + case 5: + fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0); + fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1); + fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2); + fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3); + fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4); + break; + case 6: + fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0); + fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1); + fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2); + fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3); + fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4); + fp16_ptr[full_blocks * 8 + 5] = vgetq_lane_f16(temp, 5); + break; + case 7: + fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0); + fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1); + fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2); + fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3); + fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4); + fp16_ptr[full_blocks * 8 + 5] = vgetq_lane_f16(temp, 5); + fp16_ptr[full_blocks * 8 + 6] = vgetq_lane_f16(temp, 6); + break; + + default: + break; + } } + } }; - #ifdef ARM_BF16_SUPPORT struct BF16Vec8 : public Vec { constexpr static int VEC_ELEM_NUM = 8; bfloat16x8_t reg; - explicit BF16Vec8(const void *ptr) - : reg(*reinterpret_cast(ptr)) {}; + explicit BF16Vec8(const void* ptr) + : reg(*reinterpret_cast(ptr)) {}; explicit BF16Vec8(bfloat16x8_t data) : reg(data) {}; - explicit BF16Vec8(const FP32Vec8 &); + explicit BF16Vec8(const FP32Vec8&); - explicit BF16Vec8(float32x4x2_t v) : reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1])) {}; + explicit BF16Vec8(float32x4x2_t v) + : reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1])) {}; - void save(void *ptr) const { *reinterpret_cast(ptr) = reg; } + void save(void* ptr) const { *reinterpret_cast(ptr) = reg; } }; struct BF16Vec16 : public Vec { @@ -125,19 +181,18 @@ struct BF16Vec16 : public Vec { bfloat16x8x2_t reg; - explicit BF16Vec16(const void *ptr) - : reg(*reinterpret_cast(ptr)) {}; + explicit BF16Vec16(const void* ptr) + : reg(*reinterpret_cast(ptr)) {}; explicit BF16Vec16(bfloat16x8x2_t data) : reg(data) {}; - explicit BF16Vec16(const FP32Vec16 &); + explicit BF16Vec16(const FP32Vec16&); - explicit BF16Vec16(float32x4x4_t v) : reg({ - vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1]), - vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3]) - }){}; + explicit BF16Vec16(float32x4x4_t v) + : reg({vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1]), + vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3])}) {}; - void save(void *ptr) const { *reinterpret_cast(ptr) = reg; }; + void save(void* ptr) const { *reinterpret_cast(ptr) = reg; }; }; struct BF16Vec32 : public Vec { @@ -145,19 +200,15 @@ struct BF16Vec32 : public Vec { bfloat16x8x4_t reg; - explicit BF16Vec32(const void *ptr) - : reg(*reinterpret_cast(ptr)) {}; + explicit BF16Vec32(const void* ptr) + : reg(*reinterpret_cast(ptr)) {}; explicit BF16Vec32(bfloat16x8x4_t data) : reg(data) {}; - explicit BF16Vec32(const BF16Vec8 &vec8_data) : reg({ - vec8_data.reg, - vec8_data.reg, - vec8_data.reg, - vec8_data.reg - }) {}; + explicit BF16Vec32(const BF16Vec8& vec8_data) + : reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {}; - void save(void *ptr) const { *reinterpret_cast(ptr) = reg; }; + void save(void* ptr) const { *reinterpret_cast(ptr) = reg; }; }; #endif @@ -175,11 +226,11 @@ struct FP32Vec4 : public Vec { explicit FP32Vec4() : reg(vdupq_n_f32(0.0f)) {}; - explicit FP32Vec4(const float *ptr) : reg(vld1q_f32(ptr)) {}; + explicit FP32Vec4(const float* ptr) : reg(vld1q_f32(ptr)) {}; explicit FP32Vec4(float32x4_t data) : reg(data) {}; - explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {}; + explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {}; }; struct FP32Vec8 : public Vec { @@ -195,32 +246,37 @@ struct FP32Vec8 : public Vec { explicit FP32Vec8() : reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0)}) {}; - explicit FP32Vec8(const float *ptr) : reg({vld1q_f32(ptr), vld1q_f32(ptr + 4)}) {}; + explicit FP32Vec8(const float* ptr) + : reg({vld1q_f32(ptr), vld1q_f32(ptr + 4)}) {}; explicit FP32Vec8(float32x4x2_t data) : reg(data) {}; - explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {}; + explicit FP32Vec8(const FP32Vec8& data) : reg(data.reg) {}; - explicit FP32Vec8(const FP16Vec8 &v) { - reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg)); - reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg)); - }; + explicit FP32Vec8(const FP16Vec8& v) { + reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg)); + reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg)); + }; - explicit FP32Vec8(float16x8_t v) : reg({vcvt_f32_f16(vget_low_f16(v)), vcvt_f32_f16(vget_high_f16(v))}) {}; + explicit FP32Vec8(float16x8_t v) + : reg({vcvt_f32_f16(vget_low_f16(v)), vcvt_f32_f16(vget_high_f16(v))}) {}; - #ifdef ARM_BF16_SUPPORT +#ifdef ARM_BF16_SUPPORT - explicit FP32Vec8(bfloat16x8_t v) : reg({vcvtq_low_f32_bf16(v), vcvtq_high_f32_bf16(v)}) {}; + explicit FP32Vec8(bfloat16x8_t v) + : reg({vcvtq_low_f32_bf16(v), vcvtq_high_f32_bf16(v)}) {}; - explicit FP32Vec8(const BF16Vec8 &v) : reg({vcvtq_low_f32_bf16(v.reg), vcvtq_high_f32_bf16(v.reg)}) {}; + explicit FP32Vec8(const BF16Vec8& v) + : reg({vcvtq_low_f32_bf16(v.reg), vcvtq_high_f32_bf16(v.reg)}) {}; - #endif +#endif float reduce_sum() const { AliasReg ar; ar.reg = reg; float answer = 0; - unroll_loop([&answer, &ar](int i) { answer += ar.values[i]; }); + unroll_loop( + [&answer, &ar](int i) { answer += ar.values[i]; }); return answer; } @@ -267,10 +323,14 @@ struct FP32Vec8 : public Vec { AliasReg ar; ar.reg = reg; - float32x2_t er_vec0 = {static_cast(erf(ar.values[0])), static_cast(erf(ar.values[1]))}; - float32x2_t er_vec1 = {static_cast(erf(ar.values[2])), static_cast(erf(ar.values[3]))}; - float32x2_t er_vec2 = {static_cast(erf(ar.values[4])), static_cast(erf(ar.values[5]))}; - float32x2_t er_vec3 = {static_cast(erf(ar.values[6])), static_cast(erf(ar.values[7]))}; + float32x2_t er_vec0 = {static_cast(erf(ar.values[0])), + static_cast(erf(ar.values[1]))}; + float32x2_t er_vec1 = {static_cast(erf(ar.values[2])), + static_cast(erf(ar.values[3]))}; + float32x2_t er_vec2 = {static_cast(erf(ar.values[4])), + static_cast(erf(ar.values[5]))}; + float32x2_t er_vec3 = {static_cast(erf(ar.values[6])), + static_cast(erf(ar.values[7]))}; float32x4_t result0 = vcombine_f32(er_vec0, er_vec1); float32x4_t result1 = vcombine_f32(er_vec2, er_vec3); @@ -280,25 +340,29 @@ struct FP32Vec8 : public Vec { result.val[1] = result1; return FP32Vec8(result); - } + } - FP32Vec8 operator*(const FP32Vec8 &b) const { - return FP32Vec8(float32x4x2_t({vmulq_f32(reg.val[0], b.reg.val[0]), vmulq_f32(reg.val[1], b.reg.val[1])})); + FP32Vec8 operator*(const FP32Vec8& b) const { + return FP32Vec8(float32x4x2_t({vmulq_f32(reg.val[0], b.reg.val[0]), + vmulq_f32(reg.val[1], b.reg.val[1])})); } - FP32Vec8 operator+(const FP32Vec8 &b) const { - return FP32Vec8(float32x4x2_t({vaddq_f32(reg.val[0], b.reg.val[0]), vaddq_f32(reg.val[1], b.reg.val[1])})); + FP32Vec8 operator+(const FP32Vec8& b) const { + return FP32Vec8(float32x4x2_t({vaddq_f32(reg.val[0], b.reg.val[0]), + vaddq_f32(reg.val[1], b.reg.val[1])})); } - FP32Vec8 operator-(const FP32Vec8 &b) const { - return FP32Vec8(float32x4x2_t({vsubq_f32(reg.val[0], b.reg.val[0]), vsubq_f32(reg.val[1], b.reg.val[1])})); + FP32Vec8 operator-(const FP32Vec8& b) const { + return FP32Vec8(float32x4x2_t({vsubq_f32(reg.val[0], b.reg.val[0]), + vsubq_f32(reg.val[1], b.reg.val[1])})); } - FP32Vec8 operator/(const FP32Vec8 &b) const { - return FP32Vec8(float32x4x2_t({vdivq_f32(reg.val[0], b.reg.val[0]), vdivq_f32(reg.val[1], b.reg.val[1])})); + FP32Vec8 operator/(const FP32Vec8& b) const { + return FP32Vec8(float32x4x2_t({vdivq_f32(reg.val[0], b.reg.val[0]), + vdivq_f32(reg.val[1], b.reg.val[1])})); } - void save(float *ptr) const { + void save(float* ptr) const { vst1q_f32(ptr, reg.val[0]); vst1q_f32(ptr + 4, reg.val[1]); } @@ -313,103 +377,100 @@ struct FP32Vec16 : public Vec { float32x4x4_t reg; - explicit FP32Vec16(float v) : reg({vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v)}) {} + explicit FP32Vec16(float v) + : reg({vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v)}) {} - explicit FP32Vec16() : reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0), vmovq_n_f32(0.0), vmovq_n_f32(0.0)}) {} + explicit FP32Vec16() + : reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0), vmovq_n_f32(0.0), + vmovq_n_f32(0.0)}) {} - explicit FP32Vec16(const float *ptr) : reg({vld1q_f32(ptr), vld1q_f32(ptr + 4), vld1q_f32(ptr + 8), vld1q_f32(ptr + 12)}) {} + explicit FP32Vec16(const float* ptr) + : reg({vld1q_f32(ptr), vld1q_f32(ptr + 4), vld1q_f32(ptr + 8), + vld1q_f32(ptr + 12)}) {} explicit FP32Vec16(float32x4x4_t data) : reg(data) {} - explicit FP32Vec16(const FP32Vec8 &data) { - reg.val[0] = data.reg.val[0]; - reg.val[1] = data.reg.val[1]; - reg.val[2] = data.reg.val[0]; - reg.val[3] = data.reg.val[1]; + explicit FP32Vec16(const FP32Vec8& data) { + reg.val[0] = data.reg.val[0]; + reg.val[1] = data.reg.val[1]; + reg.val[2] = data.reg.val[0]; + reg.val[3] = data.reg.val[1]; } - explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {} + explicit FP32Vec16(const FP32Vec16& data) : reg(data.reg) {} - explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v.reg)) {} + explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v.reg)) {} - #ifdef ARM_BF16_SUPPORT - explicit FP32Vec16(bfloat16x8x2_t v) : reg({ - vcvtq_low_f32_bf16(v.val[0]), - vcvtq_high_f32_bf16(v.val[0]), - vcvtq_low_f32_bf16(v.val[1]), - vcvtq_high_f32_bf16(v.val[1]) - }) {}; - #endif +#ifdef ARM_BF16_SUPPORT + explicit FP32Vec16(bfloat16x8x2_t v) + : reg({vcvtq_low_f32_bf16(v.val[0]), vcvtq_high_f32_bf16(v.val[0]), + vcvtq_low_f32_bf16(v.val[1]), vcvtq_high_f32_bf16(v.val[1])}) {}; +#endif - explicit FP32Vec16(const FP32Vec4 &data) { + explicit FP32Vec16(const FP32Vec4& data) { reg.val[0] = data.reg; reg.val[1] = data.reg; reg.val[2] = data.reg; reg.val[3] = data.reg; }; - #ifdef ARM_BF16_SUPPORT - explicit FP32Vec16(const BF16Vec16 &v) : reg({ - vcvtq_low_f32_bf16(v.reg.val[0]), - vcvtq_high_f32_bf16(v.reg.val[0]), - vcvtq_low_f32_bf16(v.reg.val[1]), - vcvtq_high_f32_bf16(v.reg.val[1]) - }) {}; - - explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}; - #endif - - explicit FP32Vec16(const FP16Vec16 &v) { - reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg.val[0])); - reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg.val[0])); - reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1])); - reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1])); +#ifdef ARM_BF16_SUPPORT + explicit FP32Vec16(const BF16Vec16& v) + : reg({vcvtq_low_f32_bf16(v.reg.val[0]), + vcvtq_high_f32_bf16(v.reg.val[0]), + vcvtq_low_f32_bf16(v.reg.val[1]), + vcvtq_high_f32_bf16(v.reg.val[1])}) {}; + + explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}; +#endif + + explicit FP32Vec16(const FP16Vec16& v) { + reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg.val[0])); + reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg.val[0])); + reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1])); + reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1])); }; - FP32Vec16 operator+(const FP32Vec16 &b) const { - return FP32Vec16(float32x4x4_t({ - vaddq_f32(reg.val[0], b.reg.val[0]), - vaddq_f32(reg.val[1], b.reg.val[1]), - vaddq_f32(reg.val[2], b.reg.val[2]), - vaddq_f32(reg.val[3], b.reg.val[3])})); + FP32Vec16 operator+(const FP32Vec16& b) const { + return FP32Vec16(float32x4x4_t({vaddq_f32(reg.val[0], b.reg.val[0]), + vaddq_f32(reg.val[1], b.reg.val[1]), + vaddq_f32(reg.val[2], b.reg.val[2]), + vaddq_f32(reg.val[3], b.reg.val[3])})); }; - FP32Vec16 operator*(const FP32Vec16 &b) const { - return FP32Vec16(float32x4x4_t({ - vmulq_f32(reg.val[0], b.reg.val[0]), - vmulq_f32(reg.val[1], b.reg.val[1]), - vmulq_f32(reg.val[2], b.reg.val[2]), - vmulq_f32(reg.val[3], b.reg.val[3])})); + FP32Vec16 operator*(const FP32Vec16& b) const { + return FP32Vec16(float32x4x4_t({vmulq_f32(reg.val[0], b.reg.val[0]), + vmulq_f32(reg.val[1], b.reg.val[1]), + vmulq_f32(reg.val[2], b.reg.val[2]), + vmulq_f32(reg.val[3], b.reg.val[3])})); }; - FP32Vec16 operator-(const FP32Vec16 &b) const { - return FP32Vec16(float32x4x4_t({ - vsubq_f32(reg.val[0], b.reg.val[0]), - vsubq_f32(reg.val[1], b.reg.val[1]), - vsubq_f32(reg.val[2], b.reg.val[2]), - vsubq_f32(reg.val[3], b.reg.val[3]) - })); + FP32Vec16 operator-(const FP32Vec16& b) const { + return FP32Vec16(float32x4x4_t({vsubq_f32(reg.val[0], b.reg.val[0]), + vsubq_f32(reg.val[1], b.reg.val[1]), + vsubq_f32(reg.val[2], b.reg.val[2]), + vsubq_f32(reg.val[3], b.reg.val[3])})); }; - FP32Vec16 operator/(const FP32Vec16 &b) const { - return FP32Vec16(float32x4x4_t({ - vdivq_f32(reg.val[0], b.reg.val[0]), - vdivq_f32(reg.val[1], b.reg.val[1]), - vdivq_f32(reg.val[2], b.reg.val[2]), - vdivq_f32(reg.val[3], b.reg.val[3]) - })); + FP32Vec16 operator/(const FP32Vec16& b) const { + return FP32Vec16(float32x4x4_t({vdivq_f32(reg.val[0], b.reg.val[0]), + vdivq_f32(reg.val[1], b.reg.val[1]), + vdivq_f32(reg.val[2], b.reg.val[2]), + vdivq_f32(reg.val[3], b.reg.val[3])})); }; float reduce_sum() const { AliasReg ar; ar.reg = reg; float answer = 0; - unroll_loop([&answer, &ar](int i) { answer += ar.values[i]; }); + unroll_loop( + [&answer, &ar](int i) { answer += ar.values[i]; }); return answer; }; - template float reduce_sub_sum(int idx) { + template + float reduce_sub_sum(int idx) { static_assert(VEC_ELEM_NUM % group_size == 0); AliasReg ar; @@ -422,7 +483,7 @@ struct FP32Vec16 : public Vec { return answer; }; - void save(float *ptr) const { + void save(float* ptr) const { vst1q_f32(ptr, reg.val[0]); vst1q_f32(ptr + 4, reg.val[1]); vst1q_f32(ptr + 8, reg.val[2]); @@ -430,43 +491,59 @@ struct FP32Vec16 : public Vec { }; }; -template struct VecType { using vec_type = void; }; +template +struct VecType { + using vec_type = void; +}; -template using vec_t = typename VecType::vec_type; +template +using vec_t = typename VecType::vec_type; -template <> struct VecType { using vec_type = FP32Vec8; }; +template <> +struct VecType { + using vec_type = FP32Vec8; +}; -template <> struct VecType { using vec_type = FP16Vec8; }; +template <> +struct VecType { + using vec_type = FP16Vec8; +}; #ifdef ARM_BF16_SUPPORT -template <> struct VecType { using vec_type = BF16Vec8; }; +template <> +struct VecType { + using vec_type = BF16Vec8; +}; #endif -template void storeFP32(float v, T *ptr) { *ptr = v; } +template +void storeFP32(float v, T* ptr) { + *ptr = v; +} -template <> inline void storeFP32(float v, c10::Half *ptr) { - *reinterpret_cast<__fp16 *>(ptr) = v; +template <> +inline void storeFP32(float v, c10::Half* ptr) { + *reinterpret_cast<__fp16*>(ptr) = v; } -inline FP16Vec16::FP16Vec16(const FP32Vec16 &v) { - float16x4_t low_0 = vcvt_f16_f32(v.reg.val[0]); - float16x4_t high_0 = vcvt_f16_f32(v.reg.val[1]); - float16x4_t low_1 = vcvt_f16_f32(v.reg.val[2]); - float16x4_t high_1 = vcvt_f16_f32(v.reg.val[3]); +inline FP16Vec16::FP16Vec16(const FP32Vec16& v) { + float16x4_t low_0 = vcvt_f16_f32(v.reg.val[0]); + float16x4_t high_0 = vcvt_f16_f32(v.reg.val[1]); + float16x4_t low_1 = vcvt_f16_f32(v.reg.val[2]); + float16x4_t high_1 = vcvt_f16_f32(v.reg.val[3]); - reg.val[0] = vcombine_f16(low_0, high_0); - reg.val[1] = vcombine_f16(low_1, high_1); + reg.val[0] = vcombine_f16(low_0, high_0); + reg.val[1] = vcombine_f16(low_1, high_1); }; -inline FP16Vec8 :: FP16Vec8(const FP32Vec8 &v) { - float16x4_t lower_half = vcvt_f16_f32(v.reg.val[0]); - float16x4_t upper_half = vcvt_f16_f32(v.reg.val[1]); +inline FP16Vec8 ::FP16Vec8(const FP32Vec8& v) { + float16x4_t lower_half = vcvt_f16_f32(v.reg.val[0]); + float16x4_t upper_half = vcvt_f16_f32(v.reg.val[1]); - reg = vcombine_f16(lower_half, upper_half); + reg = vcombine_f16(lower_half, upper_half); }; -inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { - +inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) { acc.reg.val[0] = vfmaq_f32(acc.reg.val[0], a.reg.val[0], b.reg.val[0]); acc.reg.val[1] = vfmaq_f32(acc.reg.val[1], a.reg.val[1], b.reg.val[1]); acc.reg.val[2] = vfmaq_f32(acc.reg.val[2], a.reg.val[2], b.reg.val[2]); @@ -474,8 +551,7 @@ inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { }; #ifdef ARM_BF16_SUPPORT -inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) { - +inline void fma(FP32Vec16& acc, BF16Vec32& a, BF16Vec32& b) { float32x4_t a0_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[0])); float32x4_t a0_high = vcvt_f32_bf16(vget_high_bf16(a.reg.val[0])); float32x4_t a1_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[1])); @@ -494,22 +570,22 @@ inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) { #endif #ifdef ARM_BF16_SUPPORT -inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) : reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1])) {}; - -inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) : reg({ - vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1]), - vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[2]), v.reg.val[3]) - }){}; +inline BF16Vec8::BF16Vec8(const FP32Vec8& v) + : reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1])) { + }; + +inline BF16Vec16::BF16Vec16(const FP32Vec16& v) + : reg({vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1]), + vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[2]), + v.reg.val[3])}) {}; #endif -inline void prefetch(const void *addr) { - __builtin_prefetch(addr, 0, 1); -}; +inline void prefetch(const void* addr) { __builtin_prefetch(addr, 0, 1); }; #ifdef ARM_BF16_SUPPORT template <> -inline void storeFP32(float v, c10::BFloat16 *ptr) { - *reinterpret_cast<__bf16 *>(ptr) = vcvth_bf16_f32(v); +inline void storeFP32(float v, c10::BFloat16* ptr) { + *reinterpret_cast<__bf16*>(ptr) = vcvth_bf16_f32(v); }; #endif -}; \ No newline at end of file +}; // namespace vec_op \ No newline at end of file diff --git a/csrc/cpu/cpu_types_vsx.hpp b/csrc/cpu/cpu_types_vsx.hpp index b50bdadc5713d..a8e1be37eb418 100644 --- a/csrc/cpu/cpu_types_vsx.hpp +++ b/csrc/cpu/cpu_types_vsx.hpp @@ -9,38 +9,40 @@ namespace vec_op { // FIXME: FP16 is not fully supported in Torch-CPU -#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) -#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) #ifndef CPU_OP_GUARD -#define CPU_KERNEL_GUARD_IN(NAME) -#define CPU_KERNEL_GUARD_OUT(NAME) + #define CPU_KERNEL_GUARD_IN(NAME) + #define CPU_KERNEL_GUARD_OUT(NAME) #else -#define CPU_KERNEL_GUARD_IN(NAME) \ - std::cout << #NAME << " invoked." << std::endl; -#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl; + #define CPU_KERNEL_GUARD_IN(NAME) \ + std::cout << #NAME << " invoked." << std::endl; + #define CPU_KERNEL_GUARD_OUT(NAME) \ + std::cout << #NAME << " exit." << std::endl; #endif #define FORCE_INLINE __attribute__((always_inline)) inline namespace { template -constexpr void unroll_loop_item(std::integer_sequence, F &&f) { +constexpr void unroll_loop_item(std::integer_sequence, F&& f) { (f(std::integral_constant{}), ...); } -}; // namespace +}; // namespace template >> -constexpr void unroll_loop(F &&f) { +constexpr void unroll_loop(F&& f) { unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); } -template struct Vec { +template +struct Vec { constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } }; @@ -68,12 +70,14 @@ struct BF16Vec8 : public Vec { __vector signed short reg; - explicit BF16Vec8(const void *ptr) - : reg((__vector signed short)vec_xl(0, (__vector signed short *)ptr)) {} + explicit BF16Vec8(const void* ptr) + : reg((__vector signed short)vec_xl(0, (__vector signed short*)ptr)) {} - explicit BF16Vec8(const FP32Vec8 &); + explicit BF16Vec8(const FP32Vec8&); - void save(void *ptr) const { *reinterpret_cast<__vector signed short *>(ptr) = reg; } + void save(void* ptr) const { + *reinterpret_cast<__vector signed short*>(ptr) = reg; + } }; struct BF16Vec16 : public Vec { @@ -81,18 +85,18 @@ struct BF16Vec16 : public Vec { ss16x8x2_t reg; - explicit BF16Vec16(const void *ptr) { + explicit BF16Vec16(const void* ptr) { // Load 256 bits in two parts - reg.val[0] = (__vector signed short)vec_xl(0, (signed short *)ptr); - reg.val[1] = (__vector signed short)vec_xl(16, (signed short *)ptr); + reg.val[0] = (__vector signed short)vec_xl(0, (signed short*)ptr); + reg.val[1] = (__vector signed short)vec_xl(16, (signed short*)ptr); } - explicit BF16Vec16(const FP32Vec16 &); + explicit BF16Vec16(const FP32Vec16&); - void save(void *ptr) const { + void save(void* ptr) const { // Save 256 bits in two parts - vec_xst(reg.val[0], 0, (signed short *)ptr); - vec_xst(reg.val[1], 16, (signed short *)ptr); + vec_xst(reg.val[0], 0, (signed short*)ptr); + vec_xst(reg.val[1], 16, (signed short*)ptr); } }; @@ -102,19 +106,15 @@ struct BF16Vec32 : public Vec { constexpr static int VEC_ELEM_NUM = 32; ss16x8x4_t reg; - explicit BF16Vec32(const void *ptr) - : reg(*reinterpret_cast(ptr)) {} + explicit BF16Vec32(const void* ptr) + : reg(*reinterpret_cast(ptr)) {} explicit BF16Vec32(ss16x8x4_t data) : reg(data) {} - explicit BF16Vec32(const BF16Vec8 &vec8_data) : reg({ - vec8_data.reg, - vec8_data.reg, - vec8_data.reg, - vec8_data.reg - }) {} + explicit BF16Vec32(const BF16Vec8& vec8_data) + : reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {} - void save(void *ptr) const { *reinterpret_cast(ptr) = reg; } + void save(void* ptr) const { *reinterpret_cast(ptr) = reg; } }; struct FP32Vec4 : public Vec { @@ -130,11 +130,11 @@ struct FP32Vec4 : public Vec { explicit FP32Vec4() : reg(vec_splats(0.0f)) {} - explicit FP32Vec4(const float *ptr) : reg(vec_xl(0, ptr)) {} + explicit FP32Vec4(const float* ptr) : reg(vec_xl(0, ptr)) {} explicit FP32Vec4(__vector float data) : reg(data) {} - explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {} + explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {} }; struct FP32Vec8 : public Vec { @@ -156,19 +156,19 @@ struct FP32Vec8 : public Vec { reg.val[1] = vec_splats(0.0f); } - explicit FP32Vec8(const float *ptr) { + explicit FP32Vec8(const float* ptr) { reg.val[0] = vec_xl(0, ptr); reg.val[1] = vec_xl(16, ptr); } explicit FP32Vec8(f32x4x2_t data) : reg(data) {} - explicit FP32Vec8(const FP32Vec8 &data) { + explicit FP32Vec8(const FP32Vec8& data) { reg.val[0] = data.reg.val[0]; reg.val[1] = data.reg.val[1]; } - explicit FP32Vec8(const BF16Vec8 &v) { + explicit FP32Vec8(const BF16Vec8& v) { reg.val[0] = (__vector float)vec_mergeh(zero, v.reg); reg.val[1] = (__vector float)vec_mergel(zero, v.reg); } @@ -177,7 +177,8 @@ struct FP32Vec8 : public Vec { AliasReg ar; ar.reg = reg; float result = 0; - unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); + unroll_loop( + [&result, &ar](int i) { result += ar.values[i]; }); return result; } @@ -230,23 +231,27 @@ struct FP32Vec8 : public Vec { return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); } - FP32Vec8 operator*(const FP32Vec8 &b) const { - return FP32Vec8({vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])}); + FP32Vec8 operator*(const FP32Vec8& b) const { + return FP32Vec8( + {vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])}); } - FP32Vec8 operator+(const FP32Vec8 &b) const { - return FP32Vec8({vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])}); + FP32Vec8 operator+(const FP32Vec8& b) const { + return FP32Vec8( + {vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])}); } - FP32Vec8 operator-(const FP32Vec8 &b) const { - return FP32Vec8({vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])}); + FP32Vec8 operator-(const FP32Vec8& b) const { + return FP32Vec8( + {vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])}); } - FP32Vec8 operator/(const FP32Vec8 &b) const { - return FP32Vec8({vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])}); + FP32Vec8 operator/(const FP32Vec8& b) const { + return FP32Vec8( + {vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])}); } - void save(float *ptr) const { + void save(float* ptr) const { vec_xst(reg.val[0], 0, ptr); vec_xst(reg.val[1], 16, ptr); } @@ -275,7 +280,7 @@ struct FP32Vec16 : public Vec { reg.val[3] = vec_splats(0.0f); } - explicit FP32Vec16(const float *ptr) { + explicit FP32Vec16(const float* ptr) { reg.val[0] = vec_xl(0, ptr); reg.val[1] = vec_xl(16, ptr); reg.val[2] = vec_xl(32, ptr); @@ -284,78 +289,76 @@ struct FP32Vec16 : public Vec { explicit FP32Vec16(f32x4x4_t data) : reg(data) {} - explicit FP32Vec16(const FP32Vec16 &data) { + explicit FP32Vec16(const FP32Vec16& data) { reg.val[0] = data.reg.val[0]; reg.val[1] = data.reg.val[1]; reg.val[2] = data.reg.val[2]; reg.val[3] = data.reg.val[3]; } - explicit FP32Vec16(const FP32Vec4 &data) { + explicit FP32Vec16(const FP32Vec4& data) { reg.val[0] = data.reg; reg.val[1] = data.reg; reg.val[2] = data.reg; reg.val[3] = data.reg; } - explicit FP32Vec16(const FP32Vec8 &data) { + explicit FP32Vec16(const FP32Vec8& data) { reg.val[0] = data.reg.val[0]; reg.val[1] = data.reg.val[1]; reg.val[2] = data.reg.val[0]; reg.val[3] = data.reg.val[1]; } - explicit FP32Vec16(const BF16Vec16 &v) { + explicit FP32Vec16(const BF16Vec16& v) { reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]); reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]); reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]); reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]); } - explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {} - FP32Vec16 operator*(const FP32Vec16 &b) const { - return FP32Vec16(f32x4x4_t({ - vec_mul(reg.val[0], b.reg.val[0]), - vec_mul(reg.val[1], b.reg.val[1]), - vec_mul(reg.val[2], b.reg.val[2]), - vec_mul(reg.val[3], b.reg.val[3])})); + FP32Vec16 operator*(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]), + vec_mul(reg.val[1], b.reg.val[1]), + vec_mul(reg.val[2], b.reg.val[2]), + vec_mul(reg.val[3], b.reg.val[3])})); } - FP32Vec16 operator+(const FP32Vec16 &b) const { - return FP32Vec16(f32x4x4_t({ - vec_add(reg.val[0], b.reg.val[0]), - vec_add(reg.val[1], b.reg.val[1]), - vec_add(reg.val[2], b.reg.val[2]), - vec_add(reg.val[3], b.reg.val[3])})); + FP32Vec16 operator+(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_add(reg.val[0], b.reg.val[0]), + vec_add(reg.val[1], b.reg.val[1]), + vec_add(reg.val[2], b.reg.val[2]), + vec_add(reg.val[3], b.reg.val[3])})); } - FP32Vec16 operator-(const FP32Vec16 &b) const { - return FP32Vec16(f32x4x4_t({ - vec_sub(reg.val[0], b.reg.val[0]), - vec_sub(reg.val[1], b.reg.val[1]), - vec_sub(reg.val[2], b.reg.val[2]), - vec_sub(reg.val[3], b.reg.val[3])})); + FP32Vec16 operator-(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_sub(reg.val[0], b.reg.val[0]), + vec_sub(reg.val[1], b.reg.val[1]), + vec_sub(reg.val[2], b.reg.val[2]), + vec_sub(reg.val[3], b.reg.val[3])})); } - FP32Vec16 operator/(const FP32Vec16 &b) const { - return FP32Vec16(f32x4x4_t({ - vec_div(reg.val[0], b.reg.val[0]), - vec_div(reg.val[1], b.reg.val[1]), - vec_div(reg.val[2], b.reg.val[2]), - vec_div(reg.val[3], b.reg.val[3])})); + FP32Vec16 operator/(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_div(reg.val[0], b.reg.val[0]), + vec_div(reg.val[1], b.reg.val[1]), + vec_div(reg.val[2], b.reg.val[2]), + vec_div(reg.val[3], b.reg.val[3])})); } float reduce_sum() const { AliasReg ar; ar.reg = reg; float result = 0; - unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); + unroll_loop( + [&result, &ar](int i) { result += ar.values[i]; }); return result; } - template float reduce_sub_sum(int idx) { + template + float reduce_sub_sum(int idx) { static_assert(VEC_ELEM_NUM % group_size == 0); AliasReg ar; @@ -368,7 +371,7 @@ struct FP32Vec16 : public Vec { return result; } - void save(float *ptr) const { + void save(float* ptr) const { vec_xst(reg.val[0], 0, ptr); vec_xst(reg.val[1], 16, ptr); vec_xst(reg.val[2], 32, ptr); @@ -376,43 +379,62 @@ struct FP32Vec16 : public Vec { } }; -template struct VecType { using vec_type = void; }; +template +struct VecType { + using vec_type = void; +}; -template using vec_t = typename VecType::vec_type; +template +using vec_t = typename VecType::vec_type; -template <> struct VecType { using vec_type = FP32Vec8; }; +template <> +struct VecType { + using vec_type = FP32Vec8; +}; -template <> struct VecType { using vec_type = BF16Vec8; }; +template <> +struct VecType { + using vec_type = BF16Vec8; +}; -template void storeFP32(float v, T *ptr) { *ptr = v; } +template +void storeFP32(float v, T* ptr) { + *ptr = v; +} -inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { +inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) { acc = acc + a * b; } -template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { - c10::BFloat16 __attribute__((__may_alias__)) *v_ptr = - reinterpret_cast(&v); +template <> +inline void storeFP32(float v, c10::BFloat16* ptr) { + c10::BFloat16 __attribute__((__may_alias__))* v_ptr = + reinterpret_cast(&v); *ptr = *(v_ptr + 1); } #ifndef __VEC_CLASS_FP_NAN -#define __VEC_CLASS_FP_NAN (1 << 6) + #define __VEC_CLASS_FP_NAN (1 << 6) #endif -const static __vector unsigned char omask = { 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29 }; +const static __vector unsigned char omask = {0, 1, 4, 5, 8, 9, 12, 13, + 16, 17, 20, 21, 24, 25, 28, 29}; #ifndef _ARCH_PWR10 -const static __vector unsigned int bias = { 0x00007fff, 0x00007fff, 0x00007fff, 0x00007fff }; -const static __vector unsigned int nan = { 0x7fc00000, 0x7fc00000, 0x7fc00000, 0x7fc00000 }; -const static __vector unsigned int sh16 = { 16, 16, 16, 16 }; -const static __vector unsigned int one = { 1, 1, 1, 1 }; +const static __vector unsigned int bias = {0x00007fff, 0x00007fff, 0x00007fff, + 0x00007fff}; +const static __vector unsigned int nan = {0x7fc00000, 0x7fc00000, 0x7fc00000, + 0x7fc00000}; +const static __vector unsigned int sh16 = {16, 16, 16, 16}; +const static __vector unsigned int one = {1, 1, 1, 1}; #endif -inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) { +inline BF16Vec8::BF16Vec8(const FP32Vec8& v) { #ifdef _ARCH_PWR10 __vector signed short ret[2]; - ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]); - ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]); + ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16( + (__vector unsigned char)v.reg.val[0]); + ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16( + (__vector unsigned char)v.reg.val[1]); reg = vec_perm(ret[0], ret[1], omask); #elif defined(_ARCH_PWR9) __vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]); @@ -425,8 +447,10 @@ inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) { __vector unsigned int rnd1 = vec_add(lsb1, bias); inp0 = vec_add(inp0, rnd0); inp1 = vec_add(inp1, rnd1); - __vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN); - __vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN); + __vector __bool int sel0 = + vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN); + __vector __bool int sel1 = + vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN); inp0 = vec_sel(inp0, nan, sel0); inp1 = vec_sel(inp1, nan, sel1); inp0 = vec_sr(inp0, sh16); @@ -435,13 +459,17 @@ inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) { #endif } -inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) { +inline BF16Vec16::BF16Vec16(const FP32Vec16& v) { #ifdef _ARCH_PWR10 __vector signed short ret[4]; - ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]); - ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]); - ret[2] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[2]); - ret[3] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[3]); + ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16( + (__vector unsigned char)v.reg.val[0]); + ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16( + (__vector unsigned char)v.reg.val[1]); + ret[2] = (__vector signed short)__builtin_vsx_xvcvspbf16( + (__vector unsigned char)v.reg.val[2]); + ret[3] = (__vector signed short)__builtin_vsx_xvcvspbf16( + (__vector unsigned char)v.reg.val[3]); reg.val[0] = vec_perm(ret[0], ret[1], omask); reg.val[1] = vec_perm(ret[2], ret[3], omask); #elif defined(_ARCH_PWR9) @@ -465,10 +493,14 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) { inp1 = vec_add(inp1, rnd1); inp2 = vec_add(inp2, rnd2); inp3 = vec_add(inp3, rnd3); - __vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN); - __vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN); - __vector __bool int sel2 = vec_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN); - __vector __bool int sel3 = vec_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN); + __vector __bool int sel0 = + vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN); + __vector __bool int sel1 = + vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN); + __vector __bool int sel2 = + vec_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN); + __vector __bool int sel3 = + vec_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN); inp0 = vec_sel(inp0, nan, sel0); inp1 = vec_sel(inp1, nan, sel1); inp2 = vec_sel(inp2, nan, sel2); @@ -482,10 +514,10 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) { #endif } -inline void prefetch(const void *addr) { +inline void prefetch(const void* addr) { __asm__ __volatile__("dcbt 0, %0" : : "r"(addr) : "memory"); } -}; // namespace vec_op +}; // namespace vec_op #endif diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp index 4bb4eb0f491ac..a4ef2be2a58ca 100644 --- a/csrc/cpu/cpu_types_x86.hpp +++ b/csrc/cpu/cpu_types_x86.hpp @@ -11,39 +11,40 @@ static_assert(false, "AVX2 must be supported for the current implementation."); namespace vec_op { -#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) -#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) #ifndef CPU_OP_GUARD -#define CPU_KERNEL_GUARD_IN(NAME) -#define CPU_KERNEL_GUARD_OUT(NAME) + #define CPU_KERNEL_GUARD_IN(NAME) + #define CPU_KERNEL_GUARD_OUT(NAME) #else -#define CPU_KERNEL_GUARD_IN(NAME) \ - RECORD_FUNCTION(#NAME, c10::ArrayRef({})); -#define CPU_KERNEL_GUARD_OUT(NAME) + #define CPU_KERNEL_GUARD_IN(NAME) \ + RECORD_FUNCTION(#NAME, c10::ArrayRef({})); + #define CPU_KERNEL_GUARD_OUT(NAME) #endif #define FORCE_INLINE __attribute__((always_inline)) inline namespace { template -constexpr void unroll_loop_item(std::integer_sequence, F &&f) { +constexpr void unroll_loop_item(std::integer_sequence, F&& f) { (f(std::integral_constant{}), ...); } -}; // namespace +}; // namespace template >> -constexpr void unroll_loop(F &&f) { +constexpr void unroll_loop(F&& f) { unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); } -template struct Vec { +template +struct Vec { constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } }; @@ -55,12 +56,12 @@ struct FP16Vec8 : public Vec { __m128i reg; - explicit FP16Vec8(const void *ptr) - : reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {} + explicit FP16Vec8(const void* ptr) + : reg((__m128i)_mm_loadu_si128((__m128i*)ptr)) {} - explicit FP16Vec8(const FP32Vec8 &); + explicit FP16Vec8(const FP32Vec8&); - void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; } + void save(void* ptr) const { *reinterpret_cast<__m128i*>(ptr) = reg; } }; struct FP16Vec16 : public Vec { @@ -68,12 +69,12 @@ struct FP16Vec16 : public Vec { __m256i reg; - explicit FP16Vec16(const void *ptr) - : reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {} + explicit FP16Vec16(const void* ptr) + : reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {} - explicit FP16Vec16(const FP32Vec16 &); + explicit FP16Vec16(const FP32Vec16&); - void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; } + void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; } void save(void* ptr, const int elem_num) const { constexpr uint32_t M = 0xFFFFFFFF; @@ -87,12 +88,12 @@ struct BF16Vec8 : public Vec { __m128i reg; - explicit BF16Vec8(const void *ptr) - : reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {} + explicit BF16Vec8(const void* ptr) + : reg((__m128i)_mm_loadu_si128((__m128i*)ptr)) {} - explicit BF16Vec8(const FP32Vec8 &); + explicit BF16Vec8(const FP32Vec8&); - void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; } + void save(void* ptr) const { *reinterpret_cast<__m128i*>(ptr) = reg; } }; struct BF16Vec16 : public Vec { @@ -100,12 +101,12 @@ struct BF16Vec16 : public Vec { __m256i reg; - explicit BF16Vec16(const void *ptr) - : reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {} + explicit BF16Vec16(const void* ptr) + : reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {} - explicit BF16Vec16(const FP32Vec16 &); + explicit BF16Vec16(const FP32Vec16&); - void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; } + void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; } void save(void* ptr, const int elem_num) const { constexpr uint32_t M = 0xFFFFFFFF; @@ -120,11 +121,11 @@ struct BF16Vec32 : public Vec { __m512i reg; - explicit BF16Vec32(const void *ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {} + explicit BF16Vec32(const void* ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {} explicit BF16Vec32(__m512i data) : reg(data) {} - explicit BF16Vec32(BF16Vec8 &vec8_data) + explicit BF16Vec32(BF16Vec8& vec8_data) : reg((__m512i)_mm512_inserti32x4( _mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512( (__m128i)vec8_data.reg), @@ -132,7 +133,7 @@ struct BF16Vec32 : public Vec { (__m128i)vec8_data.reg, 2), (__m128i)vec8_data.reg, 3)) {} - void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; } + void save(void* ptr) const { *reinterpret_cast<__m512i*>(ptr) = reg; } }; #else struct BF16Vec32 : public Vec { @@ -141,24 +142,24 @@ struct BF16Vec32 : public Vec { __m256i reg_low; __m256i reg_high; - explicit BF16Vec32(const void *ptr) - : reg_low(_mm256_loadu_si256((__m256i const *)ptr)), - reg_high(_mm256_loadu_si256((__m256i const *)ptr + 1)) {} + explicit BF16Vec32(const void* ptr) + : reg_low(_mm256_loadu_si256((__m256i const*)ptr)), + reg_high(_mm256_loadu_si256((__m256i const*)ptr + 1)) {} - explicit BF16Vec32(__m256i low, __m256i high) : reg_low(low), - reg_high(high) {} + explicit BF16Vec32(__m256i low, __m256i high) + : reg_low(low), reg_high(high) {} - explicit BF16Vec32(BF16Vec8 &vec8_data) + explicit BF16Vec32(BF16Vec8& vec8_data) : reg_low((__m256i)_mm256_inserti32x4( - _mm256_castsi128_si256((__m128i)vec8_data.reg), - (__m128i)vec8_data.reg, 1)), + _mm256_castsi128_si256((__m128i)vec8_data.reg), + (__m128i)vec8_data.reg, 1)), reg_high((__m256i)_mm256_inserti32x4( - _mm256_castsi128_si256((__m128i)vec8_data.reg), - (__m128i)vec8_data.reg, 1)) {} + _mm256_castsi128_si256((__m128i)vec8_data.reg), + (__m128i)vec8_data.reg, 1)) {} - void save(void *ptr) const { - *reinterpret_cast<__m256i *>(ptr) = reg_low; - *reinterpret_cast<__m256i *>((__m256i *)ptr + 1) = reg_high; + void save(void* ptr) const { + *reinterpret_cast<__m256i*>(ptr) = reg_low; + *reinterpret_cast<__m256i*>((__m256i*)ptr + 1) = reg_high; } }; #endif @@ -176,11 +177,11 @@ struct FP32Vec4 : public Vec { explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {} - explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {} + explicit FP32Vec4(const float* ptr) : reg(_mm_loadu_ps(ptr)) {} explicit FP32Vec4(__m128 data) : reg(data) {} - explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {} + explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {} }; struct FP32Vec8 : public Vec { @@ -196,15 +197,15 @@ struct FP32Vec8 : public Vec { explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {} - explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {} + explicit FP32Vec8(const float* ptr) : reg(_mm256_loadu_ps(ptr)) {} explicit FP32Vec8(__m256 data) : reg(data) {} - explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {} + explicit FP32Vec8(const FP32Vec8& data) : reg(data.reg) {} - explicit FP32Vec8(const FP16Vec8 &v) : reg(_mm256_cvtph_ps(v.reg)) {} + explicit FP32Vec8(const FP16Vec8& v) : reg(_mm256_cvtph_ps(v.reg)) {} - explicit FP32Vec8(const BF16Vec8 &v) + explicit FP32Vec8(const BF16Vec8& v) : reg(_mm256_castsi256_ps( _mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {} @@ -212,7 +213,8 @@ struct FP32Vec8 : public Vec { AliasReg ar; ar.reg = reg; float result = 0; - unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); + unroll_loop( + [&result, &ar](int i) { result += ar.values[i]; }); return result; } @@ -244,27 +246,27 @@ struct FP32Vec8 : public Vec { erf(ar.values[1]), erf(ar.values[0]))); } - FP32Vec8 operator*(const FP32Vec8 &b) const { + FP32Vec8 operator*(const FP32Vec8& b) const { return FP32Vec8(_mm256_mul_ps(reg, b.reg)); } - FP32Vec8 operator+(const FP32Vec8 &b) const { + FP32Vec8 operator+(const FP32Vec8& b) const { return FP32Vec8(_mm256_add_ps(reg, b.reg)); } - FP32Vec8 operator-(const FP32Vec8 &b) const { + FP32Vec8 operator-(const FP32Vec8& b) const { return FP32Vec8(_mm256_sub_ps(reg, b.reg)); } - FP32Vec8 operator/(const FP32Vec8 &b) const { + FP32Vec8 operator/(const FP32Vec8& b) const { return FP32Vec8(_mm256_div_ps(reg, b.reg)); } - void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); } + void save(float* ptr) const { _mm256_storeu_ps(ptr, reg); } }; #ifdef __AVX512F__ -struct INT32Vec16: public Vec { +struct INT32Vec16 : public Vec { constexpr static int VEC_ELEM_NUM = 16; union AliasReg { __m512i reg; @@ -272,12 +274,11 @@ struct INT32Vec16: public Vec { }; __m512i reg; - - explicit INT32Vec16(const void* data_ptr) : reg(_mm512_loadu_epi32(data_ptr)) {} - void save(int32_t* ptr) const { - _mm512_storeu_epi32(ptr, reg); - } + explicit INT32Vec16(const void* data_ptr) + : reg(_mm512_loadu_epi32(data_ptr)) {} + + void save(int32_t* ptr) const { _mm512_storeu_epi32(ptr, reg); } void save(int32_t* ptr, const int elem_num) const { constexpr uint32_t M = 0xFFFFFFFF; @@ -301,11 +302,11 @@ struct FP32Vec16 : public Vec { explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {} - explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {} + explicit FP32Vec16(const float* ptr) : reg(_mm512_loadu_ps(ptr)) {} explicit FP32Vec16(__m512 data) : reg(data) {} - explicit FP32Vec16(const FP32Vec4 &data) + explicit FP32Vec16(const FP32Vec4& data) : reg((__m512)_mm512_inserti32x4( _mm512_inserti32x4( _mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg), @@ -313,36 +314,37 @@ struct FP32Vec16 : public Vec { (__m128i)data.reg, 2), (__m128i)data.reg, 3)) {} - explicit FP32Vec16(const FP32Vec8 &data) + explicit FP32Vec16(const FP32Vec8& data) : reg((__m512)_mm512_inserti32x8( _mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {} - explicit FP32Vec16(const BF16Vec16 &v) + explicit FP32Vec16(const BF16Vec16& v) : reg(_mm512_castsi512_ps( _mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {} - explicit FP32Vec16(const FP16Vec16 &v) : reg(_mm512_cvtph_ps(v.reg)) {} + explicit FP32Vec16(const FP16Vec16& v) : reg(_mm512_cvtph_ps(v.reg)) {} - explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v)) {} - explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {} - explicit FP32Vec16(const INT32Vec16 &v) - : reg(_mm512_cvt_roundepi32_ps(v.reg, _MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC)) {} + explicit FP32Vec16(const INT32Vec16& v) + : reg(_mm512_cvt_roundepi32_ps( + v.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {} - FP32Vec16 operator*(const FP32Vec16 &b) const { + FP32Vec16 operator*(const FP32Vec16& b) const { return FP32Vec16(_mm512_mul_ps(reg, b.reg)); } - FP32Vec16 operator+(const FP32Vec16 &b) const { + FP32Vec16 operator+(const FP32Vec16& b) const { return FP32Vec16(_mm512_add_ps(reg, b.reg)); } - FP32Vec16 operator-(const FP32Vec16 &b) const { + FP32Vec16 operator-(const FP32Vec16& b) const { return FP32Vec16(_mm512_sub_ps(reg, b.reg)); } - FP32Vec16 operator/(const FP32Vec16 &b) const { + FP32Vec16 operator/(const FP32Vec16& b) const { return FP32Vec16(_mm512_div_ps(reg, b.reg)); } @@ -370,9 +372,7 @@ struct FP32Vec16 : public Vec { return FP32Vec16(_mm512_mask_min_ps(reg, mask, reg, b.reg)); } - FP32Vec16 abs() const { - return FP32Vec16(_mm512_abs_ps(reg)); - } + FP32Vec16 abs() const { return FP32Vec16(_mm512_abs_ps(reg)); } float reduce_sum() const { return _mm512_reduce_add_ps(reg); } @@ -380,14 +380,15 @@ struct FP32Vec16 : public Vec { float reduce_min() const { return _mm512_reduce_min_ps(reg); } - template float reduce_sub_sum(int idx) { + template + float reduce_sub_sum(int idx) { static_assert(VEC_ELEM_NUM % group_size == 0); constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); __mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size)); return _mm512_mask_reduce_add_ps(mask, reg); } - void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); } + void save(float* ptr) const { _mm512_storeu_ps(ptr, reg); } void save(float* ptr, const int elem_num) const { constexpr uint32_t M = 0xFFFFFFFF; @@ -407,32 +408,30 @@ struct FP32Vec16 : public Vec { __m256 reg_low; __m256 reg_high; - explicit FP32Vec16(float v) : reg_low(_mm256_set1_ps(v)), - reg_high(_mm256_set1_ps(v)) {} + explicit FP32Vec16(float v) + : reg_low(_mm256_set1_ps(v)), reg_high(_mm256_set1_ps(v)) {} - explicit FP32Vec16() : reg_low(_mm256_set1_ps(0.0)), - reg_high(_mm256_set1_ps(0.0)) {} + explicit FP32Vec16() + : reg_low(_mm256_set1_ps(0.0)), reg_high(_mm256_set1_ps(0.0)) {} - explicit FP32Vec16(const float *ptr) : reg_low(_mm256_loadu_ps(ptr)), - reg_high(_mm256_loadu_ps(ptr + 8)) {} + explicit FP32Vec16(const float* ptr) + : reg_low(_mm256_loadu_ps(ptr)), reg_high(_mm256_loadu_ps(ptr + 8)) {} explicit FP32Vec16(__m256 low, __m256 high) : reg_low(low), reg_high(high) {} - explicit FP32Vec16(const FP32Vec16 &data) : reg_low(data.reg_low), - reg_high(data.reg_high) {} + explicit FP32Vec16(const FP32Vec16& data) + : reg_low(data.reg_low), reg_high(data.reg_high) {} - explicit FP32Vec16(const FP32Vec4 &data) + explicit FP32Vec16(const FP32Vec4& data) : reg_low((__m256)_mm256_inserti128_si256( - _mm256_castsi128_si256((__m128i)data.reg), - (__m128i)data.reg, 1)), + _mm256_castsi128_si256((__m128i)data.reg), (__m128i)data.reg, 1)), reg_high((__m256)_mm256_inserti128_si256( - _mm256_castsi128_si256((__m128i)data.reg), - (__m128i)data.reg, 1)) {} + _mm256_castsi128_si256((__m128i)data.reg), (__m128i)data.reg, 1)) {} - explicit FP32Vec16(const FP32Vec8 &data) + explicit FP32Vec16(const FP32Vec8& data) : reg_low(data.reg), reg_high(data.reg) {} - explicit FP32Vec16(const FP16Vec16 &v) { + explicit FP32Vec16(const FP16Vec16& v) { __m128i low = _mm256_extractf128_si256(v.reg, 0); __m128i high = _mm256_extractf128_si256(v.reg, 1); @@ -440,9 +439,9 @@ struct FP32Vec16 : public Vec { reg_high = _mm256_cvtph_ps(high); } - explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v)) {} - explicit FP32Vec16(const BF16Vec16 &v) { + explicit FP32Vec16(const BF16Vec16& v) { __m128i low = _mm256_extractf128_si256(v.reg, 0); __m128i high = _mm256_extractf128_si256(v.reg, 1); @@ -456,24 +455,24 @@ struct FP32Vec16 : public Vec { reg_high = _mm256_castsi256_ps(v_high_shifted); } - explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {} - FP32Vec16 operator*(const FP32Vec16 &b) const { + FP32Vec16 operator*(const FP32Vec16& b) const { return FP32Vec16(_mm256_mul_ps(reg_low, b.reg_low), _mm256_mul_ps(reg_high, b.reg_high)); } - FP32Vec16 operator+(const FP32Vec16 &b) const { + FP32Vec16 operator+(const FP32Vec16& b) const { return FP32Vec16(_mm256_add_ps(reg_low, b.reg_low), _mm256_add_ps(reg_high, b.reg_high)); } - FP32Vec16 operator-(const FP32Vec16 &b) const { + FP32Vec16 operator-(const FP32Vec16& b) const { return FP32Vec16(_mm256_sub_ps(reg_low, b.reg_low), _mm256_sub_ps(reg_high, b.reg_high)); } - FP32Vec16 operator/(const FP32Vec16 &b) const { + FP32Vec16 operator/(const FP32Vec16& b) const { return FP32Vec16(_mm256_div_ps(reg_low, b.reg_low), _mm256_div_ps(reg_high, b.reg_high)); } @@ -484,7 +483,8 @@ struct FP32Vec16 : public Vec { return low.reduce_sum() + high.reduce_sum(); } - template float reduce_sub_sum(int idx) { + template + float reduce_sub_sum(int idx) { float sum = 0.0; static_assert(VEC_ELEM_NUM % group_size == 0); constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); @@ -507,7 +507,7 @@ struct FP32Vec16 : public Vec { return sum; } - void save(float *ptr) const { + void save(float* ptr) const { _mm256_storeu_ps(ptr, reg_low); _mm256_storeu_ps(ptr + 8, reg_high); } @@ -515,7 +515,7 @@ struct FP32Vec16 : public Vec { #endif #ifdef __AVX512F__ -struct INT8Vec16: public Vec { +struct INT8Vec16 : public Vec { constexpr static int VEC_ELEM_NUM = 16; union AliasReg { __m128i reg; @@ -523,14 +523,12 @@ struct INT8Vec16: public Vec { }; __m128i reg; - - explicit INT8Vec16(const FP32Vec16& vec) : reg( - _mm512_cvtepi32_epi8(_mm512_cvt_roundps_epi32(vec.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) - ) {} - void save(int8_t* ptr) const { - _mm_storeu_epi8(ptr, reg); - } + explicit INT8Vec16(const FP32Vec16& vec) + : reg(_mm512_cvtepi32_epi8(_mm512_cvt_roundps_epi32( + vec.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))) {} + + void save(int8_t* ptr) const { _mm_storeu_epi8(ptr, reg); } void save(int8_t* ptr, const int elem_num) const { constexpr uint32_t M = 0xFFFFFFFF; @@ -540,71 +538,92 @@ struct INT8Vec16: public Vec { }; #endif -template struct VecType { using vec_type = void; }; +template +struct VecType { + using vec_type = void; +}; -template using vec_t = typename VecType::vec_type; +template +using vec_t = typename VecType::vec_type; -template <> struct VecType { using vec_type = FP32Vec8; }; +template <> +struct VecType { + using vec_type = FP32Vec8; +}; -template <> struct VecType { using vec_type = FP16Vec8; }; +template <> +struct VecType { + using vec_type = FP16Vec8; +}; -template <> struct VecType { using vec_type = BF16Vec8; }; +template <> +struct VecType { + using vec_type = BF16Vec8; +}; -template void storeFP32(float v, T *ptr) { *ptr = v; } +template +void storeFP32(float v, T* ptr) { + *ptr = v; +} -inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { +inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) { acc = acc + a * b; } -template <> inline void storeFP32(float v, c10::Half *ptr) { - *reinterpret_cast(ptr) = +template <> +inline void storeFP32(float v, c10::Half* ptr) { + *reinterpret_cast(ptr) = _cvtss_sh(v, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); } -inline FP16Vec8::FP16Vec8(const FP32Vec8 &v) +inline FP16Vec8::FP16Vec8(const FP32Vec8& v) : reg(_mm256_cvtps_ph(v.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {} #ifdef __AVX512F__ -inline FP16Vec16::FP16Vec16(const FP32Vec16 &v) +inline FP16Vec16::FP16Vec16(const FP32Vec16& v) : reg(_mm512_cvtps_ph(v.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {} #else -inline FP16Vec16::FP16Vec16(const FP32Vec16 &v) - : reg(_mm256_insertf128_si256(_mm256_castsi128_si256(FP16Vec8(FP32Vec8(v.reg_low)).reg), FP16Vec8(FP32Vec8(v.reg_low)).reg, 1)) {} +inline FP16Vec16::FP16Vec16(const FP32Vec16& v) + : reg(_mm256_insertf128_si256( + _mm256_castsi128_si256(FP16Vec8(FP32Vec8(v.reg_low)).reg), + FP16Vec8(FP32Vec8(v.reg_low)).reg, 1)) {} #endif #ifdef __AVX512BF16__ -template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { - *reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v); +template <> +inline void storeFP32(float v, c10::BFloat16* ptr) { + *reinterpret_cast<__bfloat16*>(ptr) = _mm_cvtness_sbh(v); } -inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) +inline BF16Vec8::BF16Vec8(const FP32Vec8& v) : reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {} -inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) +inline BF16Vec16::BF16Vec16(const FP32Vec16& v) : reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {} -inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) { +inline void fma(FP32Vec16& acc, BF16Vec32& a, BF16Vec32& b) { acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg); } #else -template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { - c10::BFloat16 __attribute__((__may_alias__)) *v_ptr = - reinterpret_cast(&v); +template <> +inline void storeFP32(float v, c10::BFloat16* ptr) { + c10::BFloat16 __attribute__((__may_alias__))* v_ptr = + reinterpret_cast(&v); *ptr = *(v_ptr + 1); } -#ifdef __AVX512F__ -inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) + #ifdef __AVX512F__ +inline BF16Vec8::BF16Vec8(const FP32Vec8& v) : reg(_mm256_cvtepi32_epi16( _mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {} -inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) +inline BF16Vec16::BF16Vec16(const FP32Vec16& v) : reg(_mm512_cvtepi32_epi16( _mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {} -#else -namespace{ + #else +namespace { __m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) { __m256i ai = _mm256_castps_si256(a); ai = _mm256_srli_epi32(ai, 16); @@ -612,21 +631,21 @@ __m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) { ai = _mm256_permute4x64_epi64(ai, 0b00111001); return _mm256_extracti128_si256(ai, 0); } -} +} // namespace -inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) +inline BF16Vec8::BF16Vec8(const FP32Vec8& v) : reg(FP32Vec8_to_BF16Vec8_avx2(v.reg)) {} -inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) { +inline BF16Vec16::BF16Vec16(const FP32Vec16& v) { BF16Vec8 low = BF16Vec8(FP32Vec8(v.reg_low)); BF16Vec8 high = BF16Vec8(FP32Vec8(v.reg_high)); reg = _mm256_insertf128_si256(_mm256_castsi128_si256(low.reg), high.reg, 1); } -#endif // __AVX512F__ -#endif // __AVX512BF16__ + #endif // __AVX512F__ +#endif // __AVX512BF16__ -inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); } +inline void prefetch(const void* addr) { _mm_prefetch(addr, _MM_HINT_T1); } -}; // namespace vec_op +}; // namespace vec_op #endif diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp index d9aed657a3113..33b1637832888 100644 --- a/csrc/cpu/quant.cpp +++ b/csrc/cpu/quant.cpp @@ -359,7 +359,7 @@ void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major const torch::Tensor& b, // [IC, OC], column-major const torch::Tensor& a_scales, // [1] or [M] const torch::Tensor& b_scales, // [1] or [OC] - const c10::optional& bias // [OC] + const std::optional& bias // [OC] ) { CPU_KERNEL_GUARD_IN(cutlass_scaled_mm) // Checks for conformality @@ -442,8 +442,8 @@ void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major const torch::Tensor& a_scales, // [1] or [M] const torch::Tensor& b_scales, // [1] or [OC] const torch::Tensor& azp_adj, // [OC] - const c10::optional& azp, // [1] or [M] - const c10::optional& bias // [OC] + const std::optional& azp, // [1] or [M] + const std::optional& bias // [OC] ) { CPU_KERNEL_GUARD_IN(cutlass_scaled_mm_azp) // Checks for conformality @@ -561,7 +561,7 @@ void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] const torch::Tensor& input, // [..., hidden_size] const torch::Tensor& scale, - c10::optional const& azp) { + std::optional const& azp) { CPU_KERNEL_GUARD_IN(static_scaled_int8_quant) TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); @@ -590,7 +590,7 @@ void dynamic_scaled_int8_quant( torch::Tensor& out, // [..., hidden_size] const torch::Tensor& input, // [..., hidden_size] torch::Tensor& scale, // [..., 1] - c10::optional const& azp) { + std::optional const& azp) { CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant) TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 03beefbc6de7d..5d1c5f4c83d3e 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -9,14 +9,14 @@ std::string init_cpu_threads_env(const std::string& cpu_ids); void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& a_scales, const torch::Tensor& b_scales, - const c10::optional& bias); + const std::optional& bias); void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& a_scales, const torch::Tensor& b_scales, const torch::Tensor& azp_adj, - const c10::optional& azp, - const c10::optional& bias); + const std::optional& azp, + const std::optional& bias); TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops @@ -30,7 +30,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float k_scale, float v_scale," + " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); @@ -44,7 +44,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float k_scale, float v_scale," + " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); @@ -148,7 +148,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor! key_cache, Tensor! value_cache," " Tensor slot_mapping," " str kv_cache_dtype," - " float k_scale, float v_scale) -> ()"); + " Tensor k_scale, Tensor v_scale) -> ()"); cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache); } diff --git a/csrc/cpu/utils.cpp b/csrc/cpu/utils.cpp index 1138a55df2f05..42a1c1d924bac 100644 --- a/csrc/cpu/utils.cpp +++ b/csrc/cpu/utils.cpp @@ -1,10 +1,22 @@ -#include -#include -#include -#include +#ifndef VLLM_NUMA_DISABLED + #include + #include + #include + #include +#endif #include "cpu_types.hpp" +#ifdef VLLM_NUMA_DISABLED +std::string init_cpu_threads_env(const std::string& cpu_ids) { + return std::string( + "Warning: NUMA is not enabled in this build. `init_cpu_threads_env` has " + "no effect to setup thread affinity."); +} + +#endif + +#ifndef VLLM_NUMA_DISABLED std::string init_cpu_threads_env(const std::string& cpu_ids) { bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str()); TORCH_CHECK(omp_cpu_mask->size > 0); @@ -57,7 +69,7 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) { omp_lock_t writelock; omp_init_lock(&writelock); -#pragma omp parallel for schedule(static, 1) + #pragma omp parallel for schedule(static, 1) for (size_t i = 0; i < omp_cpu_ids.size(); ++i) { cpu_set_t mask; CPU_ZERO(&mask); @@ -88,3 +100,4 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) { return ss.str(); } +#endif \ No newline at end of file diff --git a/csrc/cumem_allocator.cpp b/csrc/cumem_allocator.cpp new file mode 100644 index 0000000000000..e8555d853b7ac --- /dev/null +++ b/csrc/cumem_allocator.cpp @@ -0,0 +1,310 @@ +// A CUDAPluggableAllocator based on cumem* APIs. +// Important: allocation size, CUdeviceptr and CUmemGenericAllocationHandle* +// need to be unsigned long long +#include + +extern "C" { + +#define PY_SSIZE_T_CLEAN +#include + +#include +#include +#include + +#define CUDA_CHECK(condition) \ + do { \ + CUresult error = condition; \ + if (error != 0) { \ + char* error_string; \ + cuGetErrorString(error, (const char**)&error_string); \ + std::cerr << "CUDA Error: " << error_string << " at " << __FILE__ << ":" \ + << __LINE__ << std::endl; \ + } \ + } while (0) + +// Global references to Python callables +// NOTE: this is borrowed reference, so we don't need to DECREF them. +// This brings the limitation that the allocator needs to be singleton. +static PyObject* g_python_malloc_callback = nullptr; +static PyObject* g_python_free_callback = nullptr; + +// --------------------------------------------------------------------------- +// Helper functions: + +void ensure_context(unsigned long long device) { + CUcontext pctx; + CUDA_CHECK(cuCtxGetCurrent(&pctx)); + if (!pctx) { + // Ensure device context. + CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device)); + CUDA_CHECK(cuCtxSetCurrent(pctx)); + } +} + +void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem, + CUmemGenericAllocationHandle* p_memHandle) { + ensure_context(device); + // Define memory allocation properties + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = device; + prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE; + + // Allocate memory using cuMemCreate + CUDA_CHECK(cuMemCreate(p_memHandle, size, &prop, 0)); + CUDA_CHECK(cuMemMap(d_mem, size, 0, *p_memHandle, 0)); + + CUmemAccessDesc accessDesc = {}; + accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + accessDesc.location.id = device; + accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + + CUDA_CHECK(cuMemSetAccess(d_mem, size, &accessDesc, 1)); + // std::cout << "create_and_map: device=" << device << ", size=" << size << ", + // d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl; +} + +void unmap_and_release(unsigned long long device, ssize_t size, + CUdeviceptr d_mem, + CUmemGenericAllocationHandle* p_memHandle) { + // std::cout << "unmap_and_release: device=" << device << ", size=" << size << + // ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl; + ensure_context(device); + CUDA_CHECK(cuMemUnmap(d_mem, size)); + CUDA_CHECK(cuMemRelease(*p_memHandle)); +} + +PyObject* create_tuple_from_c_integers(unsigned long long a, + unsigned long long b, + unsigned long long c, + unsigned long long d) { + // Create a new tuple of size 4 + PyObject* tuple = PyTuple_New(4); + if (!tuple) { + return NULL; // Return NULL on failure + } + + // Convert integers to Python objects and set them in the tuple + PyTuple_SetItem( + tuple, 0, + PyLong_FromUnsignedLongLong(a)); // Steals reference to the PyLong + PyTuple_SetItem(tuple, 1, PyLong_FromUnsignedLongLong(b)); + PyTuple_SetItem(tuple, 2, PyLong_FromUnsignedLongLong(c)); + PyTuple_SetItem(tuple, 3, PyLong_FromUnsignedLongLong(d)); + + // Note: PyTuple_SetItem "steals" a reference to each object, + // so we do not need to Py_DECREF the PyLong objects explicitly. + + return tuple; // Return the created tuple +} + +// --------------------------------------------------------------------------- +// Our exported C functions that call Python: + +// use CUstream instead of cudaStream_t, to avoid including cuda_runtime_api.h +void* my_malloc(ssize_t size, int device, CUstream stream) { + ensure_context(device); + + // first allocation, align the size, and reserve an address, and also allocate + // a CUmemGenericAllocationHandle + + // Define memory allocation properties + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = device; + prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE; + + // Check if the allocation is supported + size_t granularity; + CUDA_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, + CU_MEM_ALLOC_GRANULARITY_MINIMUM)); + + size_t alignedSize = ((size + granularity - 1) / granularity) * granularity; + + CUdeviceptr d_mem; + CUDA_CHECK(cuMemAddressReserve(&d_mem, alignedSize, 0, 0, 0)); + + // allocate the CUmemGenericAllocationHandle + CUmemGenericAllocationHandle* p_memHandle = + (CUmemGenericAllocationHandle*)malloc( + sizeof(CUmemGenericAllocationHandle)); + + if (!g_python_malloc_callback) { + std::cerr << "ERROR: g_python_malloc_callback not set.\n"; + return nullptr; + } + + // Acquire GIL (not in stable ABI officially, but often works) + PyGILState_STATE gstate = PyGILState_Ensure(); + + PyObject* arg_tuple = create_tuple_from_c_integers( + (unsigned long long)device, (unsigned long long)alignedSize, + (unsigned long long)d_mem, (unsigned long long)p_memHandle); + + // Call g_python_malloc_callback + PyObject* py_result = + PyObject_CallFunctionObjArgs(g_python_malloc_callback, arg_tuple, NULL); + Py_DECREF(arg_tuple); + + if (!py_result) { + PyErr_Print(); + PyGILState_Release(gstate); + return nullptr; + } + + PyGILState_Release(gstate); + + // do the final mapping + create_and_map(device, alignedSize, d_mem, p_memHandle); + + return (void*)d_mem; +} + +// use CUstream instead of cudaStream_t, to avoid including cuda_runtime_api.h +void my_free(void* ptr, ssize_t size, int device, CUstream stream) { + // get memory handle from the pointer + if (!g_python_free_callback) { + std::cerr << "ERROR: g_python_free_callback not set.\n"; + return; + } + + // Acquire GIL (not in stable ABI officially, but often works) + PyGILState_STATE gstate = PyGILState_Ensure(); + + PyObject* py_ptr = + PyLong_FromUnsignedLongLong(reinterpret_cast(ptr)); + + PyObject* py_result = + PyObject_CallFunctionObjArgs(g_python_free_callback, py_ptr, NULL); + + if (!py_result || !PyTuple_Check(py_result) || PyTuple_Size(py_result) != 4) { + PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4"); + return; + } + + unsigned long long recv_device, recv_size; + unsigned long long recv_d_mem, recv_p_memHandle; + // Unpack the tuple into four C integers + if (!PyArg_ParseTuple(py_result, "KKKK", &recv_device, &recv_size, + &recv_d_mem, &recv_p_memHandle)) { + // PyArg_ParseTuple sets an error if it fails + return; + } + + PyGILState_Release(gstate); + + // recv_size == size + // recv_device == device + + // Free memory + + CUdeviceptr d_mem = (CUdeviceptr)recv_d_mem; + CUmemGenericAllocationHandle* p_memHandle = + (CUmemGenericAllocationHandle*)recv_p_memHandle; + unmap_and_release(device, size, d_mem, p_memHandle); + + // free address and the handle + CUDA_CHECK(cuMemAddressFree(d_mem, size)); + free(p_memHandle); +} + +// --------------------------------------------------------------------------- +// Python extension boilerplate: + +// Python-exposed function: init_module(python_malloc, python_free) +static PyObject* py_init_module(PyObject* self, PyObject* args) { + PyObject* malloc_callback = nullptr; + PyObject* free_callback = nullptr; + + if (!PyArg_ParseTuple(args, "OO", &malloc_callback, &free_callback)) { + return nullptr; + } + + if (!PyCallable_Check(malloc_callback) || !PyCallable_Check(free_callback)) { + PyErr_SetString(PyExc_TypeError, "Both arguments must be callables"); + return nullptr; + } + + // Save the Python callables + // This module does not handle GC of these objects, so they must be kept alive + // outside of this module. + g_python_malloc_callback = malloc_callback; + g_python_free_callback = free_callback; + + Py_RETURN_NONE; +} + +static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) { + if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) { + PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4"); + return nullptr; + } + + unsigned long long recv_device, recv_size; + unsigned long long recv_d_mem, recv_p_memHandle; + // Unpack the tuple into four C integers + if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem, + &recv_p_memHandle)) { + // PyArg_ParseTuple sets an error if it fails + return nullptr; + } + + CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem; + CUmemGenericAllocationHandle* p_memHandle = + (CUmemGenericAllocationHandle*)recv_p_memHandle; + + unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle); + + Py_RETURN_NONE; +} + +static PyObject* python_create_and_map(PyObject* self, PyObject* args) { + if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) { + PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4"); + return nullptr; + } + + unsigned long long recv_device, recv_size; + unsigned long long recv_d_mem, recv_p_memHandle; + // Unpack the tuple into four C integers + if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem, + &recv_p_memHandle)) { + // PyArg_ParseTuple sets an error if it fails + return nullptr; + } + + CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem; + CUmemGenericAllocationHandle* p_memHandle = + (CUmemGenericAllocationHandle*)recv_p_memHandle; + + create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle); + + Py_RETURN_NONE; +} + +static PyMethodDef module_methods[] = { + {"init_module", (PyCFunction)py_init_module, METH_VARARGS, + "Initialize module with python_malloc and python_free callables."}, + {"python_create_and_map", (PyCFunction)python_create_and_map, METH_VARARGS, + "Create and map memory on the device."}, + {"python_unmap_and_release", (PyCFunction)python_unmap_and_release, + METH_VARARGS, "Unmap and release memory on the device."}, + {NULL, NULL, 0, NULL} // sentinel +}; + +static struct PyModuleDef cumem_allocator_module = { + PyModuleDef_HEAD_INIT, "cumem_allocator", + "cumem-based allocator for CUDAPluggableAllocator", -1, module_methods}; + +PyMODINIT_FUNC PyInit_cumem_allocator(void) { + // Initialize the module + PyObject* module = PyModule_Create(&cumem_allocator_module); + if (!module) { + return NULL; + } + return module; +} +} // extern "C" diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 6be4d4f2b2eb8..b9df4ed160b03 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -38,9 +38,13 @@ struct Signal { alignas(128) FlagType peer_counter[2][kMaxBlocks][8]; }; -struct __align__(16) RankData { const void* __restrict__ ptrs[8]; }; +struct __align__(16) RankData { + const void* __restrict__ ptrs[8]; +}; -struct __align__(16) RankSignals { Signal* signals[8]; }; +struct __align__(16) RankSignals { + Signal* signals[8]; +}; // like std::array, but aligned template diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp index 85e359aa57113..febc4eccd9561 100644 --- a/csrc/cutlass_extensions/common.hpp +++ b/csrc/cutlass_extensions/common.hpp @@ -27,9 +27,25 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { int max_shared_mem_per_block_opt_in = 0; cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, - cudaDevAttrMaxSharedMemoryPerBlockOptin, - device); + cudaDevAttrMaxSharedMemoryPerBlockOptin, device); return max_shared_mem_per_block_opt_in; } int32_t get_sm_version_num(); + +/** + * A wrapper for a kernel that is used to guard against compilation on + * architectures that will never use the kernel. The purpose of this is to + * reduce the size of the compiled binary. + * __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef + * into code that will be executed on the device where it is defined. + */ +template +struct enable_sm90_or_later : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 + Kernel::operator()(std::forward(args)...); +#endif + } +}; \ No newline at end of file diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp index 26f7423fd7455..ef413e6dd75c5 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp @@ -68,7 +68,7 @@ struct ScaledEpilogueBase { // This overload handles the case where there might not be a tensor, in which // case a nullptr is passed and a constant (0) is used. template - static auto args_from_tensor(c10::optional const& tensor) { + static auto args_from_tensor(std::optional const& tensor) { static_assert(std::is_same_v>); using Arguments = typename Descriptor::Arguments; auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; @@ -223,7 +223,7 @@ struct ScaledEpilogueBiasAzp static ArgumentType prepare_args(torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& azp_adj, - c10::optional const& bias) { + std::optional const& bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); @@ -301,7 +301,7 @@ struct ScaledEpilogueBiasAzpToken torch::Tensor const& b_scales, torch::Tensor const& azp_adj, torch::Tensor const& azp, - c10::optional const& bias) { + std::optional const& bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp index c723adf126422..c590c66a66652 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -67,7 +67,7 @@ struct ScaledEpilogueBase { // This overload handles the case where there might not be a tensor, in which // case a nullptr is passed and a constant (0) is used. template - static auto args_from_tensor(c10::optional const& tensor) { + static auto args_from_tensor(std::optional const& tensor) { using Arguments = typename Descriptor::Arguments; auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; static_assert(std::is_same_v> || @@ -223,7 +223,7 @@ struct ScaledEpilogueBiasAzp static ArgumentType prepare_args(torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& azp_adj, - c10::optional const& bias) { + std::optional const& bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); @@ -299,7 +299,7 @@ struct ScaledEpilogueBiasAzpToken torch::Tensor const& b_scales, torch::Tensor const& azp_adj, torch::Tensor const& azp, - c10::optional const& bias) { + std::optional const& bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); diff --git a/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp b/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp new file mode 100644 index 0000000000000..ec75c29e54f4d --- /dev/null +++ b/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp @@ -0,0 +1,123 @@ +// Modified from: cutlass/gemm/collective/builders/sm90_gmma_builder.inl +// clang-format off +#pragma once + +#include "cutlass/gemm/collective/builders/sm90_gmma_builder.inl" + +#include "cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp" + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_TMA_WS_SS (BlockScaled Builders) +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + int ScaleGranularityM +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum, + cute::enable_if_t< + not detail::is_use_rmem_A()> +> { + using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum; + + static_assert(is_static::value); + static_assert(is_static::value); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + static_assert(detail::is_aligned(), + "Should meet TMA alignment requirement\n"); + + static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v); + static constexpr bool IsFP8Input = detail::is_input_fp8(); + static_assert((!IsFP8Input || !IsArrayOfPointersGemm), + "KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now."); + + // For fp32 types, map to tf32 MMA value type + using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; + using ElementBMma = cute::conditional_t, tfloat32_t, ElementB>; + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + + static constexpr bool IsCooperative = cute::is_any_of_v>; + using AtomLayoutMNK = cute::conditional_t>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< + ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector< + GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector< + GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0; + static constexpr int KernelSmemCarveout = static_cast(TensorMapStorage); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + TagToStrideA_t, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/cutlass_extensions/gemm/collective/fp8_accumulation.hpp b/csrc/cutlass_extensions/gemm/collective/fp8_accumulation.hpp new file mode 100644 index 0000000000000..13b90e998625e --- /dev/null +++ b/csrc/cutlass_extensions/gemm/collective/fp8_accumulation.hpp @@ -0,0 +1,183 @@ +// clang-format off +// adapted from: https://github.com/soundOfDestiny/cutlass/blob/a4208aa6958864923505cade9c63eb2a6daf16e5/include/cutlass/gemm/collective/fp8_accumulation.hpp + +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cute/algorithm/clear.hpp" +#include "cute/tensor.hpp" + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////FP8 Accumulation/////////////////////////// +////////////////////////////////////////////////////////////////////////////// +/// This class provides API to promote (add) or scale (multiply_add) the results +/// from the tensor core accumulators to the main accumulators when the number +/// of MMAs reaches the max number of MMA interval specified by user, after that +/// the tensor core accumulators are zeroed. +////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +template < + class EngineAccum, + class LayoutAccum> +struct GmmaFP8AccumulationWithScale { + using TensorAccum = cute::Tensor; + using ElementAccumulator = typename EngineAccum::value_type; + + static_assert(is_static::value, "Accumulator Layout should be static"); + static_assert(is_rmem::value , "Accumulator tensor must be rmem resident."); + +private: + TensorAccum& accum_; + TensorAccum accum_temp_; + + uint32_t accum_promotion_interval_; // defines the max num of executed MMAs after which accum should be promoted. + uint32_t mma_count_per_mainloop_iteration_; // num of MMAs per k_tile of mainloop + uint32_t mma_count_; // current executed MMAs + uint32_t reset_accum_flag_; // accum needs to be zeroed or not. + + // promote or `add` the partial accumulators to main accumulator (FADD). + CUTLASS_DEVICE + void promote_core() { + warpgroup_wait<0>(); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accum_); ++i) { + accum_(i) += accum_temp_(i); + } + } + + // `multiply` scale the partial accumulators and `add` to main accumulator (FFMA). + template < + class EngineScale, + class LayoutScale> + CUTLASS_DEVICE + void scale_core(const cute::Tensor &scale) { + using TensorScale = cute::Tensor; + + static_assert(is_static::value, "Scale Layout should be static"); + static_assert(is_rmem::value , "Scale tensor must be rmem resident."); + + static_assert(LayoutAccum{}.shape() == LayoutScale{}.shape(), "Accumulator and scale must have same shape."); + + warpgroup_wait<0>(); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accum_); ++i) { + accum_(i) += accum_temp_(i) * scale(i); + } + } + +public: + CUTLASS_DEVICE + GmmaFP8AccumulationWithScale( + TensorAccum &accum, + uint32_t accum_promotion_interval, + uint32_t mma_count_per_mainloop_iteration) + : accum_(accum), + accum_promotion_interval_(accum_promotion_interval), + mma_count_per_mainloop_iteration_(mma_count_per_mainloop_iteration), + mma_count_(0), + reset_accum_flag_(0) + { + accum_temp_ = cute::make_fragment_like(accum); + } + + // + // Methods (Common) + // + + CUTLASS_DEVICE + TensorAccum& operator()() { + return accum_temp_; + } + + /// prepare the MMA accumulators when initialization or zeroing is required. + CUTLASS_DEVICE + bool prepare_if_needed() { + return reset_accum_flag_; + } + + // + // Methods (for FADD version) + // + + /// promote (add) the results from the MMA accumulators to main accumulator if needed. + CUTLASS_DEVICE + void promote_if_needed() { + mma_count_ += mma_count_per_mainloop_iteration_; + reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); + if (reset_accum_flag_) { + promote_core(); + mma_count_ = 0; + } + } + + /// promote (add) the residue results from the MMA accumulators to main accumulator if needed. + CUTLASS_DEVICE + void promote_residue_if_needed() { + if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { + promote_core(); + } + } + + // + // Methods (for FFMA version) + // + + /// scale (multiply_add) the results from the MMA accumulators to main accumulator if needed. + template < + class EngineScale, + class LayoutScale> + CUTLASS_DEVICE + void scale_if_needed(const cute::Tensor &scale) { + mma_count_ += mma_count_per_mainloop_iteration_; + reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); + if (reset_accum_flag_) { + scale_core(scale); + mma_count_ = 0; + } + } + + /// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed. + template < + class EngineScale, + class LayoutScale> + CUTLASS_DEVICE + void scale_residue_if_needed(const cute::Tensor &scale) { + if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { + scale_core(scale); + } + } +}; + +} // namespace cutlass::gemm::collective diff --git a/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp new file mode 100644 index 0000000000000..928a9500cbb08 --- /dev/null +++ b/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp @@ -0,0 +1,730 @@ +// clang-format off +// Adapted (Heavily) from: https://github.com/soundOfDestiny/cutlass/blob/9d997ce0dea4c5fa1a617db6b7ff29aa9235822c/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp + +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/trace.h" +#include "cutlass/numeric_types.h" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm80.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +#include "cutlass_extensions/gemm/dispatch_policy.hpp" +#include "cutlass_extensions/gemm/collective/fp8_accumulation.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape, + class KernelSchedule, + int ScaleGranularityM_, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using ElementBlockScale = ElementAccumulator; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + using PipelineParams = typename MainloopPipeline::Params; + + // Two threads per CTA are producers (1 for operand tile and 32 for scales) + static constexpr int NumProducerThreadEvents = 33; + + static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_; + static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert((size<0>(TileShape{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + // Block scaling gmem-to-smem copy atom + using SmemBlockScalingCopyAtomA = Copy_Atom, ElementBlockScale>; + using SmemBlockScalingCopyAtomB = Copy_Atom, ElementBlockScale>; + + // Block scaling smem layout + using SmemLayoutScaleA = Layout, Int>>; + using SmemLayoutScaleB = Layout>, Stride<_1>>; // `ScaleNsPerTile` is always 1. + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v, + "ElementAccumulator and ElementBlockScale should be same datatype"); + + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<128> { + cute::array_aligned> smem_A; // mxk + cute::array_aligned> smem_B; // nxk + cute::array_aligned> smem_scale_A; // ScaleMsPerTile x k + cute::array_aligned> smem_scale_B; // 1xk + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + ElementBlockScale const* ptr_scale_A; + ElementBlockScale const* ptr_scale_B; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy_A_sm90( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,0), + TileShape{}, + ClusterShape{})); + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy_B_sm90( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,0), + TileShape{}, + ClusterShape{})); + TMA_A tma_load_a; + TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + // Block scaling factors for A and B + ElementBlockScale const* ptr_scale_A; + ElementBlockScale const* ptr_scale_B; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); + typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); + uint32_t transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t transaction_bytes_nk = TmaTransactionBytesNK; + uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk; + + return { + tma_load_a, + tma_load_b, + transaction_bytes, + transaction_bytes_mk, + transaction_bytes_nk, + args.ptr_scale_A, + args.ptr_scale_B + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytesMK = + cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytesNK = + cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) + { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + constexpr auto scales_m = Int{}; + auto tM = get<2>(gA_mkl.shape()); + auto tN = get<2>(gB_nkl.shape()); + auto tK = get<3>(gA_mkl.shape()); + + // Make the tiled views of scale tensors + auto scaleA_shape = make_shape(M / ScaleGranularityM, tK, L); // (scale_m,k,l) + auto scaleA_layout = make_ordered_layout(scaleA_shape, Step<_0, _1, _2>{}); + auto scaleB_shape = make_shape(tN, tK, L); // (n,k,l) + auto scaleB_layout = make_ordered_layout(scaleB_shape, Step<_1, _0, _2>{}); + + // Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and + // gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl. + Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (scale_m,k,l) + Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,k,l) + + return cute::make_tuple(gA_mkl, gB_nkl, mScaleA_mkl, mScaleB_nkl); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, class TensorB, + class TensorScaleA, class TensorScaleB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + // Blockscaling: Tma loads for load_input and CpAsync for load_scale + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k) + Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k) + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + + // Block scaling: load_scale has scaling tensors in global memory which are not tiled + Tensor mScaleA_mkl = get<2>(load_inputs); + Tensor mScaleB_nkl = get<3>(load_inputs); + auto scales_m = get<0>(mScaleA_mkl.shape()); + + Tensor cScaleA_mkl = make_identity_tensor(mScaleA_mkl.shape()); + + Tensor gScaleA = local_tile( + mScaleA_mkl, make_tile(Int{}), + make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1) + Tensor cScaleA = local_tile( + cScaleA_mkl, make_tile(Int{}), + make_coord(m_coord,_,l_coord)); + Tensor gScaleB = mScaleB_nkl(n_coord,_,l_coord); // (1,k,1) + + // TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128 + TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{}, + Layout>{}, Layout>{}); // (1,1,1) + TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{}, + Layout>{}, Layout>{}); // (1,1,1) + ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x); + ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x); + + Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S(gScaleA); + Tensor tAcA_ScaleA = thr_scale_copy_a.partition_S(cScaleA); + Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D(sScaleA); + + Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S(gScaleB); + Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB); + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + + // Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // Allocate predicate tensors for a_scales (since we can't guarantee that + // all scales are valid, since we could have a partial tiles along M) + Tensor tApA_ScaleA = make_tensor(shape(tAsA_ScaleA(_,_,0))); + #pragma unroll + for (int i = 0; i < size(tApA_ScaleA); ++i) { + tApA_ScaleA(i) = get<0>(tAcA_ScaleA(i)) < scales_m; + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + int write_stage = smem_pipe_write.index(); + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + // Copy operands A and B from global memory to shared memory + if (lane_predicate) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + if (lane_predicate) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + + // Copy scale tensors from global memory to shared memory + copy_if(scale_copy_a, tApA_ScaleA, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage)); + copy(scale_copy_b, tBgB_ScaleB(_,*k_tile_iter), tBsB_ScaleB(_,write_stage)); + pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc); + + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail( + MainloopPipeline pipeline, + PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + + + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Block scaling + Tensor sScaleAViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), + Layout< + Shape, Int>, cute::tuple_element_t<1, TileShape>, Int>, + Stride, _0, Int> + >{}); // ((ScaleGranularityM,ScaleMsPerTile),n,k) + Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k) + + // + // Define C accumulators and A/B partitioning + // + + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, + Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); + + Tensor tCsScaleAViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C. + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Per block scale values for operand A and B + + using RegLayoutScaleAViewAsC = decltype(make_layout_like(tCsScaleAViewAsC(_, _, _, 0).layout())); // `make_layout_like` makes a compact layout. + using RegLayoutScaleAEssential = decltype(filter_zeros(RegLayoutScaleAViewAsC{}.stride(), RegLayoutScaleAViewAsC{}.shape())); // an interface to traverse the underlying storage for the compact layout mentioned above + + Tensor tCrScaleAViewAsC = make_tensor(RegLayoutScaleAViewAsC{}); // (MMA,MMA_M,MMA_N) + ElementBlockScale scale_b; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + GmmaFP8AccumulationWithScale accumulation(accum, size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{}), size<2>(tCrA)); + warpgroup_fence_operand(accumulation()); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + if (accumulation.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + int read_stage = smem_pipe_read.index(); + + // Load per block scale values from shared memory to registers. + scale_b = sScaleB[read_stage]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { + tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{})); + } + if constexpr (ScaleMsPerTile == 1) { + static_assert(size(RegLayoutScaleAEssential{}) == 1); + tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`. + } else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { + tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b; + } + } + + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + // Block scale the accumulators with reg tensor `tCrScaleAViewAsC` + accumulation.scale_if_needed(tCrScaleAViewAsC); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accumulation()); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + // Load per block scale values from shared memory to registers (at most twice per block along M and exactly once per block along N) + scale_b = sScaleB[read_stage]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { + tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{})); + } + if constexpr (ScaleMsPerTile == 1) { + static_assert(size(RegLayoutScaleAEssential{}) == 1); + tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`. + } else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { + tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b; + } + } + + if (accumulation.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + warpgroup_fence_operand(accumulation()); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accumulation()); + + // Block scale the accumulators with reg tensor `tCrScaleAViewAsC` + accumulation.scale_if_needed(tCrScaleAViewAsC); + + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + accumulation.scale_residue_if_needed(tCrScaleAViewAsC); + + warpgroup_fence_operand(accumulation()); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/cutlass_extensions/gemm/dispatch_policy.hpp b/csrc/cutlass_extensions/gemm/dispatch_policy.hpp new file mode 100644 index 0000000000000..df809e27a3efe --- /dev/null +++ b/csrc/cutlass_extensions/gemm/dispatch_policy.hpp @@ -0,0 +1,39 @@ +#pragma once + +#include "cutlass/gemm/dispatch_policy.hpp" + +namespace cutlass::gemm { + +////////////////////////////////////////////////////////////////////////////// + +// FP8 related policies (including Blocked Scaled Accumulation) +// `ScaleGranularityM` specifies scaling granularity along M, while zero-value +// `ScaleGranularityM` indicates that scaling granularity is +// `size<0>(TileShape_MNK{})` along M. +template +struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum + : KernelTmaWarpSpecializedCooperative {}; + +// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp +// specialized dynamic schedule For FP8 kernels with Block Scaling +template , + class KernelSchedule = KernelTmaWarpSpecialized, + int ScaleGranularityM = + 0 // `ScaleGranularityM` specifies scaling granularity along M, + // while zero-value `ScaleGranularityM` indicates that scaling + // granularity is `size<0>(TileShape_MNK{})` along M. + > +struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8 + : MainloopSm90TmaGmmaWarpSpecialized { + static_assert( + cute::is_same_v< + KernelSchedule, + KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum< + ScaleGranularityM>>, + "KernelSchedule must be one of the warp specialized policies"); +}; + +////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm \ No newline at end of file diff --git a/csrc/cutlass_extensions/torch_utils.hpp b/csrc/cutlass_extensions/torch_utils.hpp index 2c78572521eec..a1ff933cce63f 100644 --- a/csrc/cutlass_extensions/torch_utils.hpp +++ b/csrc/cutlass_extensions/torch_utils.hpp @@ -97,7 +97,7 @@ static inline auto make_cute_layout(torch::Tensor const& tensor, template static inline auto maybe_make_cute_layout( - c10::optional const& tensor, + std::optional const& tensor, std::string_view name = "tensor") { using Layout = decltype(make_cute_layout(*tensor)); diff --git a/csrc/cutlass_extensions/vllm_collective_builder.cuh b/csrc/cutlass_extensions/vllm_collective_builder.cuh index 085ee1290031f..e7fbba4cd4b0d 100644 --- a/csrc/cutlass_extensions/vllm_collective_builder.cuh +++ b/csrc/cutlass_extensions/vllm_collective_builder.cuh @@ -1,6 +1,6 @@ #pragma once -#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass_extensions/gemm/collective/collective_builder.hpp" namespace cutlass::gemm::collective { using namespace cute; diff --git a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py index a5beea1a35e49..b401736c9824b 100644 --- a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py +++ b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py @@ -14,9 +14,9 @@ class VLLMDataType(enum.Enum): class MixedInputKernelScheduleType(enum.Enum): - TmaWarpSpecializedMixedInput = enum_auto() - TmaWarpSpecializedPingpongMixedInput = enum_auto() - TmaWarpSpecializedCooperativeMixedInput = enum_auto() + TmaWarpSpecialized = enum_auto() + TmaWarpSpecializedPingpong = enum_auto() + TmaWarpSpecializedCooperative = enum_auto() VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = { @@ -68,11 +68,11 @@ class MixedInputKernelScheduleType(enum.Enum): MixedInputKernelScheduleType, KernelScheduleType], str] = { **KernelScheduleTag, # type: ignore **{ - MixedInputKernelScheduleType.TmaWarpSpecializedMixedInput: - "cutlass::gemm::KernelTmaWarpSpecializedMixedInput", - MixedInputKernelScheduleType.TmaWarpSpecializedPingpongMixedInput: - "cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput", - MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput: - "cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput", + MixedInputKernelScheduleType.TmaWarpSpecialized: + "cutlass::gemm::KernelTmaWarpSpecialized", + MixedInputKernelScheduleType.TmaWarpSpecializedPingpong: + "cutlass::gemm::KernelTmaWarpSpecializedPingpong", + MixedInputKernelScheduleType.TmaWarpSpecializedCooperative: + "cutlass::gemm::KernelTmaWarpSpecializedCooperative", } } diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index dd1e6de2e0180..f0e5533bcae60 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -53,12 +53,12 @@ void set_conv_params_fwd(ConvParamsBase ¶ms, const at::Tensor x, const at::Tensor weight, const at::Tensor out, - const c10::optional& bias, + const std::optional& bias, bool silu_activation, int64_t pad_slot_id, - const c10::optional& query_start_loc = std::nullopt, - const c10::optional& cache_indices = std::nullopt, - const c10::optional& has_initial_state = std::nullopt) { + const std::optional& query_start_loc = std::nullopt, + const std::optional& cache_indices = std::nullopt, + const std::optional& has_initial_state = std::nullopt) { // Reset the parameters memset(¶ms, 0, sizeof(params)); @@ -93,11 +93,11 @@ void set_conv_params_fwd(ConvParamsBase ¶ms, void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, - const c10::optional &bias_, - const c10::optional &conv_states, - const c10::optional &query_start_loc, - const c10::optional &cache_indices, - const c10::optional &has_initial_state, + const std::optional &bias_, + const std::optional &conv_states, + const std::optional &query_start_loc, + const std::optional &cache_indices, + const std::optional &has_initial_state, bool silu_activation, // used to identify padding entries if cache_indices provided // in case of padding, the kernel will return early @@ -194,10 +194,10 @@ void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, void causal_conv1d_update(const at::Tensor &x, const at::Tensor &conv_state, const at::Tensor &weight, - const c10::optional &bias_, + const std::optional &bias_, bool silu_activation, - const c10::optional &cache_seqlens_, - const c10::optional &conv_state_indices_, + const std::optional &cache_seqlens_, + const std::optional &conv_state_indices_, // used to identify padding entries if cache_indices provided // in case of padding, the kernel will return early int64_t pad_slot_id) { diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index 71624696338d0..bd0a34119c82b 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -402,14 +402,14 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, const torch::Tensor out, const torch::Tensor z, const torch::Tensor out_z, - const c10::optional& D, - const c10::optional& delta_bias, + const std::optional& D, + const std::optional& delta_bias, const torch::Tensor ssm_states, bool has_z, bool delta_softplus, - const c10::optional& query_start_loc, - const c10::optional& cache_indices, - const c10::optional& has_initial_state, + const std::optional& query_start_loc, + const std::optional& cache_indices, + const std::optional& has_initial_state, bool varlen, int64_t pad_slot_id) { @@ -504,13 +504,13 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C, - const c10::optional &D_, - const c10::optional &z_, - const c10::optional &delta_bias_, + const std::optional &D_, + const std::optional &z_, + const std::optional &delta_bias_, bool delta_softplus, - const c10::optional &query_start_loc, - const c10::optional &cache_indices, - const c10::optional &has_initial_state, + const std::optional &query_start_loc, + const std::optional &cache_indices, + const std::optional &has_initial_state, const torch::Tensor &ssm_states, // used to identify padding entries if cache_indices provided // in case of padding, the kernel will return early diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel.h b/csrc/moe/marlin_kernels/marlin_moe_kernel.h index a217401b3d7c2..47ecf109d0f53 100644 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel.h +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel.h @@ -138,8 +138,8 @@ __device__ inline FragB dequant(int q) { const int HI = 0x00f000f0; const int EX = 0x64006400; // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); + int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point // directly into `SUB` and `ADD`. const int SUB = 0x64086408; @@ -182,8 +182,8 @@ __device__ inline FragB dequant(int q) { const int HI = 0x00f000f0; const int EX = 0x64006400; // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); + int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); const int SUB = 0x64006400; const int MUL = 0x2c002c00; diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index 24341d63fb1f8..8b6fe72ad743b 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -21,7 +21,7 @@ __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, } } // namespace -template +template __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, int32_t* expert_ids, @@ -32,12 +32,10 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, const size_t start_idx = threadIdx.x * tokens_per_thread; extern __shared__ int32_t shared_mem[]; - - int32_t* tokens_cnts = - shared_mem; // 2d tensor with shape (blockDim.x + 1, num_experts) - int32_t* cumsum = - shared_mem + - (blockDim.x + 1) * num_experts; // 1d tensor with shape (num_experts + 1) + int32_t* cumsum = shared_mem; // 1d tensor with shape (num_experts + 1) + token_cnts_t* tokens_cnts = + (token_cnts_t*)(shared_mem + num_experts + + 1); // 2d tensor with shape (blockDim.x + 1, num_experts) for (int i = 0; i < num_experts; ++i) { tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; @@ -74,7 +72,7 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, block_size) * block_size; } - *total_tokens_post_pad = cumsum[num_experts]; + *total_tokens_post_pad = static_cast(cumsum[num_experts]); } __syncthreads(); @@ -224,26 +222,46 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, torch::Tensor num_tokens_post_pad) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - // If we have very large number of experts, we can no longer use shared - // memory. - // TODO(simon): the right solution should be calculating the exact right - // amount of shared memory and use that. The num_experts >= 256 is just a - // temporary solution to unblock Deepseek V3. - if (num_experts >= 256) { + int device_max_shared_mem; + auto dev = topk_ids.get_device(); + cudaDeviceGetAttribute(&device_max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + + const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); + const int32_t shared_mem_i32 = + ((num_thread + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); + const int32_t shared_mem_i16 = + ((num_thread + 1) * num_experts) * sizeof(uint16_t) + + (num_experts + 1) * sizeof(int32_t); + + bool use_global_memory = false; + bool use_i16 = false; // Use uint16_t for shared memory token counts + if (shared_mem_i32 < device_max_shared_mem) { + // Do nothing in this case. We're all set to use int32_t token counts + } else if (shared_mem_i16 < device_max_shared_mem && + topk_ids.numel() <= 65535) { + // when nelements of topk_ids is smaller than 65535 (max value of uint16), + // element value of token_cnts would also smaller than 65535, + // so we can use uint16 as dtype of token_cnts + use_i16 = true; + } else { + use_global_memory = true; + } + + if (use_global_memory) { VLLM_DISPATCH_INTEGRAL_TYPES( topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] { // calc needed amount of shared mem for `tokens_cnts` and `cumsum` // tensors const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); - const int32_t mem_tokens_cnts = - ((num_experts + 1) * num_experts) * sizeof(int32_t); - const int32_t mem_cumsum = (num_experts + 1) * sizeof(int32_t); - // allocate global memory - int32_t* tokens_cnts; - int32_t* cumsum; - cudaMalloc(&tokens_cnts, mem_tokens_cnts); - cudaMalloc(&cumsum, mem_cumsum); + auto options_int = torch::TensorOptions() + .dtype(torch::kInt) + .device(topk_ids.device()); + torch::Tensor token_cnts_buffer = + torch::empty({(num_experts + 1) * num_experts}, options_int); + torch::Tensor cumsum_buffer = + torch::empty({num_experts + 1}, options_int); auto kernel = vllm::moe::moe_align_block_size_global_mem_kernel; @@ -252,25 +270,32 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, sorted_token_ids.data_ptr(), experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), num_experts, block_size, - topk_ids.numel(), tokens_cnts, cumsum); - cudaFree(tokens_cnts); - cudaFree(cumsum); + topk_ids.numel(), token_cnts_buffer.data_ptr(), + cumsum_buffer.data_ptr()); }); - } else { + } else if (use_i16) { VLLM_DISPATCH_INTEGRAL_TYPES( topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { - // calc needed amount of shared mem for `tokens_cnts` and `cumsum` - // tensors - const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); - const int32_t shared_mem = - ((num_thread + 1) * num_experts + (num_experts + 1)) * - sizeof(int32_t); - // set dynamic shared mem - auto kernel = vllm::moe::moe_align_block_size_kernel; + auto kernel = + vllm::moe::moe_align_block_size_kernel; + AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( + (void*)kernel, shared_mem_i16)); + kernel<<<1, num_thread, shared_mem_i16, stream>>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), num_experts, block_size, + topk_ids.numel()); + }); + } else { + VLLM_DISPATCH_INTEGRAL_TYPES( + topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { + auto kernel = + vllm::moe::moe_align_block_size_kernel; AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( - (void*)kernel, shared_mem)); - kernel<<<1, num_thread, shared_mem, stream>>>( + (void*)kernel, shared_mem_i32)); + kernel<<<1, num_thread, shared_mem_i32, stream>>>( topk_ids.data_ptr(), sorted_token_ids.data_ptr(), experts_ids.data_ptr(), diff --git a/csrc/ops.h b/csrc/ops.h index 347c502845d8f..e39d4ef3188a3 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -33,9 +33,10 @@ void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + int64_t max_seq_len, const std::optional& alibi_slopes, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); @@ -44,9 +45,10 @@ void paged_attention_v2( torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + int64_t max_seq_len, const std::optional& alibi_slopes, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); @@ -86,6 +88,8 @@ void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, void silu_and_mul(torch::Tensor& out, torch::Tensor& input); +void mul_and_silu(torch::Tensor& out, torch::Tensor& input); + void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); @@ -149,19 +153,20 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type, #ifndef USE_ROCM bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); +bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability); void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales, - c10::optional const& bias); + std::optional const& bias); void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& azp_adj, - c10::optional const& azp, - c10::optional const& bias); + std::optional const& azp, + std::optional const& bias); bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability); @@ -169,7 +174,7 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& e, torch::Tensor const& a_scales, torch::Tensor const& b_scales, - c10::optional const& bias); + std::optional const& bias); bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed, torch::Tensor& e, torch::Tensor const& a); @@ -177,11 +182,11 @@ bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed, void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale, - c10::optional const& azp); + std::optional const& azp); void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scales, - c10::optional const& azp); + std::optional const& azp); torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor b_gptq_qzeros, @@ -198,34 +203,34 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, void dynamic_per_token_scaled_fp8_quant( torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale, - c10::optional const& scale_ub); + std::optional const& scale_ub); void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& C, - const c10::optional& D_, - const c10::optional& z_, - const c10::optional& delta_bias_, + const std::optional& D_, + const std::optional& z_, + const std::optional& delta_bias_, bool delta_softplus, - const c10::optional& query_start_loc, - const c10::optional& cache_indices, - const c10::optional& has_initial_state, + const std::optional& query_start_loc, + const std::optional& cache_indices, + const std::optional& has_initial_state, const torch::Tensor& ssm_states, int64_t pad_slot_id); void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight, - const c10::optional& bias_, + const std::optional& bias_, bool silu_activation, - const c10::optional& cache_seqlens_, - const c10::optional& conv_state_indices_, + const std::optional& cache_seqlens_, + const std::optional& conv_state_indices_, int64_t pad_slot_id); void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, - const c10::optional& bias_, - const c10::optional& conv_states, - const c10::optional& query_start_loc, - const c10::optional& cache_indices, - const c10::optional& has_initial_state, + const std::optional& bias_, + const std::optional& conv_states, + const std::optional& query_start_loc, + const std::optional& cache_indices, + const std::optional& has_initial_state, bool silu_activation, int64_t pad_slot_id); #ifndef USE_ROCM diff --git a/csrc/prepare_inputs/advance_step.cu b/csrc/prepare_inputs/advance_step.cu index bd184ee22682e..c3902f4c2a163 100644 --- a/csrc/prepare_inputs/advance_step.cu +++ b/csrc/prepare_inputs/advance_step.cu @@ -95,6 +95,16 @@ __global__ void advance_step_flashinfer_kernel( long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr, int64_t const block_tables_stride, int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) { + int const n_pad = num_seqs - num_queries; + if (n_pad && blockIdx.x == 0) { + // Handle cuda graph padding + int const offset = num_queries; + for (int i = threadIdx.x; i < n_pad; i += blockDim.x) { + input_tokens_ptr[offset + i] = 0; + input_positions_ptr[offset + i] = 0; + slot_mapping_ptr[offset + i] = -1; + } + } int num_query_blocks = div_ceil(num_queries, num_threads); if (blockIdx.x < num_query_blocks) { diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index e9987535bd3ea..e79785827189d 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -226,7 +226,7 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] torch::Tensor const& input, // [..., hidden_size] torch::Tensor const& scale, - c10::optional const& azp) { + std::optional const& azp) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(scale.numel() == 1); @@ -257,7 +257,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] void dynamic_scaled_int8_quant( torch::Tensor& out, // [..., hidden_size] torch::Tensor const& input, // [..., hidden_size] - torch::Tensor& scales, c10::optional const& azp) { + torch::Tensor& scales, std::optional const& azp) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(scales.is_contiguous()); diff --git a/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh b/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh new file mode 100644 index 0000000000000..9ac7eee7204ec --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh @@ -0,0 +1,93 @@ +#pragma once + +// clang-format will break include orders +// clang-format off +#include + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "core/math.hpp" +#include "cutlass_extensions/common.hpp" +// clang-format on + +namespace vllm::c3x { + +static inline cute::Shape get_problem_shape( + torch::Tensor const& a, torch::Tensor const& b) { + int32_t m = a.size(0), n = b.size(1), k = a.size(1); + return {m, n, k, 1}; +} + +template +void cutlass_gemm_caller(torch::Device device, + cute::Shape prob_shape, + typename GemmKernel::MainloopArguments mainloop_args, + typename GemmKernel::EpilogueArguments epilogue_args) { + typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, + prob_shape, mainloop_args, epilogue_args}; + + // Launch the CUTLASS GEMM kernel. + using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; + GemmOp gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(args)); + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(device); + auto workspace = torch::empty(workspace_size, workspace_options); + + auto stream = at::cuda::getCurrentCUDAStream(device.index()); + + cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); + CUTLASS_CHECK(status); +} + +template +void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + EpilogueArgs&&... epilogue_params) { + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + using GemmKernel = typename Gemm::GemmKernel; + + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + using StrideA = cute::Stride, int64_t>; + using StrideB = cute::Stride, int64_t>; + using StrideC = typename Gemm::StrideC; + + StrideA a_stride{lda, cute::Int<1>{}, 0}; + StrideB b_stride{ldb, cute::Int<1>{}, 0}; + StrideC c_stride{ldc, cute::Int<1>{}, cute::Int<0>{}}; + + typename GemmKernel::ProblemShape prob_shape = get_problem_shape(a, b); + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, + b_stride}; + + auto c_ptr = static_cast(out.data_ptr()); + typename GemmKernel::EpilogueArguments epilogue_args{ + Gemm::Epilogue::prepare_args( + std::forward(epilogue_params)...), + c_ptr, c_stride, c_ptr, c_stride}; + + cutlass_gemm_caller(a.device(), prob_shape, mainloop_args, + epilogue_args); +} + +} // namespace vllm::c3x \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh similarity index 51% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cuh rename to csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh index d4bc2f0ade50d..9227ebb735245 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh @@ -2,9 +2,6 @@ // clang-format will break include orders // clang-format off -#include - -#include #include "cutlass/cutlass.h" @@ -32,21 +29,6 @@ using namespace cute; namespace vllm { -// A wrapper for the GEMM kernel that is used to guard against compilation on -// architectures that will never use the kernel. The purpose of this is to -// reduce the size of the compiled binary. -// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef -// into code that will be executed on the device where it is defined. -template -struct enable_sm90_or_later : Kernel { - template - CUTLASS_DEVICE void operator()(Args&&... args) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 - Kernel::operator()(std::forward(args)...); -#endif - } -}; - template typename Epilogue_, typename TileShape, typename ClusterShape, typename KernelSchedule, @@ -101,60 +83,4 @@ struct cutlass_3x_gemm { struct GemmKernel : public KernelType {}; }; -template -void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - EpilogueArgs&&... epilogue_params) { - using ElementAB = typename Gemm::ElementAB; - using ElementD = typename Gemm::ElementD; - - int32_t m = a.size(0); - int32_t n = b.size(1); - int32_t k = a.size(1); - - int64_t lda = a.stride(0); - int64_t ldb = b.stride(1); - int64_t ldc = out.stride(0); - - using StrideA = Stride, int64_t>; - using StrideB = Stride, int64_t>; - using StrideC = typename Gemm::StrideC; - - StrideA a_stride{lda, Int<1>{}, 0}; - StrideB b_stride{ldb, Int<1>{}, 0}; - StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; - - using GemmKernel = typename Gemm::GemmKernel; - typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; - - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); - typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, - b_stride}; - - auto c_ptr = static_cast(out.data_ptr()); - typename GemmKernel::EpilogueArguments epilogue_args{ - Gemm::Epilogue::prepare_args( - std::forward(epilogue_params)...), - c_ptr, c_stride, c_ptr, c_stride}; - - typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, - prob_shape, mainloop_args, epilogue_args}; - - // Launch the CUTLASS GEMM kernel. - using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; - GemmOp gemm_op; - CUTLASS_CHECK(gemm_op.can_implement(args)); - - size_t workspace_size = gemm_op.get_workspace_size(args); - auto const workspace_options = - torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); - auto workspace = torch::empty(workspace_size, workspace_options); - - auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); - - cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); - CUTLASS_CHECK(status); -} - } // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu new file mode 100644 index 0000000000000..4cd38f4975df7 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu @@ -0,0 +1,24 @@ +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_sm90_int8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm { + +void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + std::optional const& azp, + std::optional const& bias) { + if (azp) { + return cutlass_scaled_mm_sm90_int8_epilogue< + c3x::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales, azp_adj, + *azp, bias); + } else { + return cutlass_scaled_mm_sm90_int8_epilogue( + out, a, b, a_scales, b_scales, azp_adj, bias); + } +} + +} // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu new file mode 100644 index 0000000000000..0501e6da160e2 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu @@ -0,0 +1,24 @@ + +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm { + +void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + if (out.dtype() == torch::kBFloat16) { + cutlass_gemm_blockwise_sm90_fp8_dispatch( + out, a, b, a_scales, b_scales); + + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + cutlass_gemm_blockwise_sm90_fp8_dispatch( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh new file mode 100644 index 0000000000000..fb7a82b80ee65 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh @@ -0,0 +1,168 @@ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass_extensions/gemm/dispatch_policy.hpp" +#include "cutlass_extensions/gemm/collective/collective_builder.hpp" + +#include "cutlass_gemm_caller.cuh" + +namespace vllm { + +using namespace cute; + +template > +struct cutlass_3x_gemm_fp8_blockwise { + using GroupSizeM = Int; + using GroupSizeN = Int; + using GroupSizeK = Int; + using TileSizeM = Int; + + static_assert(TileSizeM_ % GroupSizeM_ == 0, + "TileSizeM must be a multiple of GroupSizeM"); + + using ElementAB = cutlass::float_e4m3_t; + + using ElementA = ElementAB; + using LayoutA = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using ElementB = ElementAB; + using LayoutB = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + using ElementD = OutType; + using StrideD = Stride, Int<0>>; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + using ElementC = void; + using StrideC = StrideD; + static constexpr int AlignmentC = AlignmentD; + + using ElementAccumulator = float; + using ElementBlockScale = float; + using ElementCompute = float; + using ArchTag = cutlass::arch::Sm90; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using TileShape = Shape; + + using KernelSchedule = cutlass::gemm:: + KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum< + GroupSizeM_>; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + + using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT< + cutlass::epilogue::fusion::Sm90AccFetch>; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, + ElementAccumulator, ElementCompute, ElementC, StrideC, AlignmentC, + ElementD, StrideD, AlignmentD, EpilogueSchedule, + StoreEpilogueCompute>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, + LayoutB, AlignmentB, ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, + cutlass::gemm::PersistentScheduler>>; + + struct GemmKernel : public KernelType {}; + + using StrideA = typename GemmKernel::StrideA; + using StrideB = typename GemmKernel::StrideB; +}; + +template +void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + using GemmKernel = typename Gemm::GemmKernel; + + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + auto prob_shape = c3x::get_problem_shape(a, b); + int32_t m = get<0>(prob_shape), n = get<1>(prob_shape), + k = get<2>(prob_shape); + + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + using StrideA = Stride, int64_t>; + using StrideB = Stride, int64_t>; + using StrideC = typename Gemm::StrideC; + + StrideA a_stride{lda, Int<1>{}, 0}; + StrideB b_stride{ldb, Int<1>{}, 0}; + StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto a_scales_ptr = static_cast(a_scales.data_ptr()); + auto b_scales_ptr = static_cast(b_scales.data_ptr()); + + // Check is the t is contiguous and is 1D or 2D with one of the dimensions + // being 1 (i.e. a row or column vector) + auto is_contiguous_vector = [](const torch::Tensor& t) { + auto t_sizes = t.sizes(); + return t.is_contiguous() && + (t.dim() == 1 || + (t.dim() == 2 && + *std::min_element(t_sizes.begin(), t_sizes.end()) == 1)); + }; + + // TODO(lucas): lets clean-up the kernel so that we pass in Strides so + // we don't have to deal with enforcing implicit layouts + TORCH_CHECK(a_scales.size(0) == m / Gemm::GroupSizeM::value); + TORCH_CHECK(a_scales.size(1) == k / Gemm::GroupSizeK::value); + TORCH_CHECK(a_scales.stride(0) == 1 || is_contiguous_vector(a_scales), + "a_scales must be M major"); + TORCH_CHECK(b_scales.size(0) == k / Gemm::GroupSizeK::value); + TORCH_CHECK(b_scales.size(1) == n / Gemm::GroupSizeN::value); + TORCH_CHECK(b_scales.stride(0) == 1 || is_contiguous_vector(b_scales), + "b_scales must be K major"); + typename GemmKernel::MainloopArguments mainloop_args{ + a_ptr, a_stride, b_ptr, b_stride, a_scales_ptr, b_scales_ptr}; + + auto c_ptr = static_cast(out.data_ptr()); + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, c_ptr, c_stride, c_ptr, c_stride}; + + c3x::cutlass_gemm_caller(a.device(), prob_shape, mainloop_args, + epilogue_args); +} + +template +void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + cutlass_gemm_caller_blockwise< + cutlass_3x_gemm_fp8_blockwise>(out, a, b, a_scales, + b_scales); +} + +} // namespace vllm \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp new file mode 100644 index 0000000000000..7ede9e067477b --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp @@ -0,0 +1,33 @@ +#pragma once + +#include + +namespace vllm { + +void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + std::optional const& bias); + +void cutlass_scaled_mm_sm90_int8(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + std::optional const& bias); + +void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + std::optional const& azp, + std::optional const& bias); + +void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); + +} // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu new file mode 100644 index 0000000000000..e092c61abc249 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu @@ -0,0 +1,24 @@ +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_sm90_fp8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm { + +void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + std::optional const& bias) { + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + if (bias) { + TORCH_CHECK(bias->dtype() == out.dtype(), + "currently bias dtype must match output dtype ", out.dtype()); + return cutlass_scaled_mm_sm90_fp8_epilogue( + out, a, b, a_scales, b_scales, *bias); + } else { + return cutlass_scaled_mm_sm90_fp8_epilogue( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh similarity index 76% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90_fp8_dispatch.cuh rename to csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh index f08419b3122b2..32ea5db3321bc 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh @@ -1,6 +1,7 @@ #pragma once -#include "scaled_mm_c3x.cuh" +#include "scaled_mm.cuh" +#include "cutlass_gemm_caller.cuh" /** * This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm @@ -9,6 +10,8 @@ namespace vllm { +using c3x::cutlass_gemm_caller; + template typename Epilogue> struct sm90_fp8_config_default { @@ -93,4 +96,25 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, } } +template