diff --git a/.buildkite/check-wheel-size.py b/.buildkite/check-wheel-size.py index 90a5e54736cf3..75ad094fa1382 100644 --- a/.buildkite/check-wheel-size.py +++ b/.buildkite/check-wheel-size.py @@ -1,7 +1,7 @@ import os import zipfile -MAX_SIZE_MB = 100 +MAX_SIZE_MB = 200 def print_top_10_largest_files(zip_file): diff --git a/.buildkite/download-images.sh b/.buildkite/download-images.sh index 389a12956c3c3..360a7584bccf1 100644 --- a/.buildkite/download-images.sh +++ b/.buildkite/download-images.sh @@ -8,10 +8,6 @@ set -o pipefail # aws s3 sync s3://air-example-data-2/vllm_opensource_llava/ images/ mkdir -p images cd images -wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign_pixel_values.pt -wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign_image_features.pt -wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom_pixel_values.pt -wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom_image_features.pt wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign.jpg wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom.jpg diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct.yaml new file mode 100644 index 0000000000000..fa6ea236ef04f --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m meta-llama/Meta-Llama-3-70B-Instruct -b 32 -l 250 -f 5 +model_name: "meta-llama/Meta-Llama-3-70B-Instruct" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.892 + - name: "exact_match,flexible-extract" + value: 0.892 +limit: 250 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8.yaml new file mode 100644 index 0000000000000..02668702b83af --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m neuralmagic/Meta-Llama-3-8B-Instruct-FP8 -b 32 -l 250 -f 5 -t 1 +model_name: "neuralmagic/Meta-Llama-3-8B-Instruct-FP8" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.756 + - name: "exact_match,flexible-extract" + value: 0.752 +limit: 250 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml new file mode 100644 index 0000000000000..fb4b4915ab955 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m meta-llama/Meta-Llama-3-8B-Instruct -b 32 -l 250 -f 5 -t 1 +model_name: "meta-llama/Meta-Llama-3-8B-Instruct" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.756 + - name: "exact_match,flexible-extract" + value: 0.752 +limit: 250 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1.yaml b/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1.yaml new file mode 100644 index 0000000000000..dec9164d1b84e --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m neuralmagic/Mixtral-8x7B-Instruct-v0.1 -b 32 -l 250 -f 5 -t 4 +model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.616 + - name: "exact_match,flexible-extract" + value: 0.632 +limit: 250 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/models-large.txt b/.buildkite/lm-eval-harness/configs/models-large.txt new file mode 100644 index 0000000000000..127ec5d97bcff --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/models-large.txt @@ -0,0 +1,2 @@ +Meta-Llama-3-70B-Instruct.yaml +Mixtral-8x7B-Instruct-v0.1.yaml diff --git a/.buildkite/lm-eval-harness/configs/models-small.txt b/.buildkite/lm-eval-harness/configs/models-small.txt new file mode 100644 index 0000000000000..273c5482db264 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/models-small.txt @@ -0,0 +1,2 @@ +Meta-Llama-3-8B-Instruct.yaml +Meta-Llama-3-8B-Instruct-FP8.yaml diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh new file mode 100644 index 0000000000000..fdb8ec5393b36 --- /dev/null +++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh @@ -0,0 +1,46 @@ +#!/bin/bash +# We can use this script to compute baseline accuracy on GSM for transformers. +# +# Make sure you have lm-eval-harness installed: +# pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@9516087b81a61d0e220b22cc1b75be76de23bc10 + +usage() { + echo`` + echo "Runs lm eval harness on GSM8k using huggingface transformers." + echo "This pathway is intended to be used to create baselines for " + echo "our automated nm-test-accuracy workflow" + echo + echo "usage: ${0} " + echo + echo " -m - huggingface stub or local directory of the model" + echo " -b - batch size to run the evaluation at" + echo " -l - limit number of samples to run" + echo " -f - number of fewshot samples to use" + echo +} + +while getopts "m:b:l:f:" OPT; do + case ${OPT} in + m ) + MODEL="$OPTARG" + ;; + b ) + BATCH_SIZE="$OPTARG" + ;; + l ) + LIMIT="$OPTARG" + ;; + f ) + FEWSHOT="$OPTARG" + ;; + \? ) + usage + exit 1 + ;; + esac +done + +lm_eval --model hf \ + --model_args pretrained=$MODEL,parallelize=True \ + --tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \ + --batch_size $BATCH_SIZE diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh new file mode 100644 index 0000000000000..a2876bade8893 --- /dev/null +++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh @@ -0,0 +1,51 @@ +#!/bin/bash +# We can use this script to compute baseline accuracy on GSM for vllm. +# We use this for fp8, which HF does not support. +# +# Make sure you have lm-eval-harness installed: +# pip install lm-eval==0.4.2 + +usage() { + echo`` + echo "Runs lm eval harness on GSM8k using huggingface transformers." + echo "This pathway is intended to be used to create baselines for " + echo "our automated nm-test-accuracy workflow" + echo + echo "usage: ${0} " + echo + echo " -m - huggingface stub or local directory of the model" + echo " -b - batch size to run the evaluation at" + echo " -l - limit number of samples to run" + echo " -f - number of fewshot samples to use" + echo " -t - tensor parallel size to run at" + echo +} + +while getopts "m:b:l:f:t:" OPT; do + case ${OPT} in + m ) + MODEL="$OPTARG" + ;; + b ) + BATCH_SIZE="$OPTARG" + ;; + l ) + LIMIT="$OPTARG" + ;; + f ) + FEWSHOT="$OPTARG" + ;; + t ) + TP_SIZE="$OPTARG" + ;; + \? ) + usage + exit 1 + ;; + esac +done + +lm_eval --model vllm \ + --model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE \ + --tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \ + --batch_size $BATCH_SIZE diff --git a/.buildkite/lm-eval-harness/run-tests.sh b/.buildkite/lm-eval-harness/run-tests.sh new file mode 100644 index 0000000000000..b4fdde6dab425 --- /dev/null +++ b/.buildkite/lm-eval-harness/run-tests.sh @@ -0,0 +1,59 @@ +#!/bin/bash + +usage() { + echo`` + echo "Runs lm eval harness on GSM8k using vllm and compares to " + echo "precomputed baseline (measured by HF transformers.)" + echo + echo "usage: ${0} " + echo + echo " -c - path to the test data config (e.g. configs/small-models.txt)" + echo " -t - tensor parallel size" + echo +} + +SUCCESS=0 + +while getopts "c:t:" OPT; do + case ${OPT} in + c ) + CONFIG="$OPTARG" + ;; + t ) + TP_SIZE="$OPTARG" + ;; + \? ) + usage + exit 1 + ;; + esac +done + +# Parse list of configs. +IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < $CONFIG + +for MODEL_CONFIG in "${MODEL_CONFIGS[@]}" +do + LOCAL_SUCCESS=0 + + echo "=== RUNNING MODEL: $MODEL_CONFIG WITH TP SIZE: $TP_SIZE===" + + export LM_EVAL_TEST_DATA_FILE=$PWD/configs/${MODEL_CONFIG} + export LM_EVAL_TP_SIZE=$TP_SIZE + pytest -s test_lm_eval_correctness.py || LOCAL_SUCCESS=$? + + if [[ $LOCAL_SUCCESS == 0 ]]; then + echo "=== PASSED MODEL: ${MODEL_CONFIG} ===" + else + echo "=== FAILED MODEL: ${MODEL_CONFIG} ===" + fi + + SUCCESS=$((SUCCESS + LOCAL_SUCCESS)) + +done + +if [ "${SUCCESS}" -eq "0" ]; then + exit 0 +else + exit 1 +fi diff --git a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py new file mode 100644 index 0000000000000..975841dad1c29 --- /dev/null +++ b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py @@ -0,0 +1,54 @@ +""" +LM eval harness on model to compare vs HF baseline computed offline. +Configs are found in configs/$MODEL.yaml + +* export LM_EVAL_TEST_DATA_FILE=configs/Meta-Llama-3-70B-Instruct.yaml +* export LM_EVAL_TP_SIZE=4 +* pytest -s test_lm_eval_correctness.py +""" + +import os +from pathlib import Path + +import lm_eval +import numpy +import yaml + +RTOL = 0.02 +TEST_DATA_FILE = os.environ.get( + "LM_EVAL_TEST_DATA_FILE", + ".buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml") + +TP_SIZE = os.environ.get("LM_EVAL_TP_SIZE", 1) + + +def launch_lm_eval(eval_config): + model_args = f"pretrained={eval_config['model_name']}," \ + f"tensor_parallel_size={TP_SIZE}" + + results = lm_eval.simple_evaluate( + model="vllm", + model_args=model_args, + tasks=[task["name"] for task in eval_config["tasks"]], + num_fewshot=eval_config["num_fewshot"], + limit=eval_config["limit"], + batch_size="auto") + + return results + + +def test_lm_eval_correctness(): + eval_config = yaml.safe_load( + Path(TEST_DATA_FILE).read_text(encoding="utf-8")) + + # Launch eval requests. + results = launch_lm_eval(eval_config) + + # Confirm scores match ground truth. + for task in eval_config["tasks"]: + for metric in task["metrics"]: + ground_truth = metric["value"] + measured_value = results["results"][task["name"]][metric["name"]] + print(f'{task["name"]} | {metric["name"]}: ' + f'ground_truth={ground_truth} | measured={measured_value}') + assert numpy.isclose(ground_truth, measured_value, rtol=RTOL) diff --git a/.buildkite/nightly-benchmarks/README.md b/.buildkite/nightly-benchmarks/README.md new file mode 100644 index 0000000000000..4036b32a46bf7 --- /dev/null +++ b/.buildkite/nightly-benchmarks/README.md @@ -0,0 +1,103 @@ +# vLLM benchmark suite + +## Introduction + +This directory contains the performance benchmarking CI for vllm. +The goal is to help developers know the impact of their PRs on the performance of vllm. + +This benchmark will be *triggered* upon: +- A PR being merged into vllm. +- Every commit for those PRs with `perf-benchmarks` label. + +**Benchmarking Coverage**: latency, throughput and fix-qps serving on A100 (the support for more GPUs is comming later), with different models. + +**Benchmarking Duration**: about 1hr. + +**For benchmarking developers**: please try your best to constraint the duration of benchmarking to less than 1.5 hr so that it won't take forever to run. + + +## Configuring the workload + +The benchmarking workload contains three parts: +- Latency tests in `latency-tests.json`. +- Throughput tests in `throughput-tests.json`. +- Serving tests in `serving-tests.json`. + +See [descriptions.md](tests/descriptions.md) for detailed descriptions. + +### Latency test + +Here is an example of one test inside `latency-tests.json`: + +```json +[ + { + "test_name": "latency_llama8B_tp1", + "parameters": { + "model": "meta-llama/Meta-Llama-3-8B", + "tensor_parallel_size": 1, + "load_format": "dummy", + "num_iters_warmup": 5, + "num_iters": 15 + } + }, +] +``` + +In this example: +- The `test_name` attributes is a unique identifier for the test. In `latency-tests.json`, it must start with `latency_`. +- The `parameters` attribute control the command line arguments to be used for `benchmark_latency.py`. Note that please use underline `_` instead of the dash `-` when specifying the command line arguments, and `run-benchmarks-suite.sh` will convert the underline to dash when feeding the arguments to `benchmark_latency.py`. For example, the corresponding command line arguments for `benchmark_latency.py` will be `--model meta-llama/Meta-Llama-3-8B --tensor-parallel-size 1 --load-format dummy --num-iters-warmup 5 --num-iters 15` + +Note that the performance numbers are highly sensitive to the value of the parameters. Please make sure the parameters are set correctly. + +WARNING: The benchmarking script will save json results by itself, so please do not configure `--output-json` parameter in the json file. + + +### Throughput test +The tests are specified in `throughput-tests.json`. The syntax is similar to `latency-tests.json`, except for that the parameters will be fed forward to `benchmark_throughput.py`. + +The number of this test is also stable -- a slight change on the value of this number might vary the performance numbers by a lot. + +### Serving test +We test the throughput by using `benchmark_serving.py` with request rate = inf to cover the online serving overhead. The corresponding parameters are in `serving-tests.json`, and here is an example: + +``` +[ + { + "test_name": "serving_llama8B_tp1_sharegpt", + "qps_list": [1, 4, 16, "inf"], + "server_parameters": { + "model": "meta-llama/Meta-Llama-3-8B", + "tensor_parallel_size": 1, + "swap_space": 16, + "disable_log_stats": "", + "disable_log_requests": "", + "load_format": "dummy" + }, + "client_parameters": { + "model": "meta-llama/Meta-Llama-3-8B", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, +] +``` + +Inside this example: +- The `test_name` attribute is also a unique identifier for the test. It must start with `serving_`. +- The `server-parameters` includes the command line arguments for vLLM server. +- The `client-parameters` includes the command line arguments for `benchmark_serving.py`. +- The `qps_list` controls the list of qps for test. It will be used to configure the `--request-rate` parameter in `benchmark_serving.py` + +The number of this test is less stable compared to the delay and latency benchmarks (due to randomized sharegpt dataset sampling inside `benchmark_serving.py`), but a large change on this number (e.g. 5% change) still vary the output greatly. + +WARNING: The benchmarking script will save json results by itself, so please do not configure `--save-results` or other results-saving-related parameters in `serving-tests.json`. + +## Visualizing the results +The `convert-results-json-to-markdown.py` helps you put the benchmarking results inside a markdown table, by formatting [descriptions.md](tests/descriptions.md) with real benchmarking results. +You can find the result presented as a table inside the `buildkite/performance-benchmark` job page. +If you do not see the table, please wait till the benchmark finish running. +The json version of the table (together with the json version of the benchmark) will be also attached to the markdown file. +The raw benchmarking results (in the format of json files) are in the `Artifacts` tab of the benchmarking. diff --git a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml new file mode 100644 index 0000000000000..2b25c954b5c5c --- /dev/null +++ b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml @@ -0,0 +1,62 @@ +steps: + - label: "Wait for container to be ready" + agents: + queue: A100 + plugins: + - kubernetes: + podSpec: + containers: + - image: badouralix/curl-jq + command: + - sh + - .buildkite/nightly-benchmarks/scripts/wait-for-image.sh + - wait + - label: "A100 Benchmark" + agents: + queue: A100 + plugins: + - kubernetes: + podSpec: + priorityClassName: perf-benchmark + containers: + - image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + command: + - bash .buildkite/nightly-benchmarks/run-benchmarks-suite.sh + resources: + limits: + nvidia.com/gpu: 8 + volumeMounts: + - name: devshm + mountPath: /dev/shm + env: + - name: VLLM_USAGE_SOURCE + value: ci-test + - name: HF_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + nodeSelector: + nvidia.com/gpu.product: NVIDIA-A100-SXM4-80GB + volumes: + - name: devshm + emptyDir: + medium: Memory + # - label: "H100: NVIDIA SMI" + # agents: + # queue: H100 + # plugins: + # - docker#v5.11.0: + # image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + # command: + # - bash + # - .buildkite/nightly-benchmarks/run-benchmarks-suite.sh + # mount-buildkite-agent: true + # propagate-environment: true + # propagate-uid-gid: false + # ipc: host + # gpus: all + # environment: + # - VLLM_USAGE_SOURCE + # - HF_TOKEN + diff --git a/.buildkite/nightly-benchmarks/kickoff-pipeline.sh b/.buildkite/nightly-benchmarks/kickoff-pipeline.sh new file mode 100755 index 0000000000000..15d411febcee1 --- /dev/null +++ b/.buildkite/nightly-benchmarks/kickoff-pipeline.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash + +# NOTE(simon): this script runs inside a buildkite agent with CPU only access. +set -euo pipefail + +# Install system packages +apt update +apt install -y curl jq + +# Install minijinja for templating +curl -sSfL https://github.com/mitsuhiko/minijinja/releases/latest/download/minijinja-cli-installer.sh | sh +source $HOME/.cargo/env + +# If BUILDKITE_PULL_REQUEST != "false", then we check the PR labels using curl and jq +if [ "$BUILDKITE_PULL_REQUEST" != "false" ]; then + PR_LABELS=$(curl -s "https://api.github.com/repos/vllm-project/vllm/pulls/$BUILDKITE_PULL_REQUEST" | jq -r '.labels[].name') + + if [[ $PR_LABELS == *"perf-benchmarks"* ]]; then + echo "This PR has the 'perf-benchmarks' label. Proceeding with the nightly benchmarks." + else + echo "This PR does not have the 'perf-benchmarks' label. Skipping the nightly benchmarks." + exit 0 + fi +fi + +# Upload sample.yaml +buildkite-agent pipeline upload .buildkite/nightly-benchmarks/benchmark-pipeline.yaml diff --git a/.buildkite/nightly-benchmarks/run-benchmarks-suite.sh b/.buildkite/nightly-benchmarks/run-benchmarks-suite.sh new file mode 100644 index 0000000000000..021473f76d0e5 --- /dev/null +++ b/.buildkite/nightly-benchmarks/run-benchmarks-suite.sh @@ -0,0 +1,358 @@ +#!/bin/bash + +# This script should be run inside the CI process +# This script assumes that we are already inside the vllm/ directory +# Benchmarking results will be available inside vllm/benchmarks/results/ + +# Do not set -e, as the mixtral 8x22B model tends to crash occasionally +# and we still want to see other benchmarking results even when mixtral crashes. +set -o pipefail + +check_gpus() { + # check the number of GPUs and GPU type. + declare -g gpu_count=$(nvidia-smi --list-gpus | wc -l) + if [[ $gpu_count -gt 0 ]]; then + echo "GPU found." + else + echo "Need at least 1 GPU to run benchmarking." + exit 1 + fi + declare -g gpu_type=$(echo $(nvidia-smi --query-gpu=name --format=csv,noheader) | awk '{print $2}') + echo "GPU type is $gpu_type" +} + +check_hf_token() { + # check if HF_TOKEN is available and valid + if [[ -z "$HF_TOKEN" ]]; then + echo "Error: HF_TOKEN is not set." + exit 1 + elif [[ ! "$HF_TOKEN" =~ ^hf_ ]]; then + echo "Error: HF_TOKEN does not start with 'hf_'." + exit 1 + else + echo "HF_TOKEN is set and valid." + fi +} + +json2args() { + # transforms the JSON string to command line args, and '_' is replaced to '-' + # example: + # input: { "model": "meta-llama/Llama-2-7b-chat-hf", "tensor_parallel_size": 1 } + # output: --model meta-llama/Llama-2-7b-chat-hf --tensor-parallel-size 1 + local json_string=$1 + local args=$( + echo "$json_string" | jq -r ' + to_entries | + map("--" + (.key | gsub("_"; "-")) + " " + (.value | tostring)) | + join(" ") + ' + ) + echo "$args" +} + +wait_for_server() { + # wait for vllm server to start + # return 1 if vllm server crashes + timeout 1200 bash -c ' + until curl localhost:8000/v1/completions; do + sleep 1 + done' && return 0 || return 1 +} + +kill_gpu_processes() { + # kill all processes on GPU. + pids=$(nvidia-smi --query-compute-apps=pid --format=csv,noheader) + if [ -z "$pids" ]; then + echo "No GPU processes found." + else + for pid in $pids; do + kill -9 "$pid" + echo "Killed process with PID: $pid" + done + + echo "All GPU processes have been killed." + fi + + # waiting for GPU processes to be fully killed + sleep 10 + + # remove vllm config file + rm -rf ~/.config/vllm + + # Print the GPU memory usage + # so that we know if all GPU processes are killed. + gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0) + # The memory usage should be 0 MB. + echo "GPU 0 Memory Usage: $gpu_memory_usage MB" +} + +upload_to_buildkite() { + # upload the benchmarking results to buildkite + + # if the agent binary is not found, skip uploading the results, exit 0 + if [ ! -f /workspace/buildkite-agent ]; then + echo "buildkite-agent binary not found. Skip uploading the results." + return 0 + fi + /workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < $RESULTS_FOLDER/benchmark_results.md + /workspace/buildkite-agent artifact upload "$RESULTS_FOLDER/*" +} + +run_latency_tests() { + # run latency tests using `benchmark_latency.py` + # $1: a json file specifying latency test cases + + local latency_test_file + latency_test_file=$1 + + # Iterate over latency tests + jq -c '.[]' "$latency_test_file" | while read -r params; do + # get the test name, and append the GPU type back to it. + test_name=$(echo "$params" | jq -r '.test_name') + if [[ ! "$test_name" =~ ^latency_ ]]; then + echo "In latency-test.json, test_name must start with \"latency_\"." + exit 1 + fi + + # if TEST_SELECTOR is set, only run the test cases that match the selector + if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then + echo "Skip test case $test_name." + continue + fi + + # get arguments + latency_params=$(echo "$params" | jq -r '.parameters') + latency_args=$(json2args "$latency_params") + + # check if there is enough GPU to run the test + tp=$(echo "$latency_params" | jq -r '.tensor_parallel_size') + if [[ $gpu_count -lt $tp ]]; then + echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $testname." + continue + fi + + latency_command="python3 benchmark_latency.py \ + --output-json $RESULTS_FOLDER/${test_name}.json \ + $latency_args" + + echo "Running test case $test_name" + echo "Latency command: $latency_command" + + # recoding benchmarking command ang GPU command + jq_output=$(jq -n \ + --arg latency "$latency_command" \ + --arg gpu "$gpu_type" \ + '{ + latency_command: $latency, + gpu_type: $gpu + }') + echo "$jq_output" > "$RESULTS_FOLDER/$test_name.commands" + + # run the benchmark + eval "$latency_command" + + kill_gpu_processes + + done +} + + +run_throughput_tests() { + # run throughput tests using `benchmark_throughput.py` + # $1: a json file specifying throughput test cases + + local throughput_test_file + throughput_test_file=$1 + + # Iterate over throughput tests + jq -c '.[]' "$throughput_test_file" | while read -r params; do + # get the test name, and append the GPU type back to it. + test_name=$(echo "$params" | jq -r '.test_name') + if [[ ! "$test_name" =~ ^throughput_ ]]; then + echo "In throughput-test.json, test_name must start with \"throughput_\"." + exit 1 + fi + + # if TEST_SELECTOR is set, only run the test cases that match the selector + if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then + echo "Skip test case $test_name." + continue + fi + + # get arguments + throughput_params=$(echo "$params" | jq -r '.parameters') + throughput_args=$(json2args "$throughput_params") + + # check if there is enough GPU to run the test + tp=$(echo $throughput_params | jq -r '.tensor_parallel_size') + if [[ $gpu_count -lt $tp ]]; then + echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $testname." + continue + fi + + throughput_command="python3 benchmark_throughput.py \ + --output-json $RESULTS_FOLDER/${test_name}.json \ + $throughput_args" + + echo "Running test case $test_name" + echo "Throughput command: $throughput_command" + # recoding benchmarking command ang GPU command + jq_output=$(jq -n \ + --arg command "$throughput_command" \ + --arg gpu "$gpu_type" \ + '{ + throughput_command: $command, + gpu_type: $gpu + }') + echo "$jq_output" > "$RESULTS_FOLDER/$test_name.commands" + + # run the benchmark + eval "$throughput_command" + + kill_gpu_processes + + done +} + +run_serving_tests() { + # run serving tests using `benchmark_serving.py` + # $1: a json file specifying serving test cases + + local serving_test_file + serving_test_file=$1 + + # Iterate over serving tests + jq -c '.[]' "$serving_test_file" | while read -r params; do + # get the test name, and append the GPU type back to it. + test_name=$(echo "$params" | jq -r '.test_name') + if [[ ! "$test_name" =~ ^serving_ ]]; then + echo "In serving-test.json, test_name must start with \"serving_\"." + exit 1 + fi + + # if TEST_SELECTOR is set, only run the test cases that match the selector + if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then + echo "Skip test case $test_name." + continue + fi + + + # get client and server arguments + server_params=$(echo "$params" | jq -r '.server_parameters') + client_params=$(echo "$params" | jq -r '.client_parameters') + server_args=$(json2args "$server_params") + client_args=$(json2args "$client_params") + qps_list=$(echo "$params" | jq -r '.qps_list') + qps_list=$(echo "$qps_list" | jq -r '.[] | @sh') + echo "Running over qps list $qps_list" + + # check if there is enough GPU to run the test + tp=$(echo "$server_params" | jq -r '.tensor_parallel_size') + if [[ $gpu_count -lt $tp ]]; then + echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $testname." + continue + fi + + # check if server model and client model is aligned + server_model=$(echo "$server_params" | jq -r '.model') + client_model=$(echo "$client_params" | jq -r '.model') + if [[ $server_model != "$client_model" ]]; then + echo "Server model and client model must be the same. Skip testcase $testname." + continue + fi + + server_command="python3 \ + -m vllm.entrypoints.openai.api_server \ + $server_args" + + # run the server + echo "Running test case $test_name" + echo "Server command: $server_command" + eval "$server_command" & + + # wait until the server is alive + wait_for_server + if [ $? -eq 0 ]; then + echo "" + echo "vllm server is up and running." + else + echo "" + echo "vllm failed to start within the timeout period." + fi + + # iterate over different QPS + for qps in $qps_list; do + # remove the surrounding single quote from qps + if [[ "$qps" == *"inf"* ]]; then + echo "qps was $qps" + qps="inf" + echo "now qps is $qps" + fi + + new_test_name=$test_name"_qps_"$qps + + client_command="python3 benchmark_serving.py \ + --save-result \ + --result-dir $RESULTS_FOLDER \ + --result-filename ${new_test_name}.json \ + --request-rate $qps \ + $client_args" + + echo "Running test case $test_name with qps $qps" + echo "Client command: $client_command" + + eval "$client_command" + + # record the benchmarking commands + jq_output=$(jq -n \ + --arg server "$server_command" \ + --arg client "$client_command" \ + --arg gpu "$gpu_type" \ + '{ + server_command: $server, + client_command: $client, + gpu_type: $gpu + }') + echo "$jq_output" > "$RESULTS_FOLDER/${new_test_name}.commands" + + done + + # clean up + kill_gpu_processes + done +} + +main() { + check_gpus + check_hf_token + + # dependencies + (which wget && which curl) || (apt-get update && apt-get install -y wget curl) + (which jq) || (apt-get update && apt-get -y install jq) + + # get the current IP address, required by benchmark_serving.py + export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') + # turn of the reporting of the status of each request, to clean up the terminal output + export VLLM_LOG_LEVEL="WARNING" + + # prepare for benchmarking + cd benchmarks || exit 1 + wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + declare -g RESULTS_FOLDER=results/ + mkdir -p $RESULTS_FOLDER + QUICK_BENCHMARK_ROOT=../.buildkite/nightly-benchmarks/ + + # benchmarking + run_serving_tests $QUICK_BENCHMARK_ROOT/tests/serving-tests.json + run_latency_tests $QUICK_BENCHMARK_ROOT/tests/latency-tests.json + run_throughput_tests $QUICK_BENCHMARK_ROOT/tests/throughput-tests.json + + + # postprocess benchmarking results + pip install tabulate pandas + python3 $QUICK_BENCHMARK_ROOT/scripts/convert-results-json-to-markdown.py + + upload_to_buildkite +} + +main "$@" diff --git a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py new file mode 100644 index 0000000000000..534ecf17930e9 --- /dev/null +++ b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py @@ -0,0 +1,192 @@ +import json +import os +from pathlib import Path + +import pandas as pd +from tabulate import tabulate + +results_folder = Path("results/") + +# latency results and the keys that will be printed into markdown +latency_results = [] +latency_column_mapping = { + "test_name": "Test name", + "gpu_type": "GPU", + "avg_latency": "Mean latency (ms)", + # "P10": "P10 (s)", + # "P25": "P25 (s)", + "P50": "Median latency (ms)", + # "P75": "P75 (s)", + # "P90": "P90 (s)", + "P99": "P99 latency (ms)", +} + +# throughput tests and the keys that will be printed into markdown +throughput_results = [] +throughput_results_column_mapping = { + "test_name": "Test name", + "gpu_type": "GPU", + # "num_requests": "# of req.", + # "total_num_tokens": "Total # of tokens", + # "elapsed_time": "Elapsed time (s)", + "requests_per_second": "Tput (req/s)", + # "tokens_per_second": "Tput (tok/s)", +} + +# serving results and the keys that will be printed into markdown +serving_results = [] +serving_column_mapping = { + "test_name": "Test name", + "gpu_type": "GPU", + # "completed": "# of req.", + "request_throughput": "Tput (req/s)", + # "input_throughput": "Input Tput (tok/s)", + # "output_throughput": "Output Tput (tok/s)", + "mean_ttft_ms": "Mean TTFT (ms)", + "median_ttft_ms": "Median TTFT (ms)", + "p99_ttft_ms": "P99 TTFT (ms)", + # "mean_tpot_ms": "Mean TPOT (ms)", + # "median_tpot_ms": "Median", + # "p99_tpot_ms": "P99", + "mean_itl_ms": "Mean ITL (ms)", + "median_itl_ms": "Median ITL (ms)", + "p99_itl_ms": "P99 ITL (ms)", +} + + +def read_markdown(file): + if os.path.exists(file): + with open(file, "r") as f: + return f.read() + "\n" + else: + return f"{file} not found.\n" + + +def results_to_json(latency, throughput, serving): + return json.dumps({ + 'latency': latency.to_dict(), + 'throughput': throughput.to_dict(), + 'serving': serving.to_dict() + }) + + +if __name__ == "__main__": + + # collect results + for test_file in results_folder.glob("*.json"): + + with open(test_file, "r") as f: + raw_result = json.loads(f.read()) + + if "serving" in str(test_file): + # this result is generated via `benchmark_serving.py` + + # attach the benchmarking command to raw_result + with open(test_file.with_suffix(".commands"), "r") as f: + command = json.loads(f.read()) + raw_result.update(command) + + # update the test name of this result + raw_result.update({"test_name": test_file.stem}) + + # add the result to raw_result + serving_results.append(raw_result) + continue + + elif "latency" in f.name: + # this result is generated via `benchmark_latency.py` + + # attach the benchmarking command to raw_result + with open(test_file.with_suffix(".commands"), "r") as f: + command = json.loads(f.read()) + raw_result.update(command) + + # update the test name of this result + raw_result.update({"test_name": test_file.stem}) + + # get different percentiles + for perc in [10, 25, 50, 75, 90, 99]: + # Multiply 1000 to convert the time unit from s to ms + raw_result.update( + {f"P{perc}": 1000 * raw_result["percentiles"][str(perc)]}) + raw_result["avg_latency"] = raw_result["avg_latency"] * 1000 + + # add the result to raw_result + latency_results.append(raw_result) + continue + + elif "throughput" in f.name: + # this result is generated via `benchmark_throughput.py` + + # attach the benchmarking command to raw_result + with open(test_file.with_suffix(".commands"), "r") as f: + command = json.loads(f.read()) + raw_result.update(command) + + # update the test name of this result + raw_result.update({"test_name": test_file.stem}) + + # add the result to raw_result + throughput_results.append(raw_result) + continue + + print(f"Skipping {test_file}") + + latency_results = pd.DataFrame.from_dict(latency_results) + serving_results = pd.DataFrame.from_dict(serving_results) + throughput_results = pd.DataFrame.from_dict(throughput_results) + + raw_results_json = results_to_json(latency_results, throughput_results, + serving_results) + + # remapping the key, for visualization purpose + if not latency_results.empty: + latency_results = latency_results[list( + latency_column_mapping.keys())].rename( + columns=latency_column_mapping) + if not serving_results.empty: + serving_results = serving_results[list( + serving_column_mapping.keys())].rename( + columns=serving_column_mapping) + if not throughput_results.empty: + throughput_results = throughput_results[list( + throughput_results_column_mapping.keys())].rename( + columns=throughput_results_column_mapping) + + processed_results_json = results_to_json(latency_results, + throughput_results, + serving_results) + + # get markdown tables + latency_md_table = tabulate(latency_results, + headers='keys', + tablefmt='pipe', + showindex=False) + serving_md_table = tabulate(serving_results, + headers='keys', + tablefmt='pipe', + showindex=False) + throughput_md_table = tabulate(throughput_results, + headers='keys', + tablefmt='pipe', + showindex=False) + + # document the result + with open(results_folder / "benchmark_results.md", "w") as f: + + results = read_markdown( + "../.buildkite/nightly-benchmarks/tests/descriptions.md") + results = results.format( + latency_tests_markdown_table=latency_md_table, + throughput_tests_markdown_table=throughput_md_table, + serving_tests_markdown_table=serving_md_table, + benchmarking_results_in_json_string=processed_results_json) + f.write(results) + + # document benchmarking results in json + with open(results_folder / "benchmark_results.json", "w") as f: + + results = latency_results.to_dict( + orient='records') + throughput_results.to_dict( + orient='records') + serving_results.to_dict(orient='records') + f.write(json.dumps(results)) diff --git a/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh b/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh new file mode 100644 index 0000000000000..c785e6a0da628 --- /dev/null +++ b/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh @@ -0,0 +1,17 @@ +#!/bin/sh +TOKEN=$(curl -s -L "https://public.ecr.aws/token?service=public.ecr.aws&scope=repository:q9t5s3a7/vllm-ci-test-repo:pull" | jq -r .token) +URL="https://public.ecr.aws/v2/q9t5s3a7/vllm-ci-test-repo/manifests/$BUILDKITE_COMMIT" + +retries=0 +while [ $retries -lt 1000 ]; do + if [ $(curl -s -L -H "Authorization: Bearer $TOKEN" -o /dev/null -w "%{http_code}" $URL) -eq 200 ]; then + exit 0 + fi + + echo "Waiting for image to be available..." + + retries=$((retries + 1)) + sleep 5 +done + +exit 1 \ No newline at end of file diff --git a/.buildkite/nightly-benchmarks/tests/descriptions.md b/.buildkite/nightly-benchmarks/tests/descriptions.md new file mode 100644 index 0000000000000..891e4917070d9 --- /dev/null +++ b/.buildkite/nightly-benchmarks/tests/descriptions.md @@ -0,0 +1,67 @@ + +## Latency tests + +This test suite aims to test vllm's end-to-end latency under a controlled setup. + +- Input length: 32 tokens. +- Output length: 128 tokens. +- Batch size: fixed (8). +- Models: llama-3 8B, llama-3 70B, mixtral 8x7B. +- Evaluation metrics: end-to-end latency (mean, median, p99). + +### Latency benchmarking results + +{latency_tests_markdown_table} + +## Throughput tests + +This test suite aims to test vllm's throughput. + +- Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed). +- Output length: the corresponding output length of these 200 prompts. +- Batch size: dynamically determined by vllm to achieve maximum throughput. +- Models: llama-3 8B, llama-3 70B, mixtral 8x7B. +- Evaluation metrics: throughput. + +### Throughput benchmarking results + +{throughput_tests_markdown_table} + +## Serving tests + +This test suite aims to test vllm's real serving metrics. + +- Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed). +- Output length: the corresponding output length of these 200 prompts. +- Batch size: dynamically determined by vllm and the arrival pattern of the requests. +- **Average QPS (query per second)**: 1, 4, 16 and inf. QPS = inf means all requests come at once. For other QPS values, the arrival time of each query is determined using a random Poisson process (with fixed random seed). +- Models: llama-3 8B, llama-3 70B, mixtral 8x7B. +- Evaluation metrics: throughput, TTFT (time to the first token, with mean, median and p99), ITL (inter-token latency, with mean, median and p99). + +### Serving benchmarking results + +{serving_tests_markdown_table} + +## json version of the benchmarking tables + +This section contains the data of the markdown tables above in JSON format. +You can load the benchmarking tables into pandas dataframes as follows: + +```python +import json +import pandas as pd + +benchmarking_results_json = """The json string""" +benchmarking_results = json.loads(benchmarking_results_json) +latency_results = pd.DataFrame.from_dict(benchmarking_results["latency"]) +throughput_results = pd.DataFrame.from_dict(benchmarking_results["throughput"]) +serving_results = pd.DataFrame.from_dict(benchmarking_results["serving"]) +``` + +The json string for all benchmarking tables: +```json +{benchmarking_results_in_json_string} +``` + +You can also check the raw experiment data in the Artifact tab of the Buildkite page. + diff --git a/.buildkite/nightly-benchmarks/tests/latency-tests.json b/.buildkite/nightly-benchmarks/tests/latency-tests.json new file mode 100644 index 0000000000000..06488cd79110a --- /dev/null +++ b/.buildkite/nightly-benchmarks/tests/latency-tests.json @@ -0,0 +1,32 @@ +[ + { + "test_name": "latency_llama8B_tp1", + "parameters": { + "model": "meta-llama/Meta-Llama-3-8B", + "tensor_parallel_size": 1, + "load_format": "dummy", + "num_iters_warmup": 5, + "num_iters": 15 + } + }, + { + "test_name": "latency_llama70B_tp4", + "parameters": { + "model": "meta-llama/Meta-Llama-3-70B-Instruct", + "tensor_parallel_size": 4, + "load_format": "dummy", + "num-iters-warmup": 5, + "num-iters": 15 + } + }, + { + "test_name": "latency_mixtral8x7B_tp2", + "parameters": { + "model": "mistralai/Mixtral-8x7B-Instruct-v0.1", + "tensor_parallel_size": 2, + "load_format": "dummy", + "num-iters-warmup": 5, + "num-iters": 15 + } + } +] \ No newline at end of file diff --git a/.buildkite/nightly-benchmarks/tests/serving-tests.json b/.buildkite/nightly-benchmarks/tests/serving-tests.json new file mode 100644 index 0000000000000..86a0fefa339f7 --- /dev/null +++ b/.buildkite/nightly-benchmarks/tests/serving-tests.json @@ -0,0 +1,59 @@ +[ + { + "test_name": "serving_llama8B_tp1_sharegpt", + "qps_list": [1, 4, 16, "inf"], + "server_parameters": { + "model": "meta-llama/Meta-Llama-3-8B", + "tensor_parallel_size": 1, + "swap_space": 16, + "disable_log_stats": "", + "disable_log_requests": "", + "load_format": "dummy" + }, + "client_parameters": { + "model": "meta-llama/Meta-Llama-3-8B", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama70B_tp4_sharegpt", + "qps_list": [1, 4, 16, "inf"], + "server_parameters": { + "model": "meta-llama/Meta-Llama-3-70B-Instruct", + "tensor_parallel_size": 4, + "swap_space": 16, + "disable_log_stats": "", + "disable_log_requests": "", + "load_format": "dummy" + }, + "client_parameters": { + "model": "meta-llama/Meta-Llama-3-70B-Instruct", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_mixtral8x7B_tp2_sharegpt", + "qps_list": [1, 4, 16, "inf"], + "server_parameters": { + "model": "mistralai/Mixtral-8x7B-Instruct-v0.1", + "tensor_parallel_size": 2, + "swap_space": 16, + "disable_log_stats": "", + "disable_log_requests": "", + "load_format": "dummy" + }, + "client_parameters": { + "model": "mistralai/Mixtral-8x7B-Instruct-v0.1", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + } +] \ No newline at end of file diff --git a/.buildkite/nightly-benchmarks/tests/throughput-tests.json b/.buildkite/nightly-benchmarks/tests/throughput-tests.json new file mode 100644 index 0000000000000..41ac135748704 --- /dev/null +++ b/.buildkite/nightly-benchmarks/tests/throughput-tests.json @@ -0,0 +1,35 @@ +[ + { + "test_name": "throughput_llama8B_tp1", + "parameters": { + "model": "meta-llama/Meta-Llama-3-8B", + "tensor_parallel_size": 1, + "load_format": "dummy", + "dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200, + "backend": "vllm" + } + }, + { + "test_name": "throughput_llama70B_tp4", + "parameters": { + "model": "meta-llama/Meta-Llama-3-70B-Instruct", + "tensor_parallel_size": 4, + "load_format": "dummy", + "dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200, + "backend": "vllm" + } + }, + { + "test_name": "throughput_mixtral8x7B_tp2", + "parameters": { + "model": "mistralai/Mixtral-8x7B-Instruct-v0.1", + "tensor_parallel_size": 2, + "load_format": "dummy", + "dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200, + "backend": "vllm" + } + } +] \ No newline at end of file diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml new file mode 100644 index 0000000000000..1959f9752069f --- /dev/null +++ b/.buildkite/release-pipeline.yaml @@ -0,0 +1,21 @@ +steps: + - block: "Build wheels" + + - label: "Build wheel - Python {{matrix.python_version}}, CUDA {{matrix.cuda_version}}" + agents: + queue: cpu_queue + commands: + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg CUDA_VERSION={{matrix.cuda_version}} --build-arg PYTHON_VERSION={{matrix.python_version}} --tag vllm-ci:build-image --target build --progress plain ." + - "mkdir artifacts" + - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image cp -r dist /artifacts_host" + - "aws s3 cp --recursive artifacts/dist s3://vllm-wheels/$BUILDKITE_COMMIT/" + matrix: + setup: + cuda_version: + - "11.8.0" + - "12.1.0" + python_version: + - "3.8" + - "3.9" + - "3.10" + - "3.11" diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index ce508e4748aba..bde8ab6184d3c 100644 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -1,10 +1,38 @@ -# This script build the ROCm docker image and runs test inside it. +# This script runs test inside the corresponding ROCm docker container. set -ex # Print ROCm version echo "--- ROCm info" rocminfo +# cleanup older docker images +cleanup_docker() { + # Get Docker's root directory + docker_root=$(docker info -f '{{.DockerRootDir}}') + if [ -z "$docker_root" ]; then + echo "Failed to determine Docker root directory." + exit 1 + fi + echo "Docker root directory: $docker_root" + # Check disk usage of the filesystem where Docker's root directory is located + disk_usage=$(df "$docker_root" | tail -1 | awk '{print $5}' | sed 's/%//') + # Define the threshold + threshold=70 + if [ "$disk_usage" -gt "$threshold" ]; then + echo "Disk usage is above $threshold%. Cleaning up Docker images and volumes..." + # Remove dangling images (those that are not tagged and not used by any container) + docker image prune -f + # Remove unused volumes + docker volume prune -f + echo "Docker images and volumes cleanup completed." + else + echo "Disk usage is below $threshold%. No cleanup needed." + fi +} + +# Call the cleanup docker function +cleanup_docker + echo "--- Resetting GPUs" echo "reset" > /opt/amdgpu/etc/gpu_state @@ -19,15 +47,16 @@ done echo "--- Building container" sha=$(git rev-parse --short HEAD) -container_name=rocm_${sha} +image_name=rocm_${sha} +container_name=rocm_${sha}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo) docker build \ - -t ${container_name} \ + -t ${image_name} \ -f Dockerfile.rocm \ --progress plain \ . remove_docker_container() { - docker rm -f ${container_name} || docker image rm -f ${container_name} || true + docker rm -f ${container_name} || docker image rm -f ${image_name} || true } trap remove_docker_container EXIT @@ -39,6 +68,6 @@ docker run \ --rm \ -e HF_TOKEN \ --name ${container_name} \ - ${container_name} \ + ${image_name} \ /bin/bash -c "${@}" diff --git a/.buildkite/run-benchmarks.sh b/.buildkite/run-benchmarks.sh index 7fbad1c4bd950..cbf6dda677c53 100644 --- a/.buildkite/run-benchmarks.sh +++ b/.buildkite/run-benchmarks.sh @@ -9,10 +9,10 @@ cd "$(dirname "${BASH_SOURCE[0]}")/.." (which wget && which curl) || (apt-get update && apt-get install -y wget curl) # run python-based benchmarks and upload the result to buildkite -python3 benchmarks/benchmark_latency.py 2>&1 | tee benchmark_latency.txt +python3 benchmarks/benchmark_latency.py --output-json latency_results.json 2>&1 | tee benchmark_latency.txt bench_latency_exit_code=$? -python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 2>&1 | tee benchmark_throughput.txt +python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 --output-json throughput_results.json 2>&1 | tee benchmark_throughput.txt bench_throughput_exit_code=$? # run server-based benchmarks and upload the result to buildkite @@ -50,16 +50,16 @@ echo "### Serving Benchmarks" >> benchmark_results.md sed -n '1p' benchmark_serving.txt >> benchmark_results.md # first line echo "" >> benchmark_results.md echo '```' >> benchmark_results.md -tail -n 20 benchmark_serving.txt >> benchmark_results.md # last 20 lines +tail -n 24 benchmark_serving.txt >> benchmark_results.md # last 24 lines echo '```' >> benchmark_results.md # if the agent binary is not found, skip uploading the results, exit 0 -if [ ! -f /workspace/buildkite-agent ]; then +if [ ! -f /usr/bin/buildkite-agent ]; then exit 0 fi # upload the results to buildkite -/workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md +buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md # exit with the exit code of the benchmarks if [ $bench_latency_exit_code -ne 0 ]; then @@ -74,4 +74,5 @@ if [ $bench_serving_exit_code -ne 0 ]; then exit $bench_serving_exit_code fi -/workspace/buildkite-agent artifact upload openai-*.json +rm ShareGPT_V3_unfiltered_cleaned_split.json +buildkite-agent artifact upload "*.json" diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index f187d1f181724..f4fa24be1f20f 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -4,11 +4,23 @@ set -ex # Try building the docker image docker build -t cpu-test -f Dockerfile.cpu . +docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" -t cpu-test-avx2 -f Dockerfile.cpu . # Setup cleanup -remove_docker_container() { docker rm -f cpu-test || true; } +remove_docker_container() { docker rm -f cpu-test cpu-test-avx2 || true; } trap remove_docker_container EXIT remove_docker_container -# Run the image and launch offline inference -docker run --network host --env VLLM_CPU_KVCACHE_SPACE=1 --name cpu-test cpu-test python3 examples/offline_inference.py +# Run the image +docker run -itd -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus=48-95 --cpuset-mems=1 --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test cpu-test +docker run -itd -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus=48-95 --cpuset-mems=1 --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test-avx2 cpu-test-avx2 + +# offline inference +docker exec cpu-test bash -c "python3 examples/offline_inference.py" +docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py" + +# Run basic model test +docker exec cpu-test bash -c "cd tests; + pip install pytest Pillow protobuf + cd ../ + pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py" diff --git a/.buildkite/run-openvino-test.sh b/.buildkite/run-openvino-test.sh new file mode 100755 index 0000000000000..70e56596c4a86 --- /dev/null +++ b/.buildkite/run-openvino-test.sh @@ -0,0 +1,14 @@ +# This script build the OpenVINO docker image and run the offline inference inside the container. +# It serves a sanity check for compilation and basic model usage. +set -ex + +# Try building the docker image +docker build -t openvino-test -f Dockerfile.openvino . + +# Setup cleanup +remove_docker_container() { docker rm -f openvino-test || true; } +trap remove_docker_container EXIT +remove_docker_container + +# Run the image and launch offline inference +docker run --network host --env VLLM_OPENVINO_KVCACHE_SPACE=1 --name openvino-test openvino-test python3 /workspace/vllm/examples/offline_inference.py diff --git a/.buildkite/run-xpu-test.sh b/.buildkite/run-xpu-test.sh new file mode 100644 index 0000000000000..22a7e76937a76 --- /dev/null +++ b/.buildkite/run-xpu-test.sh @@ -0,0 +1,14 @@ +# This script build the CPU docker image and run the offline inference inside the container. +# It serves a sanity check for compilation and basic model usage. +set -ex + +# Try building the docker image +docker build -t xpu-test -f Dockerfile.xpu . + +# Setup cleanup +remove_docker_container() { docker rm -f xpu-test || true; } +trap remove_docker_container EXIT +remove_docker_container + +# Run the image and launch offline inference +docker run --network host --name xpu-test --device /dev/dri -v /dev/dri/by-path:/dev/dri/by-path xpu-test python3 examples/offline_inference.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index cee5e7e9d2a73..d96e3c6d192e2 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -1,17 +1,23 @@ # In this file, you can add more tests to run either by adding a new step or # adding a new command to an existing step. See different options here for examples. -# This script will be feed into Jinja template in `test-template.j2` to generate -# the final pipeline yaml file. + +# This script will be feed into Jinja template in `test-template-aws.j2` at +# https://github.com/vllm-project/buildkite-ci/blob/main/scripts/test-template-aws.j2 +# to generate the final pipeline yaml file. + steps: - label: Regression Test + mirror_hardwares: [amd] command: pytest -v -s test_regression.py working_dir: "/vllm-workspace/tests" # optional - label: AsyncEngine Test + #mirror_hardwares: [amd] command: pytest -v -s async_engine - label: Basic Correctness Test + mirror_hardwares: [amd] commands: - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py @@ -21,68 +27,99 @@ steps: - label: Core Test mirror_hardwares: [amd] - command: pytest -v -s core + commands: + - pytest -v -s core + - pytest -v -s distributed/test_parallel_state.py - label: Distributed Comm Ops Test - command: pytest -v -s test_comm_ops.py - working_dir: "/vllm-workspace/tests/distributed" + #mirror_hardwares: [amd] + working_dir: "/vllm-workspace/tests" num_gpus: 2 + commands: + - pytest -v -s distributed/test_comm_ops.py + - pytest -v -s distributed/test_shm_broadcast.py -- label: Distributed Tests - working_dir: "/vllm-workspace/tests/distributed" - - num_gpus: 2 # only support 1 or 2 for now. +- label: Distributed Tests (2 GPUs) mirror_hardwares: [amd] - + working_dir: "/vllm-workspace/tests" + num_gpus: 2 commands: - - pytest -v -s test_pynccl_library.py - - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py - -- label: Distributed Tests (Multiple Groups) - working_dir: "/vllm-workspace/tests/distributed" + - bash ../.buildkite/download-images.sh + - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py + - TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py + - TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py + - TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py + - TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py + - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py + - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py + - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py + +- label: Distributed Tests (4 GPUs) + #mirror_hardwares: [amd] + working_dir: "/vllm-workspace/tests" num_gpus: 4 commands: - - pytest -v -s test_pynccl.py + - pytest -v -s distributed/test_pynccl.py + # We want to test that models which use 2 GPUs work with 4 GPUs, which is why we duplicate them here. + # See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context. + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py + - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py - label: Engine Test - #mirror_hardwares: [amd] + mirror_hardwares: [amd] command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py - label: Entrypoints Test + mirror_hardwares: [amd] + commands: - # these tests have to be separated, because each one will allocate all posible GPU memory - - pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py - - pytest -v -s entrypoints/test_server_oot_registration.py + - pytest -v -s entrypoints/llm + - pytest -v -s entrypoints/openai - label: Examples Test working_dir: "/vllm-workspace/examples" mirror_hardwares: [amd] commands: # install aws cli for llava_example.py - - pip install awscli + # install tensorizer for tensorize_vllm_model.py + - pip install awscli tensorizer - python3 offline_inference.py - python3 offline_inference_with_prefix.py - python3 llm_engine_example.py - python3 llava_example.py + - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors + +- label: Inputs Test + #mirror_hardwares: [amd] + commands: + - bash ../.buildkite/download-images.sh + - pytest -v -s test_inputs.py + - pytest -v -s multimodal - label: Kernels Test %N + #mirror_hardwares: [amd] command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 4 - label: Models Test #mirror_hardwares: [amd] commands: - - bash ../.buildkite/download-images.sh - - pytest -v -s models --ignore=models/test_llava.py --ignore=models/test_mistral.py + - pytest -v -s models -m \"not vlm\" -- label: Llava Test - #mirror_hardwares: [amd] +- label: Vision Language Models Test + mirror_hardwares: [amd] commands: - bash ../.buildkite/download-images.sh - - pytest -v -s models/test_llava.py + - pytest -v -s models -m vlm - label: Prefix Caching Test mirror_hardwares: [amd] @@ -90,33 +127,63 @@ steps: - pytest -v -s prefix_caching - label: Samplers Test + #mirror_hardwares: [amd] command: pytest -v -s samplers - label: LogitsProcessor Test mirror_hardwares: [amd] command: pytest -v -s test_logits_processor.py +- label: Utils Test + command: pytest -v -s test_utils.py + - label: Worker Test mirror_hardwares: [amd] command: pytest -v -s worker - label: Speculative decoding tests #mirror_hardwares: [amd] - command: pytest -v -s spec_decode + commands: + # See https://github.com/vllm-project/vllm/issues/5152 + - export VLLM_ATTENTION_BACKEND=XFORMERS + - pytest -v -s spec_decode - label: LoRA Test %N - command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + #mirror_hardwares: [amd] + command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py parallelism: 4 +- label: LoRA Long Context (Distributed) + #mirror_hardwares: [amd] + num_gpus: 4 + # This test runs llama 13B, so it is required to run on 4 GPUs. + commands: + # FIXIT: find out which code initialize cuda before running the test + # before the fix, we need to use spawn to test it + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -v -s -x lora/test_long_context.py + - label: Tensorizer Test + #mirror_hardwares: [amd] command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader - label: Metrics Test + mirror_hardwares: [amd] command: pytest -v -s metrics - label: Quantization Test + #mirror_hardwares: [amd] command: pytest -v -s quantization +- label: Tracing Test + commands: + - "pip install \ + opentelemetry-sdk \ + opentelemetry-api \ + opentelemetry-exporter-otlp \ + opentelemetry-semantic-conventions-ai" + - pytest -v -s tracing + - label: Benchmarks working_dir: "/vllm-workspace/.buildkite" mirror_hardwares: [amd] @@ -124,9 +191,39 @@ steps: - pip install aiohttp - bash run-benchmarks.sh +- label: LM Eval Small Models + working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" + commands: + - pip install lm-eval + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - bash ./run-tests.sh -c configs/models-small.txt -t 1 + +- label: LM Eval Large Models + gpu: a100 + num_gpus: 4 + working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" + commands: + - pip install lm-eval + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - bash ./run-tests.sh -c configs/models-large.txt -t 4 + - label: Documentation Build working_dir: "/vllm-workspace/test_docs/docs" no_gpu: True commands: - pip install -r requirements-docs.txt - SPHINXOPTS=\"-W\" make html + +- label: Distributed Tests (A100) + gpu: a100 + num_gpus: 4 + commands: + # NOTE: don't test llama model here, it seems hf implementation is buggy + # see https://github.com/vllm-project/vllm/pull/5689 for details + - pytest -v -s distributed/test_custom_all_reduce.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py + - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.5/flashinfer-0.0.5+cu121torch2.3-cp310-cp310-linux_x86_64.whl + - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - pytest -v -s -x lora/test_mixtral.py diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 deleted file mode 100644 index 174c756ae74a3..0000000000000 --- a/.buildkite/test-template.j2 +++ /dev/null @@ -1,94 +0,0 @@ -{% set docker_image = "us-central1-docker.pkg.dev/vllm-405802/vllm-ci-test-repo/vllm-test:$BUILDKITE_COMMIT" %} -{% set default_num_gpu = 1 %} -{% set default_working_dir = "/vllm-workspace/tests" %} - -steps: - - - label: ":docker: build image" - commands: - - "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ." - - "docker push {{ docker_image }}" - env: - DOCKER_BUILDKIT: "1" - retry: - automatic: - - exit_status: -1 # Agent was lost - limit: 5 - - exit_status: -10 # Agent was lost - limit: 5 - - wait - - - group: "AMD Tests" - depends_on: ~ - steps: - {% for step in steps %} - {% if step.mirror_hardwares and "amd" in step.mirror_hardwares %} - - label: "AMD: {{ step.label }}" - agents: - queue: amd - command: bash .buildkite/run-amd-test.sh "cd {{ (step.working_dir or default_working_dir) | safe }} ; {{ step.command or (step.commands | join(" ; ")) | safe }}" - env: - DOCKER_BUILDKIT: "1" - {% endif %} - {% endfor %} - - - label: "Neuron Test" - depends_on: ~ - agents: - queue: neuron - command: bash .buildkite/run-neuron-test.sh - soft_fail: true - - - label: "Intel Test" - depends_on: ~ - command: bash .buildkite/run-cpu-test.sh - - {% for step in steps %} - - label: "{{ step.label }}" - agents: - queue: kubernetes - soft_fail: {{ step.soft_fail or false }} - {% if step.parallelism %} - parallelism: {{ step.parallelism }} - {% endif %} - retry: - automatic: - - exit_status: -1 # Agent was lost - limit: 5 - - exit_status: -10 # Agent was lost - limit: 5 - plugins: - - kubernetes: - podSpec: - {% if step.num_gpus %} - priorityClassName: gpu-priority-cls-{{ step.num_gpus }} - {% endif %} - volumes: - - name: dshm - emptyDir: - medium: Memory - containers: - - image: "{{ docker_image }}" - command: ["bash"] - args: - - '-c' - - "'cd {{ (step.working_dir or default_working_dir) | safe }} && {{ step.command or (step.commands | join(' && ')) | safe }}'" - {% if not step.no_gpu %} - resources: - requests: - nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}" - limits: - nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}" - {% endif %} - env: - - name: VLLM_USAGE_SOURCE - value: ci-test - - name: HF_TOKEN - valueFrom: - secretKeyRef: - name: hf-token-secret - key: token - volumeMounts: - - mountPath: /dev/shm - name: dshm - {% endfor %} diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000000000..7f9e6d720fae5 --- /dev/null +++ b/.clang-format @@ -0,0 +1,26 @@ +BasedOnStyle: Google +UseTab: Never +IndentWidth: 2 +ColumnLimit: 80 + +# Force pointers to the type for C++. +DerivePointerAlignment: false +PointerAlignment: Left + +# Reordering #include statements can (and currently will) introduce errors +SortIncludes: false + +# Style choices +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +IndentPPDirectives: BeforeHash + +IncludeCategories: + - Regex: '^<' + Priority: 4 + - Regex: '^"(llvm|llvm-c|clang|clang-c|mlir|mlir-c)/' + Priority: 3 + - Regex: '^"(qoda|\.\.)/' + Priority: 2 + - Regex: '.*' + Priority: 1 diff --git a/.github/ISSUE_TEMPLATE/400-bug report.yml b/.github/ISSUE_TEMPLATE/400-bug report.yml index 08120ad8e5a60..ce980c3f4a01d 100644 --- a/.github/ISSUE_TEMPLATE/400-bug report.yml +++ b/.github/ISSUE_TEMPLATE/400-bug report.yml @@ -59,6 +59,8 @@ body: Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````. + Please set the environment variable `export VLLM_LOGGING_LEVEL=DEBUG` to turn on more logging to help debugging potential issues. + If you experienced crashes or hangs, it would be helpful to run vllm with `export VLLM_TRACE_FUNCTION=1` . All the function calls in vllm will be recorded. Inspect these log files, and tell which function crashes or hangs. placeholder: | A clear and concise description of what the bug is. diff --git a/.github/workflows/clang-format.yml b/.github/workflows/clang-format.yml new file mode 100644 index 0000000000000..e9b6e28fa6bcb --- /dev/null +++ b/.github/workflows/clang-format.yml @@ -0,0 +1,42 @@ +name: clang-format + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + clang-format: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.11"] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install clang-format==18.1.5 + - name: Running clang-format + run: | + EXCLUDES=( + 'csrc/moe/topk_softmax_kernels.cu' + 'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu' + 'csrc/punica/bgmv/bgmv_config.h' + 'csrc/punica/bgmv/bgmv_impl.cuh' + 'csrc/punica/bgmv/vec_dtypes.cuh' + 'csrc/punica/punica_ops.cu' + 'csrc/punica/type_convert.h' + ) + find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \ + | grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \ + | xargs clang-format --dry-run --Werror \ No newline at end of file diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index a20753d8a7702..62f0dbcd93eff 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -37,6 +37,7 @@ jobs: mypy vllm/distributed --config-file pyproject.toml mypy vllm/entrypoints --config-file pyproject.toml mypy vllm/executor --config-file pyproject.toml + mypy vllm/multimodal --config-file pyproject.toml mypy vllm/usage --config-file pyproject.toml mypy vllm/*.py --config-file pyproject.toml mypy vllm/transformers_utils --config-file pyproject.toml @@ -46,5 +47,5 @@ jobs: mypy vllm/model_executor --config-file pyproject.toml mypy vllm/lora --config-file pyproject.toml mypy vllm/logging --config-file pyproject.toml - mypy vllm/model_executor --config-file pyproject.toml + mypy tests --config-file pyproject.toml diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index e71033f828006..773def58fd966 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -25,7 +25,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install ruff==0.1.5 codespell==2.2.6 tomli==2.0.1 isort==5.13.2 + pip install ruff==0.1.5 codespell==2.3.0 tomli==2.0.1 isort==5.13.2 - name: Analysing the code with ruff run: | ruff . diff --git a/CMakeLists.txt b/CMakeLists.txt index f817f3382c5e1..ede9192cd1dbb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,7 +2,8 @@ cmake_minimum_required(VERSION 3.21) project(vllm_extensions LANGUAGES CXX) -option(VLLM_TARGET_DEVICE "Target device backend for vLLM" "cuda") +# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py) +set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM") message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") message(STATUS "Target device: ${VLLM_TARGET_DEVICE}") @@ -32,8 +33,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11 # versions are derived from Dockerfile.rocm # set(TORCH_SUPPORTED_VERSION_CUDA "2.3.0") -set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1") -set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1") +set(TORCH_SUPPORTED_VERSION_ROCM "2.4.0") # # Try to find python package with an executable that exactly matches @@ -66,19 +66,6 @@ endif() # find_package(Torch REQUIRED) -# -# Normally `torch.utils.cpp_extension.CUDAExtension` would add -# `libtorch_python.so` for linking against an extension. Torch's cmake -# configuration does not include this library (presumably since the cmake -# config is used for standalone C++ binaries that link against torch). -# The `libtorch_python.so` library defines some of the glue code between -# torch/python via pybind and is required by VLLM extensions for this -# reason. So, add it by manually with `find_library` using torch's -# installed library path. -# -find_library(torch_python_LIBRARY torch_python PATHS - "${TORCH_INSTALL_PREFIX}/lib") - # # Forward the non-CUDA device extensions to external CMake scripts. # @@ -111,18 +98,11 @@ elseif(HIP_FOUND) # .hip extension automatically, HIP must be enabled explicitly. enable_language(HIP) - # ROCm 5.x - if (ROCM_VERSION_DEV_MAJOR EQUAL 5 AND - NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM_5X}) - message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM_5X} " - "expected for ROCMm 5.x build, saw ${Torch_VERSION} instead.") - endif() - - # ROCm 6.x - if (ROCM_VERSION_DEV_MAJOR EQUAL 6 AND - NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM_6X}) - message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM_6X} " - "expected for ROCMm 6.x build, saw ${Torch_VERSION} instead.") + # ROCm 5.X and 6.X + if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND + NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM}) + message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM} " + "expected for ROCm build, saw ${Torch_VERSION} instead.") endif() else() message(FATAL_ERROR "Can't find CUDA or HIP installation.") @@ -167,19 +147,47 @@ set(VLLM_EXT_SRC "csrc/layernorm_kernels.cu" "csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/gptq/q_gemm.cu" - "csrc/quantization/fp8/fp8_cuda_kernels.cu" + "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" + "csrc/quantization/fp8/common.cu" "csrc/cuda_utils_kernels.cu" "csrc/moe_align_block_size_kernels.cu" - "csrc/pybind.cpp") + "csrc/torch_bindings.cpp") if(VLLM_GPU_LANG STREQUAL "CUDA") + include(FetchContent) + SET(CUTLASS_ENABLE_HEADERS_ONLY=ON) + FetchContent_Declare( + cutlass + GIT_REPOSITORY https://github.com/nvidia/cutlass.git + # CUTLASS 3.5.0 + GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc + ) + FetchContent_MakeAvailable(cutlass) + list(APPEND VLLM_EXT_SRC "csrc/quantization/aqlm/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu" - "csrc/quantization/marlin/marlin_cuda_kernel.cu" + "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" + "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" "csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" - "csrc/custom_all_reduce.cu") + "csrc/custom_all_reduce.cu" + "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" + "csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu" + "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu") + + # + # The CUTLASS kernels for Hopper require sm90a to be enabled. + # This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a. + # That adds an extra 17MB to compiled binary, so instead we selectively enable it. + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) + set_source_files_properties( + "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu" + PROPERTIES + COMPILE_FLAGS + "-gencode arch=compute_90a,code=sm_90a") + endif() + endif() define_gpu_extension_target( @@ -189,6 +197,8 @@ define_gpu_extension_target( SOURCES ${VLLM_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + USE_SABI 3 WITH_SOABI) # @@ -196,7 +206,7 @@ define_gpu_extension_target( # set(VLLM_MOE_EXT_SRC - "csrc/moe/moe_ops.cpp" + "csrc/moe/torch_bindings.cpp" "csrc/moe/topk_softmax_kernels.cu") define_gpu_extension_target( @@ -206,6 +216,7 @@ define_gpu_extension_target( SOURCES ${VLLM_MOE_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} + USE_SABI 3 WITH_SOABI) # @@ -219,7 +230,8 @@ set(VLLM_PUNICA_EXT_SRC "csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu" "csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu" "csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu" - "csrc/punica/punica_ops.cc") + "csrc/punica/punica_ops.cu" + "csrc/punica/torch_bindings.cpp") # # Copy GPU compilation flags+update for punica @@ -243,6 +255,9 @@ if (${VLLM_GPU_LANG} STREQUAL "CUDA") endif() endforeach() message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}") +elseif(${VLLM_GPU_LANG} STREQUAL "HIP") + set(VLLM_PUNICA_GPU_ARCHES ${VLLM_GPU_ARCHES}) + message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}") endif() if (VLLM_PUNICA_GPU_ARCHES) @@ -253,6 +268,7 @@ if (VLLM_PUNICA_GPU_ARCHES) SOURCES ${VLLM_PUNICA_EXT_SRC} COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS} ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES} + USE_SABI 3 WITH_SOABI) else() message(WARNING "Unable to create _punica_C target because none of the " @@ -277,9 +293,7 @@ add_custom_target(default) if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") message(STATUS "Enabling C extension.") add_dependencies(default _C) -endif() -if(VLLM_GPU_LANG STREQUAL "CUDA") message(STATUS "Enabling moe extension.") add_dependencies(default _moe_C) diff --git a/Dockerfile b/Dockerfile index 90be3a30f89b1..d031d98c5b7e4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,18 +5,35 @@ # docs/source/dev/dockerfile/dockerfile.rst and # docs/source/assets/dev/dockerfile-stages-dependency.png +ARG CUDA_VERSION=12.4.1 #################### BASE BUILD IMAGE #################### # prepare basic build environment -FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS dev +FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS base + +ARG CUDA_VERSION=12.4.1 +ARG PYTHON_VERSION=3 + +ENV DEBIAN_FRONTEND=noninteractive + +RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ + && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \ + && apt-get update -y \ + && apt-get install -y ccache software-properties-common \ + && add-apt-repository ppa:deadsnakes/ppa \ + && apt-get update -y \ + && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv python3-pip \ + && if [ "${PYTHON_VERSION}" != "3" ]; then update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1; fi \ + && python3 --version \ + && python3 -m pip --version RUN apt-get update -y \ - && apt-get install -y python3-pip git + && apt-get install -y python3-pip git curl sudo # Workaround for https://github.com/openai/triton/issues/2507 and # https://github.com/pytorch/pytorch/issues/107960 -- hopefully # this won't be needed for future versions of this docker image # or future versions of triton. -RUN ldconfig /usr/local/cuda-12.4/compat/ +RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/ WORKDIR /workspace @@ -24,12 +41,7 @@ WORKDIR /workspace COPY requirements-common.txt requirements-common.txt COPY requirements-cuda.txt requirements-cuda.txt RUN --mount=type=cache,target=/root/.cache/pip \ - pip install -r requirements-cuda.txt - -# install development dependencies -COPY requirements-dev.txt requirements-dev.txt -RUN --mount=type=cache,target=/root/.cache/pip \ - pip install -r requirements-dev.txt + python3 -m pip install -r requirements-cuda.txt # cuda arch list used by torch # can be useful for both `dev` and `test` @@ -39,14 +51,16 @@ ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX' ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} #################### BASE BUILD IMAGE #################### - #################### WHEEL BUILD IMAGE #################### -FROM dev AS build +FROM base AS build + +ARG PYTHON_VERSION=3 # install build dependencies COPY requirements-build.txt requirements-build.txt + RUN --mount=type=cache,target=/root/.cache/pip \ - pip install -r requirements-build.txt + python3 -m pip install -r requirements-build.txt # install compiler cache to speed up compilation leveraging local or remote caching RUN apt-get update -y && apt-get install -y ccache @@ -70,43 +84,50 @@ ENV NVCC_THREADS=$nvcc_threads # make sure punica kernels are built (for LoRA) ENV VLLM_INSTALL_PUNICA_KERNELS=1 +ARG USE_SCCACHE +# if USE_SCCACHE is set, use sccache to speed up compilation +RUN --mount=type=cache,target=/root/.cache/pip \ + if [ "$USE_SCCACHE" = "1" ]; then \ + echo "Installing sccache..." \ + && curl -L -o sccache.tar.gz https://github.com/mozilla/sccache/releases/download/v0.8.1/sccache-v0.8.1-x86_64-unknown-linux-musl.tar.gz \ + && tar -xzf sccache.tar.gz \ + && sudo mv sccache-v0.8.1-x86_64-unknown-linux-musl/sccache /usr/bin/sccache \ + && rm -rf sccache.tar.gz sccache-v0.8.1-x86_64-unknown-linux-musl \ + && export SCCACHE_BUCKET=vllm-build-sccache \ + && export SCCACHE_REGION=us-west-2 \ + && sccache --show-stats \ + && python3 setup.py bdist_wheel --dist-dir=dist \ + && sccache --show-stats; \ + fi + ENV CCACHE_DIR=/root/.cache/ccache RUN --mount=type=cache,target=/root/.cache/ccache \ --mount=type=cache,target=/root/.cache/pip \ - python3 setup.py bdist_wheel --dist-dir=dist + if [ "$USE_SCCACHE" != "1" ]; then \ + python3 setup.py bdist_wheel --dist-dir=dist; \ + fi # check the size of the wheel, we cannot upload wheels larger than 100MB COPY .buildkite/check-wheel-size.py check-wheel-size.py RUN python3 check-wheel-size.py dist -# the `vllm_nccl` package must be installed from source distribution -# pip is too smart to store a wheel in the cache, and other CI jobs -# will directly use the wheel from the cache, which is not what we want. -# we need to remove it manually -RUN --mount=type=cache,target=/root/.cache/pip \ - pip cache remove vllm_nccl* #################### EXTENSION Build IMAGE #################### -#################### FLASH_ATTENTION Build IMAGE #################### -FROM dev as flash-attn-builder -# max jobs used for build -ARG max_jobs=2 -ENV MAX_JOBS=${max_jobs} -# flash attention version -ARG flash_attn_version=v2.5.8 -ENV FLASH_ATTN_VERSION=${flash_attn_version} - -WORKDIR /usr/src/flash-attention-v2 +#################### DEV IMAGE #################### +FROM base as dev -# Download the wheel or build it if a pre-compiled release doesn't exist -RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \ - --no-build-isolation --no-deps --no-cache-dir +COPY requirements-lint.txt requirements-lint.txt +COPY requirements-test.txt requirements-test.txt +COPY requirements-dev.txt requirements-dev.txt +RUN --mount=type=cache,target=/root/.cache/pip \ + python3 -m pip install -r requirements-dev.txt -#################### FLASH_ATTENTION Build IMAGE #################### +#################### DEV IMAGE #################### #################### vLLM installation IMAGE #################### # image with vLLM installed -FROM nvidia/cuda:12.4.1-base-ubuntu22.04 AS vllm-base +FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu22.04 AS vllm-base +ARG CUDA_VERSION=12.4.1 WORKDIR /vllm-workspace RUN apt-get update -y \ @@ -116,16 +137,12 @@ RUN apt-get update -y \ # https://github.com/pytorch/pytorch/issues/107960 -- hopefully # this won't be needed for future versions of this docker image # or future versions of triton. -RUN ldconfig /usr/local/cuda-12.4/compat/ +RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/ # install vllm wheel first, so that torch etc will be installed RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ --mount=type=cache,target=/root/.cache/pip \ - pip install dist/*.whl --verbose - -RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \ - --mount=type=cache,target=/root/.cache/pip \ - pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir + python3 -m pip install dist/*.whl --verbose #################### vLLM installation IMAGE #################### @@ -138,7 +155,7 @@ ADD . /vllm-workspace/ # install development dependencies (for testing) RUN --mount=type=cache,target=/root/.cache/pip \ - pip install -r requirements-dev.txt + python3 -m pip install -r requirements-dev.txt # doc requires source code # we hide them inside `test_docs/` , so that this source code @@ -155,7 +172,7 @@ FROM vllm-base AS vllm-openai # install additional dependencies for openai api server RUN --mount=type=cache,target=/root/.cache/pip \ - pip install accelerate hf_transfer modelscope + pip install accelerate hf_transfer 'modelscope!=1.15.0' ENV VLLM_USAGE_SOURCE production-docker-image diff --git a/Dockerfile.cpu b/Dockerfile.cpu index 4251fddd6cc3b..6e55203decc56 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -1,13 +1,19 @@ # This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform. -FROM ubuntu:22.04 +FROM ubuntu:22.04 AS cpu-test-1 RUN apt-get update -y \ - && apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \ + && apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 \ && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 +RUN echo 'export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD' >> ~/.bashrc + +RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.3.100%2Bgit0eb3473-cp310-cp310-linux_x86_64.whl + RUN pip install --upgrade pip \ - && pip install wheel packaging ninja setuptools>=49.4.0 numpy + && pip install wheel packaging ninja "setuptools>=49.4.0" numpy + +FROM cpu-test-1 AS build COPY ./ /workspace/vllm @@ -15,6 +21,14 @@ WORKDIR /workspace/vllm RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu +# Support for building with non-AVX512 vLLM: docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" ... +ARG VLLM_CPU_DISABLE_AVX512 +ENV VLLM_CPU_DISABLE_AVX512=${VLLM_CPU_DISABLE_AVX512} + RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install +WORKDIR /workspace/ + +RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks + CMD ["/bin/bash"] diff --git a/Dockerfile.neuron b/Dockerfile.neuron index fe42b4ef393f1..010f23a143010 100644 --- a/Dockerfile.neuron +++ b/Dockerfile.neuron @@ -28,7 +28,7 @@ COPY ./requirements-neuron.txt /app/vllm/requirements-neuron.txt RUN cd /app/vllm \ && python3 -m pip install -U -r requirements-neuron.txt -ENV VLLM_BUILD_WITH_NEURON 1 +ENV VLLM_TARGET_DEVICE neuron RUN cd /app/vllm \ && pip install -e . \ && cd .. diff --git a/Dockerfile.openvino b/Dockerfile.openvino new file mode 100644 index 0000000000000..9861997b451a9 --- /dev/null +++ b/Dockerfile.openvino @@ -0,0 +1,26 @@ +# The vLLM Dockerfile is used to construct vLLM image that can be directly used +# to run the OpenAI compatible server. + +FROM ubuntu:22.04 AS dev + +RUN apt-get update -y && \ + apt-get install -y python3-pip git +WORKDIR /workspace + +# copy requirements +COPY requirements-build.txt /workspace/vllm/ +COPY requirements-common.txt /workspace/vllm/ +COPY requirements-openvino.txt /workspace/vllm/ + +COPY vllm/ /workspace/vllm/vllm +COPY setup.py /workspace/vllm/ + +# install build requirements +RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/vllm/requirements-build.txt +# build vLLM with OpenVINO backend +RUN PIP_PRE=1 PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/nightly/" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace/vllm/ + +COPY examples/ /workspace/vllm/examples +COPY benchmarks/ /workspace/vllm/benchmarks + +CMD ["/bin/bash"] diff --git a/Dockerfile.ppc64le b/Dockerfile.ppc64le new file mode 100644 index 0000000000000..d4e4c483cada8 --- /dev/null +++ b/Dockerfile.ppc64le @@ -0,0 +1,22 @@ +FROM mambaorg/micromamba +ARG MAMBA_DOCKERFILE_ACTIVATE=1 +USER root + +RUN apt-get update -y && apt-get install -y git wget vim numactl gcc-12 g++-12 protobuf-compiler libprotobuf-dev && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 + +# Some packages in requirements-cpu are installed here +# IBM provides optimized packages for ppc64le processors in the open-ce project for mamba +# Currently these may not be available for venv or pip directly +RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults python=3.10 pytorch-cpu=2.1.2 torchvision-cpu=0.16.2 && micromamba clean --all --yes + +COPY ./ /workspace/vllm + +WORKDIR /workspace/vllm + +# These packages will be in rocketce eventually +RUN pip install -v -r requirements-cpu.txt --prefer-binary --extra-index-url https://repo.fury.io/mgiessing + +RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install + +WORKDIR /vllm-workspace +ENTRYPOINT ["/opt/conda/bin/python3", "-m", "vllm.entrypoints.openai.api_server"] diff --git a/Dockerfile.rocm b/Dockerfile.rocm index d04bb9915e2ab..1b89b892bbf1c 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -1,35 +1,35 @@ -# default base image -ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" - -FROM $BASE_IMAGE - -ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" - -RUN echo "Base image is $BASE_IMAGE" - -# BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" -# BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" - - +# Default ROCm 6.1 base image +ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging" + +# Tested and supported base rocm/pytorch images +ARG ROCm_5_7_BASE="rocm/pytorch:rocm5.7_ubuntu20.04_py3.9_pytorch_2.0.1" \ + ROCm_6_0_BASE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" \ + ROCM_6_1_BASE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging" + +# Default ROCm ARCHes to build vLLM for. +ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100" + +# Whether to build CK-based flash-attention +# If 0, will not build flash attention +# This is useful for gfx target where flash-attention is not supported +# (i.e. those that do not appear in `FA_GFX_ARCHS`) +# Triton FA is used by default on ROCm now so this is unnecessary. +ARG BUILD_FA="1" ARG FA_GFX_ARCHS="gfx90a;gfx942" -RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS" - ARG FA_BRANCH="ae7928c" -RUN echo "FA_BRANCH is $FA_BRANCH" -# whether to build flash-attention -# if 0, will not build flash attention -# this is useful for gfx target where flash-attention is not supported -# In that case, we need to use the python reference attention implementation in vllm -ARG BUILD_FA="1" - -# whether to build triton on rocm +# Whether to build triton on rocm ARG BUILD_TRITON="1" +ARG TRITON_BRANCH="0ef1848" -# Install some basic utilities -RUN apt-get update && apt-get install python3 python3-pip -y +### Base image build stage +FROM $BASE_IMAGE AS base + +# Import arg(s) defined before this build stage +ARG PYTORCH_ROCM_ARCH # Install some basic utilities +RUN apt-get update && apt-get install python3 python3-pip -y RUN apt-get update && apt-get install -y \ curl \ ca-certificates \ @@ -40,68 +40,165 @@ RUN apt-get update && apt-get install -y \ build-essential \ wget \ unzip \ - nvidia-cuda-toolkit \ tmux \ + ccache \ && rm -rf /var/lib/apt/lists/* -### Mount Point ### -# When launching the container, mount the code directory to /app +# When launching the container, mount the code directory to /vllm-workspace ARG APP_MOUNT=/vllm-workspace -VOLUME [ ${APP_MOUNT} ] WORKDIR ${APP_MOUNT} -RUN python3 -m pip install --upgrade pip -RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas +RUN pip install --upgrade pip +# Remove sccache so it doesn't interfere with ccache +# TODO: implement sccache support across components +RUN apt-get purge -y sccache; pip uninstall -y sccache; rm -f "$(which sccache)" +# Install torch == 2.4.0 on ROCm +RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ + *"rocm-5.7"*) \ + pip uninstall -y torch torchaudio torchvision \ + && pip install --no-cache-dir --pre \ + torch==2.4.0.dev20240612 torchaudio==2.4.0.dev20240612 \ + torchvision==0.19.0.dev20240612 \ + --index-url https://download.pytorch.org/whl/nightly/rocm5.7;; \ + *"rocm-6.0"*) \ + pip uninstall -y torch torchaudio torchvision \ + && pip install --no-cache-dir --pre \ + torch==2.4.0.dev20240612 torchaudio==2.4.0.dev20240612 \ + torchvision==0.19.0.dev20240612 \ + --index-url https://download.pytorch.org/whl/nightly/rocm6.0;; \ + *"rocm-6.1"*) \ + pip uninstall -y torch torchaudio torchvision \ + && pip install --no-cache-dir --pre \ + torch==2.4.0.dev20240612 torchaudio==2.4.0.dev20240612 \ + torchvision==0.19.0.dev20240612 \ + --index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \ + *) ;; esac ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin: ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib: ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/: -# Install ROCm flash-attention -RUN if [ "$BUILD_FA" = "1" ]; then \ - mkdir libs \ +ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} +ENV CCACHE_DIR=/root/.cache/ccache + + +### AMD-SMI build stage +FROM base AS build_amdsmi +# Build amdsmi wheel always +RUN cd /opt/rocm/share/amd_smi \ + && pip wheel . --wheel-dir=/install + + +### Flash-Attention wheel build stage +FROM base AS build_fa +ARG BUILD_FA +ARG FA_GFX_ARCHS +ARG FA_BRANCH +# Build ROCm flash-attention wheel if `BUILD_FA = 1` +RUN --mount=type=cache,target=${CCACHE_DIR} \ + if [ "$BUILD_FA" = "1" ]; then \ + mkdir -p libs \ && cd libs \ && git clone https://github.com/ROCm/flash-attention.git \ && cd flash-attention \ - && git checkout ${FA_BRANCH} \ + && git checkout "${FA_BRANCH}" \ && git submodule update --init \ - && export GPU_ARCHS=${FA_GFX_ARCHS} \ - && if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \ - patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \ - && python3 setup.py install \ - && cd ..; \ + && case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ + *"rocm-5.7"*) \ + export VLLM_TORCH_PATH="$(python3 -c 'import torch; print(torch.__path__[0])')" \ + && patch "${VLLM_TORCH_PATH}"/utils/hipify/hipify_python.py hipify_patch.patch;; \ + *) ;; esac \ + && GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \ + # Create an empty directory otherwise as later build stages expect one + else mkdir -p /install; \ fi -# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt. -# Manually removed it so that later steps of numpy upgrade can continue -RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \ - rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi -# build triton -RUN if [ "$BUILD_TRITON" = "1" ]; then \ +### Triton wheel build stage +FROM base AS build_triton +ARG BUILD_TRITON +ARG TRITON_BRANCH +# Build triton wheel if `BUILD_TRITON = 1` +RUN --mount=type=cache,target=${CCACHE_DIR} \ + if [ "$BUILD_TRITON" = "1" ]; then \ mkdir -p libs \ && cd libs \ - && pip uninstall -y triton \ - && git clone https://github.com/ROCm/triton.git \ - && cd triton/python \ - && pip3 install . \ - && cd ../..; \ + && git clone https://github.com/OpenAI/triton.git \ + && cd triton \ + && git checkout "${TRITON_BRANCH}" \ + && cd python \ + && python3 setup.py bdist_wheel --dist-dir=/install; \ + # Create an empty directory otherwise as later build stages expect one + else mkdir -p /install; \ fi -WORKDIR /vllm-workspace + +### Final vLLM build stage +FROM base AS final +# Import the vLLM development directory from the build context COPY . . -RUN python3 -m pip install --upgrade pip numba +# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt. +# Manually remove it so that later steps of numpy upgrade can continue +RUN case "$(which python3)" in \ + *"/opt/conda/envs/py_3.9"*) \ + rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/;; \ + *) ;; esac +# Package upgrades for useful functionality or to avoid dependency issues RUN --mount=type=cache,target=/root/.cache/pip \ + pip install --upgrade numba scipy huggingface-hub[cli] + +# Make sure punica kernels are built (for LoRA) +ENV VLLM_INSTALL_PUNICA_KERNELS=1 +# Workaround for ray >= 2.10.0 +ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 +# Silences the HF Tokenizers warning +ENV TOKENIZERS_PARALLELISM=false + +RUN --mount=type=cache,target=${CCACHE_DIR} \ + --mount=type=cache,target=/root/.cache/pip \ pip install -U -r requirements-rocm.txt \ - && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \ - && python3 setup.py install \ - && cp build/lib.linux-x86_64-cpython-39/vllm/_C.cpython-39-x86_64-linux-gnu.so vllm/ \ - && cd .. + && case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ + *"rocm-6.0"*) \ + patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h rocm_patch/rocm_bf16.patch;; \ + *"rocm-6.1"*) \ + # Bring in upgrades to HIP graph earlier than ROCm 6.2 for vLLM + wget -N https://github.com/ROCm/vllm/raw/fa78403/rocm_patch/libamdhip64.so.6 -P rocm_patch \ + && cp rocm_patch/libamdhip64.so.6 /opt/rocm/lib/libamdhip64.so.6 \ + # Prevent interference if torch bundles its own HIP runtime + && rm -f "$(python3 -c 'import torch; print(torch.__path__[0])')"/lib/libamdhip64.so* || true;; \ + *) ;; esac \ + && python3 setup.py clean --all \ + && python3 setup.py develop + +# Copy amdsmi wheel into final image +RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \ + mkdir -p libs \ + && cp /install/*.whl libs \ + # Preemptively uninstall to avoid same-version no-installs + && pip uninstall -y amdsmi; + +# Copy triton wheel(s) into final image if they were built +RUN --mount=type=bind,from=build_triton,src=/install,target=/install \ + mkdir -p libs \ + && if ls /install/*.whl; then \ + cp /install/*.whl libs \ + # Preemptively uninstall to avoid same-version no-installs + && pip uninstall -y triton; fi -RUN python3 -m pip install --upgrade pip -RUN python3 -m pip install --no-cache-dir ray[all]==2.9.3 +# Copy flash-attn wheel(s) into final image if they were built +RUN --mount=type=bind,from=build_fa,src=/install,target=/install \ + mkdir -p libs \ + && if ls /install/*.whl; then \ + cp /install/*.whl libs \ + # Preemptively uninstall to avoid same-version no-installs + && pip uninstall -y flash-attn; fi + +# Install wheels that were built to the final image +RUN --mount=type=cache,target=/root/.cache/pip \ + if ls libs/*.whl; then \ + pip install libs/*.whl; fi CMD ["/bin/bash"] diff --git a/Dockerfile.tpu b/Dockerfile.tpu new file mode 100644 index 0000000000000..931c844c08dce --- /dev/null +++ b/Dockerfile.tpu @@ -0,0 +1,19 @@ +ARG NIGHTLY_DATE="20240601" +ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE" + +FROM $BASE_IMAGE + +WORKDIR /workspace +COPY . /workspace/vllm + +ENV VLLM_TARGET_DEVICE="tpu" +# Install aiohttp separately to avoid build errors. +RUN pip install aiohttp +# Install the TPU and Pallas dependencies. +RUN pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html +RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html + +# Build vLLM. +RUN cd /workspace/vllm && python setup.py develop + +CMD ["/bin/bash"] diff --git a/Dockerfile.xpu b/Dockerfile.xpu new file mode 100644 index 0000000000000..c39e551672d20 --- /dev/null +++ b/Dockerfile.xpu @@ -0,0 +1,22 @@ +FROM intel/oneapi-basekit:2024.1.0-devel-ubuntu22.04 + +RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/intel-oneapi-archive-keyring.gpg > /dev/null && \ + echo "deb [signed-by=/usr/share/keyrings/intel-oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main " | tee /etc/apt/sources.list.d/oneAPI.list && \ + chmod 644 /usr/share/keyrings/intel-oneapi-archive-keyring.gpg && \ + rm /etc/apt/sources.list.d/intel-graphics.list && \ + wget -O- https://repositories.intel.com/graphics/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null && \ + echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/graphics/ubuntu jammy arc" | tee /etc/apt/sources.list.d/intel.gpu.jammy.list && \ + chmod 644 /usr/share/keyrings/intel-graphics.gpg + +RUN apt-get update -y \ +&& apt-get install -y curl libicu70 lsb-release git wget vim numactl python3 python3-pip + +COPY ./ /workspace/vllm + +WORKDIR /workspace/vllm + +RUN pip install -v -r requirements-xpu.txt + +RUN VLLM_TARGET_DEVICE=xpu python3 setup.py install + +CMD ["/bin/bash"] diff --git a/README.md b/README.md index 9b180877a5a82..d6957a7f5ee3a 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,19 @@ Easy, fast, and cheap LLM serving for everyone

+--- + +**Ray Summit CPF is Open (June 4th to June 20th)!** + +There will be a track for vLLM at the Ray Summit (09/30-10/02, SF) this year! +If you have cool projects related to vLLM or LLM inference, we would love to see your proposals. +This will be a great chance for everyone in the community to get together and learn. +Please submit your proposal [here](https://raysummit.anyscale.com/flow/anyscale/raysummit2024/landing/page/eventsite) + +--- + *Latest News* 🔥 +- [2024/06] We hosted [the fourth vLLM meetup](https://lu.ma/agivllm) with Cloudflare and BentoML! Please find the meetup slides [here](https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing). - [2024/05] vLLM-fork specific: Added Intel® Gaudi® 2 support with SynapseAI 1.16.0. For more information, please refer to Intel® Gaudi® README. - [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing). - [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing). @@ -48,45 +60,18 @@ vLLM is flexible and easy to use with: - Tensor parallelism support for distributed inference - Streaming outputs - OpenAI-compatible API server -- Support NVIDIA GPUs and AMD GPUs +- Support NVIDIA GPUs, AMD GPUs, Intel CPUs and GPUs - (Experimental) Prefix caching support - (Experimental) Multi-lora support -vLLM seamlessly supports many Hugging Face models, including the following architectures: - -- Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.) -- Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.) -- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.) -- ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.) -- Command-R (`CohereForAI/c4ai-command-r-v01`, etc.) -- DBRX (`databricks/dbrx-base`, `databricks/dbrx-instruct` etc.) -- DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.) -- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.) -- Gemma (`google/gemma-2b`, `google/gemma-7b`, etc.) -- GPT-2 (`gpt2`, `gpt2-xl`, etc.) -- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.) -- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.) -- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.) -- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.) -- InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.) -- Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.) -- LLaMA, Llama 2, and Meta Llama 3 (`meta-llama/Meta-Llama-3-8B-Instruct`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) -- MiniCPM (`openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, etc.) -- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.) -- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc.) -- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.) -- OLMo (`allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc.) -- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) -- Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.) -- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.) -- Phi-3 (`microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, etc.) -- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.) -- Qwen2 (`Qwen/Qwen1.5-7B`, `Qwen/Qwen1.5-7B-Chat`, etc.) -- Qwen2MoE (`Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.) -- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.) -- Starcoder2(`bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc.) -- Xverse (`xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.) -- Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.) +vLLM seamlessly supports most popular open-source models on HuggingFace, including: +- Transformer-like LLMs (e.g., Llama) +- Mixture-of-Expert LLMs (e.g., Mixtral) +- Multi-modal LLMs (e.g., LLaVA) + +Find the full list of supported models [here](https://docs.vllm.ai/en/latest/models/supported_models.html). + +## Getting Started Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source): @@ -94,9 +79,7 @@ Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/get pip install vllm ``` -## Getting Started - -Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started. +Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to learn more. - [Installation](https://vllm.readthedocs.io/en/latest/getting_started/installation.html) - [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html) - [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html) @@ -106,6 +89,34 @@ Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started We welcome and value any contributions and collaborations. Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved. +## Sponsors + +vLLM is a community project. Our compute resources for development and testing are supported by the following organizations. Thank you for your support! + + + + +- a16z +- AMD +- Anyscale +- AWS +- Crusoe Cloud +- Databricks +- DeepInfra +- Dropbox +- Lambda Lab +- NVIDIA +- Replicate +- Roblox +- RunPod +- Sequoia Capital +- Trainy +- UC Berkeley +- UC San Diego +- ZhenFund + +We also have an official fundraising venue through [OpenCollective](https://opencollective.com/vllm). We plan to use the fund to support the development, maintenance, and adoption of vLLM. + ## Citation If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180): diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index f9d167590fe47..fe29c67086158 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -4,10 +4,13 @@ import time import traceback from dataclasses import dataclass, field -from typing import List, Optional +from typing import List, Optional, Union import aiohttp +import huggingface_hub.constants from tqdm.asyncio import tqdm +from transformers import (AutoTokenizer, PreTrainedTokenizer, + PreTrainedTokenizerFast) AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) @@ -68,9 +71,13 @@ async def async_request_tgi( chunk_bytes = chunk_bytes.strip() if not chunk_bytes: continue + chunk_bytes = chunk_bytes.decode("utf-8") - chunk = remove_prefix(chunk_bytes.decode("utf-8"), - "data:") + #NOTE: Sometimes TGI returns a ping response without + # any data, we should skip it. + if chunk_bytes.startswith(":"): + continue + chunk = remove_prefix(chunk_bytes, "data:") data = json.loads(chunk) timestamp = time.perf_counter() @@ -89,6 +96,9 @@ async def async_request_tgi( output.latency = most_recent_timestamp - st output.success = True output.generated_text = data["generated_text"] + else: + output.error = response.reason or "" + output.success = False except Exception: output.success = False exc_info = sys.exc_info() @@ -215,8 +225,8 @@ async def async_request_openai_completions( ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith( - "v1/completions" - ), "OpenAI Completions API URL must end with 'v1/completions'." + "completions" + ), "OpenAI Completions API URL must end with 'completions'." async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: assert not request_func_input.use_beam_search @@ -255,6 +265,9 @@ async def async_request_openai_completions( else: data = json.loads(chunk) + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated if data["choices"][0]["text"]: timestamp = time.perf_counter() # First token @@ -263,12 +276,8 @@ async def async_request_openai_completions( output.ttft = ttft # Decoding phase - # NOTE: Some completion API might have a last - # usage summary response without a token so we - # do not want to include as inter-token-latency - elif data.get("usage", None) is None: - output.itl.append(timestamp - - most_recent_timestamp) + output.itl.append(timestamp - + most_recent_timestamp) most_recent_timestamp = timestamp generated_text += data["choices"][0]["text"] @@ -276,6 +285,9 @@ async def async_request_openai_completions( output.generated_text = generated_text output.success = True output.latency = latency + else: + output.error = response.reason or "" + output.success = False except Exception: output.success = False exc_info = sys.exc_info() @@ -292,8 +304,8 @@ async def async_request_openai_chat_completions( ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith( - "v1/chat/completions" - ), "OpenAI Chat Completions API URL must end with 'v1/chat/completions'." + "chat/completions" + ), "OpenAI Chat Completions API URL must end with 'chat/completions'." async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: assert not request_func_input.use_beam_search @@ -378,6 +390,30 @@ def remove_prefix(text: str, prefix: str) -> str: return text +def get_model(pretrained_model_name_or_path: str): + if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true': + from modelscope import snapshot_download + else: + from huggingface_hub import snapshot_download + + model_path = snapshot_download( + model_id=pretrained_model_name_or_path, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"]) + return model_path + + +def get_tokenizer( + pretrained_model_name_or_path: str, trust_remote_code: bool +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + if pretrained_model_name_or_path is not None and not os.path.exists( + pretrained_model_name_or_path): + pretrained_model_name_or_path = get_model( + pretrained_model_name_or_path) + return AutoTokenizer.from_pretrained(pretrained_model_name_or_path, + trust_remote_code=trust_remote_code) + + ASYNC_REQUEST_FUNCS = { "tgi": async_request_tgi, "vllm": async_request_openai_completions, @@ -386,4 +422,5 @@ def remove_prefix(text: str, prefix: str) -> str: "openai": async_request_openai_completions, "openai-chat": async_request_openai_chat_completions, "tensorrt-llm": async_request_trt_llm, + "scalellm": async_request_openai_completions, } diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index e8530c2761acf..16802d879c0ca 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -1,15 +1,19 @@ """Benchmark the latency of processing a single batch of requests.""" import argparse +import json import time from pathlib import Path -from typing import Optional +from typing import List, Optional import numpy as np import torch from tqdm import tqdm from vllm import LLM, SamplingParams +from vllm.engine.arg_utils import EngineArgs +from vllm.inputs import PromptStrictInputs from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.utils import FlexibleArgumentParser def main(args: argparse.Namespace): @@ -17,20 +21,33 @@ def main(args: argparse.Namespace): # NOTE(woosuk): If the request cannot be processed in a single batch, # the engine will automatically process the request in multiple batches. - llm = LLM(model=args.model, - tokenizer=args.tokenizer, - quantization=args.quantization, - tensor_parallel_size=args.tensor_parallel_size, - trust_remote_code=args.trust_remote_code, - dtype=args.dtype, - enforce_eager=args.enforce_eager, - kv_cache_dtype=args.kv_cache_dtype, - quantization_param_path=args.quantization_param_path, - device=args.device, - ray_workers_use_nsight=args.ray_workers_use_nsight, - enable_chunked_prefill=args.enable_chunked_prefill, - download_dir=args.download_dir, - block_size=args.block_size) + llm = LLM( + model=args.model, + speculative_model=args.speculative_model, + num_speculative_tokens=args.num_speculative_tokens, + speculative_draft_tensor_parallel_size=\ + args.speculative_draft_tensor_parallel_size, + tokenizer=args.tokenizer, + quantization=args.quantization, + tensor_parallel_size=args.tensor_parallel_size, + trust_remote_code=args.trust_remote_code, + dtype=args.dtype, + max_model_len=args.max_model_len, + enforce_eager=args.enforce_eager, + kv_cache_dtype=args.kv_cache_dtype, + quantization_param_path=args.quantization_param_path, + device=args.device, + ray_workers_use_nsight=args.ray_workers_use_nsight, + use_v2_block_manager=args.use_v2_block_manager, + enable_chunked_prefill=args.enable_chunked_prefill, + download_dir=args.download_dir, + block_size=args.block_size, + gpu_memory_utilization=args.gpu_memory_utilization, + load_format=args.load_format, + distributed_executor_backend=args.distributed_executor_backend, + otlp_traces_endpoint=args.otlp_traces_endpoint, + enable_prefix_caching=args.enable_prefix_caching, + ) sampling_params = SamplingParams( n=args.n, @@ -44,7 +61,9 @@ def main(args: argparse.Namespace): dummy_prompt_token_ids = np.random.randint(10000, size=(args.batch_size, args.input_len)) - dummy_prompt_token_ids = dummy_prompt_token_ids.tolist() + dummy_inputs: List[PromptStrictInputs] = [{ + "prompt_token_ids": batch + } for batch in dummy_prompt_token_ids.tolist()] def run_to_completion(profile_dir: Optional[str] = None): if profile_dir: @@ -55,13 +74,13 @@ def run_to_completion(profile_dir: Optional[str] = None): ], on_trace_ready=torch.profiler.tensorboard_trace_handler( str(profile_dir))) as p: - llm.generate(prompt_token_ids=dummy_prompt_token_ids, + llm.generate(dummy_inputs, sampling_params=sampling_params, use_tqdm=False) print(p.key_averages()) else: start_time = time.perf_counter() - llm.generate(prompt_token_ids=dummy_prompt_token_ids, + llm.generate(dummy_inputs, sampling_params=sampling_params, use_tqdm=False) end_time = time.perf_counter() @@ -87,18 +106,34 @@ def run_to_completion(profile_dir: Optional[str] = None): for _ in tqdm(range(args.num_iters), desc="Profiling iterations"): latencies.append(run_to_completion(profile_dir=None)) latencies = np.array(latencies) - percentages = [10, 25, 50, 75, 90] + percentages = [10, 25, 50, 75, 90, 99] percentiles = np.percentile(latencies, percentages) print(f'Avg latency: {np.mean(latencies)} seconds') for percentage, percentile in zip(percentages, percentiles): print(f'{percentage}% percentile latency: {percentile} seconds') + # Output JSON results if specified + if args.output_json: + results = { + "avg_latency": np.mean(latencies), + "latencies": latencies.tolist(), + "percentiles": dict(zip(percentages, percentiles.tolist())), + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + if __name__ == '__main__': - parser = argparse.ArgumentParser( + parser = FlexibleArgumentParser( description='Benchmark the latency of processing a single batch of ' 'requests till completion.') parser.add_argument('--model', type=str, default='facebook/opt-125m') + parser.add_argument('--speculative-model', type=str, default=None) + parser.add_argument('--num-speculative-tokens', type=int, default=None) + parser.add_argument('--speculative-draft-tensor-parallel-size', + '-spec-draft-tp', + type=int, + default=None) parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--quantization', '-q', @@ -124,6 +159,12 @@ def run_to_completion(profile_dir: Optional[str] = None): parser.add_argument('--trust-remote-code', action='store_true', help='trust remote code from huggingface') + parser.add_argument( + '--max-model-len', + type=int, + default=None, + help='Maximum length of a sequence (including prompt and output). ' + 'If None, will be derived from the model.') parser.add_argument( '--dtype', type=str, @@ -137,15 +178,13 @@ def run_to_completion(profile_dir: Optional[str] = None): action='store_true', help='enforce eager mode and disable CUDA graph') parser.add_argument( - "--kv-cache-dtype", + '--kv-cache-dtype', type=str, - choices=['auto', 'fp8'], - default='auto', - help= - 'Data type for kv cache storage. If "auto", will use model data type. ' - 'FP8_E5M2 (without scaling) is only supported on cuda version greater ' - 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for ' - 'common inference criteria.') + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], + default="auto", + help='Data type for kv cache storage. If "auto", will use model ' + 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' + 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') parser.add_argument( '--quantization-param-path', type=str, @@ -169,9 +208,10 @@ def run_to_completion(profile_dir: Optional[str] = None): parser.add_argument( "--device", type=str, - default="cuda", - choices=["cuda", "cpu", "hpu"], - help='device type for vLLM execution, supporting CUDA, CPU and HPU.') + default="auto", + choices=["auto", "cuda", "cpu", "hpu", "openvino", "tpu", "xpu"], + help='device type for vLLM execution, supporting CUDA, HPU, ' + 'OpenVINO and CPU.') parser.add_argument('--block-size', type=int, default=16, @@ -181,6 +221,10 @@ def run_to_completion(profile_dir: Optional[str] = None): action='store_true', help='If True, the prefill requests can be chunked based on the ' 'max_num_batched_tokens') + parser.add_argument("--enable-prefix-caching", + action='store_true', + help="Enable automatic prefix caching") + parser.add_argument('--use-v2-block-manager', action='store_true') parser.add_argument( "--ray-workers-use-nsight", action='store_true', @@ -191,5 +235,51 @@ def run_to_completion(profile_dir: Optional[str] = None): default=None, help='directory to download and load the weights, ' 'default to the default cache dir of huggingface') + parser.add_argument( + '--output-json', + type=str, + default=None, + help='Path to save the latency results in JSON format.') + parser.add_argument('--gpu-memory-utilization', + type=float, + default=0.9, + help='the fraction of GPU memory to be used for ' + 'the model executor, which can range from 0 to 1.' + 'If unspecified, will use the default value of 0.9.') + parser.add_argument( + '--load-format', + type=str, + default=EngineArgs.load_format, + choices=[ + 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer', + 'bitsandbytes' + ], + help='The format of the model weights to load.\n\n' + '* "auto" will try to load the weights in the safetensors format ' + 'and fall back to the pytorch bin format if safetensors format ' + 'is not available.\n' + '* "pt" will load the weights in the pytorch bin format.\n' + '* "safetensors" will load the weights in the safetensors format.\n' + '* "npcache" will load the weights in pytorch format and store ' + 'a numpy cache to speed up the loading.\n' + '* "dummy" will initialize the weights with random values, ' + 'which is mainly for profiling.\n' + '* "tensorizer" will load the weights using tensorizer from ' + 'CoreWeave. See the Tensorize vLLM Model script in the Examples' + 'section for more information.\n' + '* "bitsandbytes" will load the weights using bitsandbytes ' + 'quantization.\n') + parser.add_argument( + '--distributed-executor-backend', + choices=['ray', 'mp'], + default=None, + help='Backend to use for distributed serving. When more than 1 GPU ' + 'is used, will be automatically set to "ray" if installed ' + 'or "mp" (multiprocessing) otherwise.') + parser.add_argument( + '--otlp-traces-endpoint', + type=str, + default=None, + help='Target URL to which OpenTelemetry traces will be sent.') args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index 089966986984f..395107a5ec747 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -1,7 +1,7 @@ -import argparse import time from vllm import LLM, SamplingParams +from vllm.utils import FlexibleArgumentParser PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501 @@ -44,7 +44,7 @@ def main(args): if __name__ == "__main__": - parser = argparse.ArgumentParser( + parser = FlexibleArgumentParser( description='Benchmark the performance with or without automatic ' 'prefix caching.') parser.add_argument('--model', diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 2c2d69da4a7d1..42867fc40edd2 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -17,6 +17,10 @@ --dataset-path \ --request-rate \ # By default is inf --num-prompts # By default is 1000 + + when using tgi backend, add + --endpoint /generate_stream + to the end of the command above. """ import argparse import asyncio @@ -27,7 +31,7 @@ import warnings from dataclasses import dataclass from datetime import datetime -from typing import AsyncGenerator, List, Optional, Tuple +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple import numpy as np from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput, @@ -35,7 +39,15 @@ from tqdm.asyncio import tqdm from transformers import PreTrainedTokenizerBase -from vllm.transformers_utils.tokenizer import get_tokenizer +try: + from vllm.transformers_utils.tokenizer import get_tokenizer +except ImportError: + from backend_request_func import get_tokenizer + +try: + from vllm.utils import FlexibleArgumentParser +except ImportError: + from argparse import ArgumentParser as FlexibleArgumentParser @dataclass @@ -52,6 +64,9 @@ class BenchmarkMetrics: mean_tpot_ms: float median_tpot_ms: float p99_tpot_ms: float + mean_itl_ms: float + median_itl_ms: float + p99_itl_ms: float def sample_sharegpt_requests( @@ -193,24 +208,37 @@ def calculate_metrics( dur_s: float, tokenizer: PreTrainedTokenizerBase, ) -> Tuple[BenchmarkMetrics, List[int]]: - actual_output_lens = [] + actual_output_lens: List[int] = [] total_input = 0 completed = 0 - tpots = [] - ttfts = [] + itls: List[float] = [] + tpots: List[float] = [] + ttfts: List[float] = [] for i in range(len(outputs)): if outputs[i].success: - output_len = len(tokenizer(outputs[i].generated_text).input_ids) + # We use the tokenizer to count the number of output tokens for all + # serving backends instead of looking at len(outputs[i].itl) since + # multiple output tokens may be bundled together + # Note: this may inflate the output token count slightly + output_len = len( + tokenizer(outputs[i].generated_text, + add_special_tokens=False).input_ids) actual_output_lens.append(output_len) total_input += input_requests[i][1] if output_len > 1: tpots.append( (outputs[i].latency - outputs[i].ttft) / (output_len - 1)) + itls += outputs[i].itl ttfts.append(outputs[i].ttft) completed += 1 else: actual_output_lens.append(0) + if completed == 0: + warnings.warn( + "All requests failed. This is likely due to a misconfiguration " + "on the benchmark arguments.", + stacklevel=2) metrics = BenchmarkMetrics( completed=completed, total_input=total_input, @@ -222,9 +250,12 @@ def calculate_metrics( 1000, # ttfts is empty if streaming is not supported by backend median_ttft_ms=np.median(ttfts or 0) * 1000, p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, - mean_tpot_ms=np.mean(tpots) * 1000, - median_tpot_ms=np.median(tpots) * 1000, - p99_tpot_ms=np.percentile(tpots, 99) * 1000, + mean_tpot_ms=np.mean(tpots or 0) * 1000, + median_tpot_ms=np.median(tpots or 0) * 1000, + p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000, + mean_itl_ms=np.mean(itls or 0) * 1000, + median_itl_ms=np.median(itls or 0) * 1000, + p99_itl_ms=np.percentile(itls or 0, 99) * 1000, ) return metrics, actual_output_lens @@ -242,16 +273,34 @@ async def benchmark( disable_tqdm: bool, ): if backend in ASYNC_REQUEST_FUNCS: - request_func = ASYNC_REQUEST_FUNCS.get(backend) + request_func = ASYNC_REQUEST_FUNCS[backend] else: raise ValueError(f"Unknown backend: {backend}") + print("Starting initial single prompt test run...") + test_prompt, test_prompt_len, test_output_len = input_requests[0] + test_input = RequestFuncInput( + model=model_id, + prompt=test_prompt, + api_url=api_url, + prompt_len=test_prompt_len, + output_len=test_output_len, + best_of=best_of, + use_beam_search=use_beam_search, + ) + test_output = await request_func(request_func_input=test_input) + if not test_output.success: + raise ValueError( + "Initial test run failed - Please make sure benchmark arguments " + f"are correctly specified. Error: {test_output.error}") + else: + print("Initial test run completed. Starting main benchmark run...") print(f"Traffic request rate: {request_rate}") pbar = None if disable_tqdm else tqdm(total=len(input_requests)) benchmark_start_time = time.perf_counter() - tasks = [] + tasks: List[asyncio.Task] = [] async for request in get_request(input_requests, request_rate): prompt, prompt_len, output_len = request request_func_input = RequestFuncInput( @@ -269,7 +318,7 @@ async def benchmark( pbar=pbar))) outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) - if not disable_tqdm: + if pbar is not None: pbar.close() benchmark_duration = time.perf_counter() - benchmark_start_time @@ -306,6 +355,10 @@ async def benchmark( print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms)) print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms)) + print("{s:{c}^{n}}".format(s='Inter-token Latency', n=50, c='-')) + print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms)) + print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms)) + print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms)) print("=" * 50) result = { @@ -322,6 +375,9 @@ async def benchmark( "mean_tpot_ms": metrics.mean_tpot_ms, "median_tpot_ms": metrics.median_tpot_ms, "p99_tpot_ms": metrics.p99_tpot_ms, + "mean_itl_ms": metrics.mean_itl_ms, + "median_itl_ms": metrics.median_itl_ms, + "p99_itl_ms": metrics.p99_itl_ms, "input_lens": [output.prompt_len for output in outputs], "output_lens": actual_output_lens, "ttfts": [output.ttft for output in outputs], @@ -418,7 +474,7 @@ def main(args: argparse.Namespace): # Save config and results to json if args.save_result: - result_json = {} + result_json: Dict[str, Any] = {} # Setup current_dt = datetime.now().strftime("%Y%m%d-%H%M%S") @@ -451,6 +507,8 @@ def main(args: argparse.Namespace): # Save to file base_model_id = model_id.split("/")[-1] file_name = f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" #noqa + if args.result_filename: + file_name = args.result_filename if args.result_dir: file_name = os.path.join(args.result_dir, file_name) with open(file_name, "w") as outfile: @@ -458,7 +516,7 @@ def main(args: argparse.Namespace): if __name__ == "__main__": - parser = argparse.ArgumentParser( + parser = FlexibleArgumentParser( description="Benchmark the online serving throughput.") parser.add_argument( "--backend", @@ -591,6 +649,15 @@ def main(args: argparse.Namespace): help="Specify directory to save benchmark json results." "If not specified, results are saved in the current directory.", ) + parser.add_argument( + "--result-filename", + type=str, + default=None, + help="Specify the filename to save benchmark json results." + "If not specified, results will be saved in " + "{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" + " format.", + ) args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 2e8cfd3f2ca3e..ff33e3dced66f 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -10,7 +10,9 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase) +from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.utils import FlexibleArgumentParser def sample_requests( @@ -78,8 +80,10 @@ def run_vllm( enable_prefix_caching: bool, enable_chunked_prefill: bool, max_num_batched_tokens: int, + distributed_executor_backend: Optional[str], gpu_memory_utilization: float = 0.9, download_dir: Optional[str] = None, + load_format: str = EngineArgs.load_format, ) -> float: from vllm import LLM, SamplingParams llm = LLM( @@ -100,11 +104,13 @@ def run_vllm( download_dir=download_dir, enable_chunked_prefill=enable_chunked_prefill, max_num_batched_tokens=max_num_batched_tokens, + distributed_executor_backend=distributed_executor_backend, + load_format=load_format, ) # Add the requests to the engine. - prompts = [] - sampling_params = [] + prompts: List[str] = [] + sampling_params: List[SamplingParams] = [] for prompt, _, output_len in requests: prompts.append(prompt) sampling_params.append( @@ -225,8 +231,8 @@ def main(args: argparse.Namespace): args.enforce_eager, args.kv_cache_dtype, args.quantization_param_path, args.device, args.enable_prefix_caching, args.enable_chunked_prefill, - args.max_num_batched_tokens, args.gpu_memory_utilization, - args.download_dir) + args.max_num_batched_tokens, args.distributed_executor_backend, + args.gpu_memory_utilization, args.download_dir, args.load_format) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -242,9 +248,21 @@ def main(args: argparse.Namespace): print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " f"{total_num_tokens / elapsed_time:.2f} tokens/s") + # Output JSON results if specified + if args.output_json: + results = { + "elapsed_time": elapsed_time, + "num_requests": len(requests), + "total_num_tokens": total_num_tokens, + "requests_per_second": len(requests) / elapsed_time, + "tokens_per_second": total_num_tokens / elapsed_time, + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Benchmark the throughput.") + parser = FlexibleArgumentParser(description="Benchmark the throughput.") parser.add_argument("--backend", type=str, choices=["vllm", "hf", "mii"], @@ -311,15 +329,13 @@ def main(args: argparse.Namespace): action="store_true", help="enforce eager execution") parser.add_argument( - "--kv-cache-dtype", + '--kv-cache-dtype', type=str, - choices=["auto", "fp8"], + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], default="auto", - help= - 'Data type for kv cache storage. If "auto", will use model data type. ' - 'FP8_E5M2 (without scaling) is only supported on cuda version greater ' - 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for ' - 'common inference criteria.') + help='Data type for kv cache storage. If "auto", will use model ' + 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' + 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') parser.add_argument( '--quantization-param-path', type=str, @@ -333,9 +349,10 @@ def main(args: argparse.Namespace): parser.add_argument( "--device", type=str, - default="cuda", - choices=["cuda", "cpu", "hpu"], - help='device type for vLLM execution, supporting CUDA, CPU and HPU.') + default="auto", + choices=["auto", "cuda", "cpu", "hpu", "openvino", "tpu", "xpu"], + help='device type for vLLM execution, supporting CUDA, HPU, ' + 'OpenVINO and CPU.') parser.add_argument( "--enable-prefix-caching", action='store_true', @@ -353,6 +370,41 @@ def main(args: argparse.Namespace): default=None, help='directory to download and load the weights, ' 'default to the default cache dir of huggingface') + parser.add_argument( + '--output-json', + type=str, + default=None, + help='Path to save the throughput results in JSON format.') + parser.add_argument( + '--distributed-executor-backend', + choices=['ray', 'mp'], + default=None, + help='Backend to use for distributed serving. When more than 1 GPU ' + 'is used, will be automatically set to "ray" if installed ' + 'or "mp" (multiprocessing) otherwise.') + parser.add_argument( + '--load-format', + type=str, + default=EngineArgs.load_format, + choices=[ + 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer', + 'bitsandbytes' + ], + help='The format of the model weights to load.\n\n' + '* "auto" will try to load the weights in the safetensors format ' + 'and fall back to the pytorch bin format if safetensors format ' + 'is not available.\n' + '* "pt" will load the weights in the pytorch bin format.\n' + '* "safetensors" will load the weights in the safetensors format.\n' + '* "npcache" will load the weights in pytorch format and store ' + 'a numpy cache to speed up the loading.\n' + '* "dummy" will initialize the weights with random values, ' + 'which is mainly for profiling.\n' + '* "tensorizer" will load the weights using tensorizer from ' + 'CoreWeave. See the Tensorize vLLM Model script in the Examples' + 'section for more information.\n' + '* "bitsandbytes" will load the weights using bitsandbytes ' + 'quantization.\n') args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py new file mode 100644 index 0000000000000..377f8683c021f --- /dev/null +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -0,0 +1,353 @@ +import argparse +import copy +import itertools +import pickle as pkl +import time +from typing import Callable, Iterable, List, Tuple + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from weight_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops +from vllm.utils import FlexibleArgumentParser + +DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())[1:] +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] +DEFAULT_TP_SIZES = [1] + +# helpers + + +def to_fp8(tensor: torch.tensor) -> torch.tensor: + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + + +def to_int8(tensor: torch.tensor) -> torch.tensor: + return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) + + +def make_rand_tensors(dtype: torch.dtype, m: int, n: int, + k: int) -> Tuple[torch.tensor, torch.tensor]: + + a = torch.randn((m, k), device='cuda') * 5 + b = torch.randn((n, k), device='cuda').t() * 5 + + if dtype == torch.int8: + return to_int8(a), to_int8(b) + if dtype == torch.float8_e4m3fn: + return to_fp8(a), to_fp8(b) + + raise ValueError("unsupported dtype") + + +# impl + + +def pytorch_mm_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, + scale_b: torch.tensor, + out_dtype: torch.dtype) -> torch.tensor: + return torch.mm(a, b) + + +def pytorch_fp8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, + scale_b: torch.tensor, + out_dtype: torch.dtype) -> torch.tensor: + return torch._scaled_mm(a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=out_dtype) + + +def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor, + scale_a: torch.tensor, scale_b: torch.tensor, + out_dtype: torch.dtype) -> torch.tensor: + return torch._scaled_mm(a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=out_dtype, + use_fast_accum=True) + + +def cutlass_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, + scale_b: torch.tensor, + out_dtype: torch.dtype) -> torch.tensor: + return ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=out_dtype) + + +# bench +def bench_fn(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, + scale_b: torch.tensor, out_dtype: torch.dtype, label: str, + sub_label: str, fn: Callable, description: str) -> TMeasurement: + + min_run_time = 1 + + globals = { + "a": a, + "b": b, + "scale_a": scale_a, + "scale_b": scale_b, + "out_dtype": out_dtype, + "fn": fn, + } + return TBenchmark.Timer( + stmt="fn(a, b, scale_a, scale_b, out_dtype)", + globals=globals, + label=label, + sub_label=sub_label, + description=description, + ).blocked_autorange(min_run_time=min_run_time) + + +def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.int8 + a, b = make_rand_tensors(torch.int8, m, n, k) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + + timers = [] + # pytorch impl + timers.append( + bench_fn(a.to(dtype=torch.bfloat16, device="cuda"), + b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b, + torch.bfloat16, label, sub_label, pytorch_mm_impl, + "pytorch_bf16_bf16_bf16_matmul-no-scales")) + + # cutlass impl + timers.append( + bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, + cutlass_impl, "cutlass_i8_i8_bf16_scaled_mm")) + + return timers + + +def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.float8_e4m3fn + a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + + timers = [] + + # pytorch impl w. bf16 + timers.append( + bench_fn(a.to(dtype=torch.bfloat16, device="cuda"), + b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b, + torch.bfloat16, label, sub_label, pytorch_mm_impl, + "pytorch_bf16_bf16_bf16_matmul-no-scales")) + + # pytorch impl: bf16 output, without fp8 fast accum + timers.append( + bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, + pytorch_fp8_impl, "pytorch_fp8_fp8_bf16_scaled_mm")) + + # pytorch impl: bf16 output, with fp8 fast accum + timers.append( + bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, + pytorch_fp8_impl_fast_accum, + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum")) + + # pytorch impl: fp16 output, without fp8 fast accum + timers.append( + bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label, + pytorch_fp8_impl, "pytorch_fp8_fp8_fp16_scaled_mm")) + + # pytorch impl: fp16 output, with fp8 fast accum + timers.append( + bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label, + pytorch_fp8_impl_fast_accum, + "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum")) + + # cutlass impl: bf16 output + timers.append( + bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, + cutlass_impl, "cutlass_fp8_fp8_bf16_scaled_mm")) + # cutlass impl: fp16 output + timers.append( + bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label, + cutlass_impl, "cutlass_fp8_fp8_fp16_scaled_mm")) + return timers + + +def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + if dtype == torch.int8: + return bench_int8(dtype, m, k, n, label, sub_label) + if dtype == torch.float8_e4m3fn: + return bench_fp8(dtype, m, k, n, label, sub_label) + raise ValueError("unsupported type") + + +# runner +def print_timers(timers: Iterable[TMeasurement]): + compare = TBenchmark.Compare(timers) + compare.print() + + +def run(dtype: torch.dtype, + MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: + + results = [] + for m, k, n in MKNs: + timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", + f"MKN=({m}x{k}x{n})") + print_timers(timers) + results.extend(timers) + + return results + + +# output makers +def make_output(data: Iterable[TMeasurement], + MKNs: Iterable[Tuple[int, int, int]], + base_description: str, + timestamp=None): + + print(f"== All Results {base_description} ====") + print_timers(data) + + # pickle all the results + timestamp = int(time.time()) if timestamp is None else timestamp + with open(f"{base_description}-{timestamp}.pkl", "wb") as f: + pkl.dump(data, f) + + +# argparse runners + + +def run_square_bench(args): + dim_sizes = list( + range(args.dim_start, args.dim_end + 1, args.dim_increment)) + MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) + data = run(args.dtype, MKNs) + + make_output(data, MKNs, f"square_bench-{args.dtype}") + + +def run_range_bench(args): + dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment)) + n = len(dim_sizes) + Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes + Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes + Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes + MKNs = list(zip(Ms, Ks, Ns)) + data = run(args.dtype, MKNs) + + make_output(data, MKNs, f"range_bench-{args.dtype}") + + +def run_model_bench(args): + + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: + KNs = [] + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KNs.append(KN) + return KNs + + model_bench_data = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + Ms = args.batch_sizes + KNs = model_shapes(model, tp_size) + MKNs = [] + for m in Ms: + for k, n in KNs: + MKNs.append((m, k, n)) + + data = run(args.dtype, MKNs) + model_bench_data.append(data) + + # Print all results + for data, model_tp in zip(model_bench_data, models_tps): + model, tp_size = model_tp + print(f"== Results {args.dtype} {model}-TP{tp_size} ====") + print_timers(data) + + timestamp = int(time.time()) + + all_data = [] + for d in model_bench_data: + all_data.extend(d) + # pickle all data + with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: + pkl.dump(all_data, f) + + +if __name__ == '__main__': + + def to_torch_dtype(dt): + if dt == "int8": + return torch.int8 + if dt == "fp8": + return torch.float8_e4m3fn + raise ValueError("unsupported dtype") + + parser = FlexibleArgumentParser( + description=""" +Benchmark Cutlass GEMM. + + To run square GEMMs: + python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 + + To run constant N and K and sweep M: + python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 + + To run dimensions from a model: + python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 + + Output: + - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. + """, # noqa: E501 + formatter_class=argparse.RawTextHelpFormatter) + + parser.add_argument("--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['int8', 'fp8']") + subparsers = parser.add_subparsers(dest="cmd") + + square_parser = subparsers.add_parser("square_bench") + square_parser.add_argument("--dim-start", type=int, required=True) + square_parser.add_argument("--dim-end", type=int, required=True) + square_parser.add_argument("--dim-increment", type=int, required=True) + square_parser.set_defaults(func=run_square_bench) + + range_parser = subparsers.add_parser("range_bench") + range_parser.add_argument("--dim-start", type=int, required=True) + range_parser.add_argument("--dim-end", type=int, required=True) + range_parser.add_argument("--dim-increment", type=int, required=True) + range_parser.add_argument("--m-constant", type=int, default=None) + range_parser.add_argument("--n-constant", type=int, default=None) + range_parser.add_argument("--k-constant", type=int, default=None) + range_parser.set_defaults(func=run_range_bench) + + model_parser = subparsers.add_parser("model_bench") + model_parser.add_argument("--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys()) + model_parser.add_argument("--tp-sizes", + nargs="+", + type=int, + default=DEFAULT_TP_SIZES) + model_parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + model_parser.set_defaults(func=run_model_bench) + + args = parser.parse_args() + args.func(args) diff --git a/benchmarks/cutlass_benchmarks/weight_shapes.py b/benchmarks/cutlass_benchmarks/weight_shapes.py new file mode 100644 index 0000000000000..25ec9d6028627 --- /dev/null +++ b/benchmarks/cutlass_benchmarks/weight_shapes.py @@ -0,0 +1,43 @@ +# Weight Shapes are in the format +# ([K, N], TP_SPLIT_DIM) +# Example: +# A shape of ([14336, 4096], 0) indicates the following GEMM shape, +# - TP1 : K = 14336, N = 4096 +# - TP2 : K = 7168, N = 4096 +# A shape of ([4096, 6144], 1) indicates the following GEMM shape, +# - TP1 : K = 4096, N = 6144 +# - TP4 : K = 4096, N = 1536 + +# TP1 shapes +WEIGHT_SHAPES = { + "mistralai/Mistral-7B-v0.1": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-2-7b-hf": [ + ([4096, 12288], 1), + ([4096, 4096], 0), + ([4096, 22016], 1), + ([11008, 4096], 0), + ], + "meta-llama/Llama-3-8b": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-2-13b-hf": [ + ([5120, 15360], 1), + ([5120, 5120], 0), + ([5120, 27648], 1), + ([13824, 5120], 0), + ], + "meta-llama/Llama-2-70b-hf": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 57344], 1), + ([28672, 8192], 0), + ], +} diff --git a/benchmarks/kernels/benchmark_aqlm.py b/benchmarks/kernels/benchmark_aqlm.py index 59392947b15c8..601c4ea439aea 100644 --- a/benchmarks/kernels/benchmark_aqlm.py +++ b/benchmarks/kernels/benchmark_aqlm.py @@ -1,4 +1,3 @@ -import argparse import os import sys from typing import Optional @@ -10,6 +9,7 @@ from vllm.model_executor.layers.quantization.aqlm import ( dequantize_weight, generic_dequantize_gemm, get_int_dtype, optimized_dequantize_gemm) +from vllm.utils import FlexibleArgumentParser os.environ['CUDA_VISIBLE_DEVICES'] = '0' @@ -86,9 +86,9 @@ def dequant_no_scale( # Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against # the generic pytorch version. # Just visual comparison. -def dequant_test(k: int, parts: torch.tensor, nbooks: int, bits: int) -> None: +def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None: - n = parts.sum().item() + n = int(parts.sum().item()) device = torch.device('cuda:0') @@ -137,7 +137,7 @@ def dequant_test(k: int, parts: torch.tensor, nbooks: int, bits: int) -> None: def main(): - parser = argparse.ArgumentParser(description="Benchmark aqlm performance.") + parser = FlexibleArgumentParser(description="Benchmark aqlm performance.") # Add arguments parser.add_argument("--nbooks", @@ -204,7 +204,7 @@ def main(): sys.stdout = sys.__stdout__ -def run_grid(m: int, k: int, parts: torch.tensor, nbooks: int, bits: int, +def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, methods): # I didn't see visible improvements from increasing these, but feel free :) @@ -252,10 +252,10 @@ def run_grid(m: int, k: int, parts: torch.tensor, nbooks: int, bits: int, print('') -def run_timing(num_calls: int, m: int, k: int, parts: torch.tensor, +def run_timing(num_calls: int, m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, method) -> float: - n = parts.sum().item() + n = int(parts.sum().item()) device = torch.device('cuda:0') diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py new file mode 100644 index 0000000000000..261f5829631ee --- /dev/null +++ b/benchmarks/kernels/benchmark_marlin.py @@ -0,0 +1,235 @@ +from typing import List + +import torch +import torch.utils.benchmark as benchmark +from benchmark_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS) +from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( + GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, + GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + MarlinWorkspace, marlin_24_quantize, marlin_quantize) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + gptq_pack, quantize_weights, sort_weights) +from vllm.utils import FlexibleArgumentParser + +DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"] +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] + +ACT_ORDER_OPTS = [False, True] +K_FULL_OPTS = [False, True] + + +def bench_run(results: List[benchmark.Measurement], model: str, + act_order: bool, is_k_full: bool, num_bits: int, group_size: int, + size_m: int, size_k: int, size_n: int): + label = "Quant Matmul" + + sub_label = ("{}, act={} k_full={}, b={}, g={}, " + "MKN=({}x{}x{})".format(model, act_order, is_k_full, num_bits, + group_size, size_m, size_k, size_n)) + + print(f"Testing: {sub_label}") + + a = torch.randn(size_m, size_k).to(torch.half).cuda() + b = torch.rand(size_k, size_n).to(torch.half).cuda() + + a_tmp = (torch.zeros(size_m, size_k).to(torch.half).cuda()) + + # Marlin quant + ( + marlin_w_ref, + marlin_q_w, + marlin_s, + marlin_g_idx, + marlin_sort_indices, + marlin_rand_perm, + ) = marlin_quantize(b, num_bits, group_size, act_order) + + # Marlin_24 quant + (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, + marlin_24_s) = marlin_24_quantize(b, num_bits, group_size) + + # GPTQ quant + (w_ref, q_w, s, g_idx, + rand_perm) = quantize_weights(b, num_bits, group_size, act_order) + q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n) + + # For act_order, sort the "weights" and "g_idx" + # so that group ids are increasing + repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device) + if act_order: + (q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx) + + # Prepare + marlin_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MAX_PARALLEL) + + marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N, + GPTQ_MARLIN_24_MAX_PARALLEL) + + globals = { + # Gen params + "num_bits": num_bits, + "group_size": group_size, + "size_m": size_m, + "size_n": size_n, + "size_k": size_k, + "a": a, + "a_tmp": a_tmp, + # Marlin params + "marlin_w_ref": marlin_w_ref, + "marlin_q_w": marlin_q_w, + "marlin_s": marlin_s, + "marlin_g_idx": marlin_g_idx, + "marlin_sort_indices": marlin_sort_indices, + "marlin_rand_perm": marlin_rand_perm, + "marlin_workspace": marlin_workspace, + "is_k_full": is_k_full, + # Marlin_24 params + "marlin_24_w_ref": marlin_24_w_ref, + "marlin_24_q_w_comp": marlin_24_q_w_comp, + "marlin_24_meta": marlin_24_meta, + "marlin_24_s": marlin_24_s, + "marlin_24_workspace": marlin_24_workspace, + # GPTQ params + "q_w_gptq": q_w_gptq, + "repack_sort_indices": repack_sort_indices, + # Kernels + "gptq_marlin_gemm": ops.gptq_marlin_gemm, + "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm, + "gptq_marlin_repack": ops.gptq_marlin_repack, + } + + min_run_time = 1 + + # Warmup pytorch + for i in range(5): + torch.matmul(a, marlin_w_ref) + + results.append( + benchmark.Timer( + stmt="torch.matmul(a, marlin_w_ref)", + globals=globals, + label=label, + sub_label=sub_label, + description="pytorch_gemm", + ).blocked_autorange(min_run_time=min_run_time)) + + results.append( + benchmark.Timer( + stmt= + "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="gptq_marlin_gemm", + ).blocked_autorange(min_run_time=min_run_time)) + + if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS + and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES): + results.append( + benchmark.Timer( + stmt= + "output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="gptq_marlin_24_gemm", + ).blocked_autorange(min_run_time=min_run_time)) + + results.append( + benchmark.Timer( + stmt= + "q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="gptq_marlin_repack", + ).blocked_autorange(min_run_time=min_run_time)) + + +def main(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + results: List[benchmark.Measurement] = [] + + for model in args.models: + for layer in WEIGHT_SHAPES[model]: + size_k = layer[0] + size_n = layer[1] + + if len(args.limit_k) > 0 and size_k not in args.limit_k: + continue + + if len(args.limit_n) > 0 and size_n not in args.limit_n: + continue + + for act_order in ACT_ORDER_OPTS: + if len(args.limit_act_order + ) > 0 and act_order not in args.limit_act_order: + continue + + for is_k_full in K_FULL_OPTS: + if len(args.limit_k_full + ) > 0 and is_k_full not in args.limit_k_full: + continue + + for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS: + if len(args.limit_num_bits + ) > 0 and num_bits not in args.limit_num_bits: + continue + + for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES: + if len( + args.limit_group_size + ) > 0 and group_size not in args.limit_group_size: + continue + + # For act_order, the group_size must be less than + # size_k + if act_order and (group_size == size_k + or group_size == -1): + continue + + for size_m in args.batch_sizes: + bench_run(results, model, act_order, is_k_full, + num_bits, group_size, size_m, size_k, + size_n) + + compare = benchmark.Compare(results) + compare.print() + + +# For quick benchmarking use: +# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501 +# +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="Benchmark Marlin across specified models/shapes/batches") + parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys(), + ) + parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + parser.add_argument("--limit-k", nargs="+", type=int, default=[]) + parser.add_argument("--limit-n", nargs="+", type=int, default=[]) + parser.add_argument("--limit-group-size", nargs="+", type=int, default=[]) + parser.add_argument("--limit-num-bits", nargs="+", type=int, default=[]) + parser.add_argument("--limit-act-order", nargs="+", type=int, default=[]) + parser.add_argument("--limit-k-full", nargs="+", type=int, default=[]) + + args = parser.parse_args() + main(args) diff --git a/benchmarks/kernels/benchmark_mixtral_moe.py b/benchmarks/kernels/benchmark_mixtral_moe.py deleted file mode 100644 index 5280b214144c9..0000000000000 --- a/benchmarks/kernels/benchmark_mixtral_moe.py +++ /dev/null @@ -1,215 +0,0 @@ -import argparse -import json -import os -import sys - -import torch -import torch.nn.functional as F -import triton -from tqdm import tqdm - -from vllm.model_executor.layers.fused_moe import (fused_moe, - get_config_file_name) - -os.environ['CUDA_VISIBLE_DEVICES'] = '0' - - -def main(dtype: str): - method = fused_moe - for bs in [ - 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, - 2048, 3072, 4096 - ]: - run_grid(bs, method=method, dtype=dtype) - - -def run_grid(bs, method, dtype: str): - d_model = 4096 - num_total_experts = 8 - top_k = 2 - tp_size = 2 - model_intermediate_size = 14336 - num_layers = 32 - num_calls = 100 - - num_warmup_trials = 1 - num_trials = 1 - - configs = [] - - for block_size_n in [32, 64, 128, 256]: - for block_size_m in [16, 32, 64, 128, 256]: - for block_size_k in [64, 128, 256]: - for group_size_m in [1, 16, 32, 64]: - for num_warps in [4, 8]: - for num_stages in [2, 3, 4, 5]: - configs.append({ - "BLOCK_SIZE_M": block_size_m, - "BLOCK_SIZE_N": block_size_n, - "BLOCK_SIZE_K": block_size_k, - "GROUP_SIZE_M": group_size_m, - "num_warps": num_warps, - "num_stages": num_stages, - }) - - best_config = None - best_time_us = 1e20 - - print(f'{tp_size=} {bs=}') - - for config in tqdm(configs): - # warmup - try: - for _ in range(num_warmup_trials): - run_timing( - num_calls=num_calls, - bs=bs, - d_model=d_model, - num_total_experts=num_total_experts, - top_k=top_k, - tp_size=tp_size, - model_intermediate_size=model_intermediate_size, - method=method, - config=config, - dtype=dtype, - ) - except triton.runtime.autotuner.OutOfResources: - continue - - # trial - for _ in range(num_trials): - kernel_dur_ms = run_timing( - num_calls=num_calls, - bs=bs, - d_model=d_model, - num_total_experts=num_total_experts, - top_k=top_k, - tp_size=tp_size, - model_intermediate_size=model_intermediate_size, - method=method, - config=config, - dtype=dtype, - ) - - kernel_dur_us = 1000 * kernel_dur_ms - model_dur_ms = kernel_dur_ms * num_layers - - if kernel_dur_us < best_time_us: - best_config = config - best_time_us = kernel_dur_us - - tqdm.write( - f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}' - f' {bs=} {tp_size=} {top_k=} {num_total_experts=} ' - f'{d_model=} {model_intermediate_size=} {num_layers=}') - - print("best_time_us", best_time_us) - print("best_config", best_config) - - # holds Dict[str, Dict[str, int]] - filename = get_config_file_name(num_total_experts, - model_intermediate_size // tp_size, - "float8" if dtype == "float8" else None) - print(f"writing config to file {filename}") - existing_content = {} - if os.path.exists(filename): - with open(filename, "r") as f: - existing_content = json.load(f) - existing_content[str(bs)] = best_config - with open(filename, "w") as f: - json.dump(existing_content, f, indent=4) - f.write("\n") - - -def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, - top_k: int, tp_size: int, model_intermediate_size: int, method, - config, dtype: str) -> float: - shard_intermediate_size = model_intermediate_size // tp_size - - hidden_states = torch.rand( - (bs, d_model), - device="cuda:0", - dtype=torch.float16, - ) - - w1 = torch.rand( - (num_total_experts, 2 * shard_intermediate_size, d_model), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - - w2 = torch.rand( - (num_total_experts, d_model, shard_intermediate_size), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - - w1_scale = None - w2_scale = None - a1_scale = None - a2_scale = None - - if dtype == "float8": - w1 = w1.to(torch.float8_e4m3fn) - w2 = w2.to(torch.float8_e4m3fn) - w1_scale = torch.ones(num_total_experts, - device=hidden_states.device, - dtype=torch.float32) - w2_scale = torch.ones(num_total_experts, - device=hidden_states.device, - dtype=torch.float32) - a1_scale = torch.ones(1, - device=hidden_states.device, - dtype=torch.float32) - a2_scale = torch.ones(1, - device=hidden_states.device, - dtype=torch.float32) - - gating_output = F.softmax(torch.rand( - (num_calls, bs, num_total_experts), - device=hidden_states.device, - dtype=torch.float32, - ), - dim=-1) - - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - start_event.record() - for i in range(num_calls): - hidden_states = method( - hidden_states=hidden_states, - w1=w1, - w2=w2, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - gating_output=gating_output[i], - topk=2, - renormalize=True, - inplace=True, - override_config=config, - use_fp8=dtype == "float8", - ) - end_event.record() - end_event.synchronize() - - dur_ms = start_event.elapsed_time(end_event) / num_calls - return dur_ms - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - prog='benchmark_mixtral_moe', - description='Benchmark and tune the fused_moe kernel', - ) - parser.add_argument( - '--dtype', - type=str, - default='auto', - choices=['float8', 'float16'], - help='Data type used for fused_moe kernel computations', - ) - args = parser.parse_args() - sys.exit(main(args.dtype)) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py new file mode 100644 index 0000000000000..e00696d6d43cb --- /dev/null +++ b/benchmarks/kernels/benchmark_moe.py @@ -0,0 +1,333 @@ +import argparse +import time +from datetime import datetime +from typing import Any, Dict, List, Tuple, TypedDict + +import ray +import torch +import triton +from ray.experimental.tqdm_ray import tqdm +from transformers import AutoConfig + +from vllm.model_executor.layers.fused_moe.fused_moe import * +from vllm.utils import FlexibleArgumentParser + + +class BenchmarkConfig(TypedDict): + BLOCK_SIZE_M: int + BLOCK_SIZE_N: int + BLOCK_SIZE_K: int + GROUP_SIZE_M: int + num_warps: int + num_stages: int + + +def benchmark_config( + config: BenchmarkConfig, + num_tokens: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8: bool, + num_iters: int = 100, +) -> float: + init_dtype = torch.float16 if use_fp8 else dtype + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + w1 = torch.randn(num_experts, + shard_intermediate_size, + hidden_size, + dtype=init_dtype) + w2 = torch.randn(num_experts, + hidden_size, + shard_intermediate_size // 2, + dtype=init_dtype) + gating_output = torch.randn(num_iters, + num_tokens, + num_experts, + dtype=torch.float32) + + w1_scale = None + w2_scale = None + a1_scale = None + a2_scale = None + if use_fp8: + w1_scale = torch.randn(num_experts, dtype=torch.float32) + w2_scale = torch.randn(num_experts, dtype=torch.float32) + a1_scale = torch.randn(1, dtype=torch.float32) + a2_scale = torch.randn(1, dtype=torch.float32) + + w1 = w1.to(torch.float8_e4m3fn) + w2 = w2.to(torch.float8_e4m3fn) + + input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32) + + def prepare(i: int): + input_gating.copy_(gating_output[i]) + + def run(): + fused_moe( + x, + w1, + w2, + input_gating, + topk, + renormalize=True, + inplace=True, + override_config=config, + use_fp8=use_fp8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) + + # JIT compilation & warmup + run() + torch.cuda.synchronize() + + # Capture 10 invocations with CUDA graph + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + for _ in range(10): + run() + torch.cuda.synchronize() + + # Warmup + for _ in range(5): + graph.replay() + torch.cuda.synchronize() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + latencies: List[float] = [] + for i in range(num_iters): + prepare(i) + torch.cuda.synchronize() + + start_event.record() + graph.replay() + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + avg = sum(latencies) / (num_iters * 10) * 1000 # us + graph.reset() + return avg + + +def get_configs_compute_bound() -> List[Dict[str, int]]: + # Reduced search space for faster tuning. + # TODO(woosuk): Increase the search space and use a performance model to + # prune the search space. + configs: List[BenchmarkConfig] = [] + for num_stages in [2, 3, 4, 5]: + for block_m in [16, 32, 64, 128, 256]: + for block_k in [64, 128, 256]: + for block_n in [32, 64, 128, 256]: + for num_warps in [4, 8]: + for group_size in [1, 16, 32, 64]: + configs.append({ + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_size, + "num_warps": num_warps, + "num_stages": num_stages, + }) + return configs + + +@ray.remote(num_gpus=1) +class BenchmarkWorker: + + def __init__(self, seed: int) -> None: + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(seed) + self.seed = seed + + def benchmark( + self, + num_tokens: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8: bool, + ) -> Tuple[Dict[str, int], float]: + torch.cuda.manual_seed_all(self.seed) + + dtype_str = "float8" if use_fp8 else None + # NOTE(woosuk): The current naming convention uses w2.shape[2], which + # is the intermediate size after silu_and_mul. + op_config = get_moe_configs(num_experts, shard_intermediate_size // 2, + dtype_str) + if op_config is None: + config = get_default_config(num_tokens, num_experts, + shard_intermediate_size, hidden_size, + topk, dtype_str) + else: + config = op_config[min(op_config.keys(), + key=lambda x: abs(x - num_tokens))] + kernel_time = benchmark_config(config, num_tokens, num_experts, + shard_intermediate_size, hidden_size, + topk, dtype, use_fp8) + return config, kernel_time + + def tune( + self, + num_tokens: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8: bool, + search_space: List[BenchmarkConfig], + ) -> BenchmarkConfig: + best_config = None + best_time = float("inf") + for config in tqdm(search_space): + try: + kernel_time = benchmark_config(config, + num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8, + num_iters=10) + except triton.runtime.autotuner.OutOfResources: + # Some configurations may be invalid and fail to compile. + continue + + if kernel_time < best_time: + best_time = kernel_time + best_config = config + now = datetime.now() + print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}") + assert best_config is not None + return best_config + + +def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: + return { + "BLOCK_SIZE_M": config["BLOCK_SIZE_M"], + "BLOCK_SIZE_N": config["BLOCK_SIZE_N"], + "BLOCK_SIZE_K": config["BLOCK_SIZE_K"], + "GROUP_SIZE_M": config["GROUP_SIZE_M"], + "num_warps": config["num_warps"], + "num_stages": config["num_stages"], + } + + +def save_configs( + configs: Dict[int, BenchmarkConfig], + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8: bool, +) -> None: + dtype_str = "float8" if use_fp8 else None + # NOTE(woosuk): The current naming convention uses w2.shape[2], which + # is the intermediate size after silu_and_mul. + filename = get_config_file_name(num_experts, shard_intermediate_size // 2, + dtype_str) + print(f"Writing best config to {filename}...") + with open(filename, "w") as f: + json.dump(configs, f, indent=4) + f.write("\n") + + +def main(args: argparse.Namespace): + print(args) + + config = AutoConfig.from_pretrained(args.model) + if config.architectures[0] == "DbrxForCausalLM": + E = config.ffn_config.moe_num_experts + topk = config.ffn_config.moe_top_k + intermediate_size = config.ffn_config.ffn_hidden_size + shard_intermediate_size = 2 * intermediate_size // args.tp_size + else: + # Default: Mixtral. + E = config.num_local_experts + topk = config.num_experts_per_tok + intermediate_size = config.intermediate_size + shard_intermediate_size = 2 * intermediate_size // args.tp_size + + hidden_size = config.hidden_size + dtype = config.torch_dtype + use_fp8 = args.dtype == "fp8" + + if args.batch_size is None: + batch_sizes = [ + 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, + 2048, 3072, 4096 + ] + else: + batch_sizes = [args.batch_size] + + ray.init() + num_gpus = int(ray.available_resources()["GPU"]) + workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)] + + def _distribute(method: str, inputs: List[Any]) -> List[Any]: + outputs = [] + worker_idx = 0 + for input_args in inputs: + worker = workers[worker_idx] + worker_method = getattr(worker, method) + output = worker_method.remote(*input_args) + outputs.append(output) + worker_idx = (worker_idx + 1) % num_gpus + return ray.get(outputs) + + if args.tune: + search_space = get_configs_compute_bound() + print(f"Start tuning over {len(search_space)} configurations...") + + start = time.time() + configs = _distribute( + "tune", [(batch_size, E, shard_intermediate_size, hidden_size, + topk, dtype, use_fp8, search_space) + for batch_size in batch_sizes]) + best_configs = { + M: sort_config(config) + for M, config in zip(batch_sizes, configs) + } + save_configs(best_configs, E, shard_intermediate_size, hidden_size, + topk, dtype, use_fp8) + end = time.time() + print(f"Tuning took {end - start:.2f} seconds") + else: + outputs = _distribute("benchmark", + [(batch_size, E, shard_intermediate_size, + hidden_size, topk, dtype, use_fp8) + for batch_size in batch_sizes]) + + for batch_size, (config, kernel_time) in zip(batch_sizes, outputs): + print(f"Batch size: {batch_size}, config: {config}") + print(f"Kernel time: {kernel_time:.2f} us") + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + parser.add_argument("--model", + type=str, + default="mistralai/Mixtral-8x7B-Instruct-v0.1") + parser.add_argument("--tp-size", "-tp", type=int, default=2) + parser.add_argument("--dtype", + type=str, + choices=["auto", "fp8"], + default="auto") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--batch-size", type=int, required=False) + parser.add_argument("--tune", action="store_true") + args = parser.parse_args() + + main(args) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index ca7967c1ab0d2..16de60477c305 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -1,12 +1,12 @@ -import argparse import random import time -from typing import Optional +from typing import List, Optional import torch from vllm import _custom_ops as ops -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, + create_kv_caches_with_random) NUM_BLOCKS = 1024 PARTITION_SIZE = 512 @@ -54,14 +54,17 @@ def main( # Create the block tables. max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size - block_tables = [] + block_tables_lst: List[List[int]] = [] for _ in range(num_seqs): block_table = [ random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq) ] - block_tables.append(block_table) - block_tables = torch.tensor(block_tables, dtype=torch.int, device=device) + block_tables_lst.append(block_table) + + block_tables = torch.tensor(block_tables_lst, + dtype=torch.int, + device=device) # Create the KV cache. key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS, @@ -158,19 +161,19 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: if __name__ == '__main__': - parser = argparse.ArgumentParser( + parser = FlexibleArgumentParser( description="Benchmark the paged attention kernel.") parser.add_argument("--version", type=str, choices=["v1", "v2"], default="v2") parser.add_argument("--batch-size", type=int, default=8) - parser.add_argument("--seq_len", type=int, default=4096) + parser.add_argument("--seq-len", type=int, default=4096) parser.add_argument("--num-query-heads", type=int, default=64) parser.add_argument("--num-kv-heads", type=int, default=8) parser.add_argument("--head-size", type=int, - choices=[64, 80, 96, 112, 128, 256], + choices=[64, 80, 96, 112, 128, 192, 256], default=128) parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) parser.add_argument("--use-alibi", action="store_true") @@ -183,13 +186,11 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: parser.add_argument( "--kv-cache-dtype", type=str, - choices=["auto", "fp8"], + choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"], default="auto", - help= - 'Data type for kv cache storage. If "auto", will use model data type. ' - 'FP8_E5M2 (without scaling) is only supported on cuda version greater ' - 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for ' - 'common inference criteria.') + help="Data type for kv cache storage. If 'auto', will use model " + "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. " + "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)") args = parser.parse_args() print(args) diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index 9188e811e2982..78736c7a7ba6f 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -1,11 +1,12 @@ -import argparse from itertools import accumulate -from typing import Optional +from typing import List, Optional import nvtx import torch -from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding, + get_rope) +from vllm.utils import FlexibleArgumentParser def benchmark_rope_kernels_multi_lora( @@ -37,7 +38,7 @@ def benchmark_rope_kernels_multi_lora( }) # non-batched RoPE takes only one scaling factor, we create multiple # instances to simulate the same behavior - non_batched_ropes = [] + non_batched_ropes: List[RotaryEmbedding] = [] for scaling_factor in scaling_factors: non_batched_ropes.append( get_rope(head_size, rotary_dim, max_position, base, is_neox_style, @@ -85,7 +86,7 @@ def benchmark_rope_kernels_multi_lora( if __name__ == '__main__': - parser = argparse.ArgumentParser( + parser = FlexibleArgumentParser( description="Benchmark the rotary embedding kernels.") parser.add_argument("--is-neox-style", type=bool, default=True) parser.add_argument("--batch-size", type=int, default=16) @@ -93,7 +94,7 @@ def benchmark_rope_kernels_multi_lora( parser.add_argument("--num-heads", type=int, default=8) parser.add_argument("--head-size", type=int, - choices=[64, 80, 96, 112, 128, 256], + choices=[64, 80, 96, 112, 128, 192, 256], default=128) parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32) parser.add_argument("--dtype", diff --git a/benchmarks/kernels/benchmark_shapes.py b/benchmarks/kernels/benchmark_shapes.py new file mode 100644 index 0000000000000..4eeeca35a37cc --- /dev/null +++ b/benchmarks/kernels/benchmark_shapes.py @@ -0,0 +1,75 @@ +WEIGHT_SHAPES = { + "ideal": [[4 * 256 * 32, 256 * 32]], + "mistralai/Mistral-7B-v0.1/TP1": [ + [4096, 6144], + [4096, 4096], + [4096, 28672], + [14336, 4096], + ], + "mistralai/Mistral-7B-v0.1/TP2": [ + [4096, 3072], + [2048, 4096], + [4096, 14336], + [7168, 4096], + ], + "mistralai/Mistral-7B-v0.1/TP4": [ + [4096, 1536], + [1024, 4096], + [4096, 7168], + [3584, 4096], + ], + "meta-llama/Llama-2-7b-hf/TP1": [ + [4096, 12288], + [4096, 4096], + [4096, 22016], + [11008, 4096], + ], + "meta-llama/Llama-2-7b-hf/TP2": [ + [4096, 6144], + [2048, 4096], + [4096, 11008], + [5504, 4096], + ], + "meta-llama/Llama-2-7b-hf/TP4": [ + [4096, 3072], + [1024, 4096], + [4096, 5504], + [2752, 4096], + ], + "meta-llama/Llama-2-13b-hf/TP1": [ + [5120, 15360], + [5120, 5120], + [5120, 27648], + [13824, 5120], + ], + "meta-llama/Llama-2-13b-hf/TP2": [ + [5120, 7680], + [2560, 5120], + [5120, 13824], + [6912, 5120], + ], + "meta-llama/Llama-2-13b-hf/TP4": [ + [5120, 3840], + [1280, 5120], + [5120, 6912], + [3456, 5120], + ], + "meta-llama/Llama-2-70b-hf/TP1": [ + [8192, 10240], + [8192, 8192], + [8192, 57344], + [28672, 8192], + ], + "meta-llama/Llama-2-70b-hf/TP2": [ + [8192, 5120], + [4096, 8192], + [8192, 28672], + [14336, 8192], + ], + "meta-llama/Llama-2-70b-hf/TP4": [ + [8192, 2560], + [2048, 8192], + [8192, 14336], + [7168, 8192], + ], +} diff --git a/benchmarks/launch_tgi_server.sh b/benchmarks/launch_tgi_server.sh index 64d3c4f4b3889..f491c90d0683e 100755 --- a/benchmarks/launch_tgi_server.sh +++ b/benchmarks/launch_tgi_server.sh @@ -4,7 +4,7 @@ PORT=8000 MODEL=$1 TOKENS=$2 -docker run --gpus all --shm-size 1g -p $PORT:80 \ +docker run -e HF_TOKEN=$HF_TOKEN --gpus all --shm-size 1g -p $PORT:80 \ -v $PWD/data:/data \ ghcr.io/huggingface/text-generation-inference:1.4.0 \ --model-id $MODEL \ diff --git a/benchmarks/overheads/benchmark_hashing.py b/benchmarks/overheads/benchmark_hashing.py new file mode 100644 index 0000000000000..203699e9a8d06 --- /dev/null +++ b/benchmarks/overheads/benchmark_hashing.py @@ -0,0 +1,63 @@ +import cProfile +import pstats + +from vllm import LLM, SamplingParams +from vllm.utils import FlexibleArgumentParser + +# A very long prompt, total number of tokens is about 15k. +LONG_PROMPT = ["You are an expert in large language models, aren't you?" + ] * 1000 +LONG_PROMPT = ' '.join(LONG_PROMPT) + + +def main(args): + llm = LLM( + model=args.model, + enforce_eager=True, + enable_prefix_caching=True, + tensor_parallel_size=args.tensor_parallel_size, + use_v2_block_manager=args.use_v2_block_manager, + ) + + sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) + profiler = cProfile.Profile() + + print("------warm up------") + for i in range(3): + output = llm.generate(LONG_PROMPT, sampling_params) + print(output[0].outputs[0].text) + + print("------start generating------") + for i in range(3): + profiler.runctx('llm.generate(LONG_PROMPT, sampling_params)', + globals(), locals()) + + # analyze the runtime of hashing function + stats = pstats.Stats(profiler) + stats.sort_stats('cumulative') + total_time = 0 + total_calls = 0 + for func in stats.stats: + if 'hash_of_block' in func[2]: + total_time = stats.stats[func][3] + total_calls = stats.stats[func][0] + percentage = (total_time / stats.total_tt) * 100 + print(f"Hashing took {total_time:.2f} seconds," + f"{percentage:.2f}% of the total runtime.") + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description='Benchmark the performance of hashing function in' + 'automatic prefix caching.') + parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k') + parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) + parser.add_argument('--output-len', type=int, default=10) + parser.add_argument('--enable-prefix-caching', + action='store_true', + help='enable prefix caching') + parser.add_argument('--use-v2-block-manager', + action='store_true', + help='Use BlockSpaceMangerV2') + args = parser.parse_args() + main(args) diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 0cf37769a6960..690559ee265e9 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -12,7 +12,7 @@ include_directories("${CMAKE_SOURCE_DIR}/csrc") # # Check the compile flags # -list(APPEND CXX_COMPILE_FLAGS +list(APPEND CXX_COMPILE_FLAGS "-fopenmp" "-DVLLM_CPU_EXTENSION") @@ -33,9 +33,23 @@ function (find_isa CPUINFO TARGET OUT) endif() endfunction() +function (is_avx512_disabled OUT) + set(DISABLE_AVX512 $ENV{VLLM_CPU_DISABLE_AVX512}) + if(DISABLE_AVX512 AND DISABLE_AVX512 STREQUAL "true") + set(${OUT} ON PARENT_SCOPE) + else() + set(${OUT} OFF PARENT_SCOPE) + endif() +endfunction() + +is_avx512_disabled(AVX512_DISABLED) + +find_isa(${CPUINFO} "avx2" AVX2_FOUND) find_isa(${CPUINFO} "avx512f" AVX512_FOUND) +find_isa(${CPUINFO} "POWER10" POWER10_FOUND) +find_isa(${CPUINFO} "POWER9" POWER9_FOUND) -if (AVX512_FOUND) +if (AVX512_FOUND AND NOT AVX512_DISABLED) list(APPEND CXX_COMPILE_FLAGS "-mavx512f" "-mavx512vl" @@ -44,8 +58,8 @@ if (AVX512_FOUND) find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND) if (AVX512BF16_FOUND OR ENABLE_AVX512BF16) - if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND - CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) + if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND + CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16") else() message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3") @@ -53,8 +67,18 @@ if (AVX512_FOUND) else() message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.") endif() +elseif (AVX2_FOUND) + list(APPEND CXX_COMPILE_FLAGS "-mavx2") + message(WARNING "vLLM CPU backend using AVX2 ISA") +elseif (POWER9_FOUND OR POWER10_FOUND) + message(STATUS "PowerPC detected") + # Check for PowerPC VSX support + list(APPEND CXX_COMPILE_FLAGS + "-mvsx" + "-mcpu=native" + "-mtune=native") else() - message(FATAL_ERROR "vLLM CPU backend requires AVX512 ISA support.") + message(FATAL_ERROR "vLLM CPU backend requires AVX512 or AVX2 or Power9+ ISA support.") endif() message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") @@ -73,7 +97,7 @@ set(VLLM_EXT_SRC "csrc/cpu/cache.cpp" "csrc/cpu/layernorm.cpp" "csrc/cpu/pos_encoding.cpp" - "csrc/cpu/pybind.cpp") + "csrc/cpu/torch_bindings.cpp") define_gpu_extension_target( _C @@ -81,10 +105,10 @@ define_gpu_extension_target( LANGUAGE CXX SOURCES ${VLLM_EXT_SRC} COMPILE_FLAGS ${CXX_COMPILE_FLAGS} - WITH_SOABI + USE_SABI 3 + WITH_SOABI ) add_custom_target(default) message(STATUS "Enabling C extension.") add_dependencies(default _C) - diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 7c71673e36f29..4869cad541135 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -5,7 +5,7 @@ macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS) file(REAL_PATH ${EXECUTABLE} EXECUTABLE) set(Python_EXECUTABLE ${EXECUTABLE}) - find_package(Python COMPONENTS Interpreter Development.Module) + find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule) if (NOT Python_FOUND) message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.") endif() @@ -99,7 +99,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) "Failed to determine torch nvcc compiler flags") if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8) - list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2") + list(APPEND GPU_FLAGS "-DENABLE_FP8") endif() if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0) list(REMOVE_ITEM GPU_FLAGS @@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) list(APPEND GPU_FLAGS "-DUSE_ROCM" - "-DENABLE_FP8_E4M3" + "-DENABLE_FP8" "-U__HIP_NO_HALF_CONVERSIONS__" "-U__HIP_NO_HALF_OPERATORS__" "-fno-gpu-rdc") @@ -147,16 +147,23 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES) if (${GPU_LANG} STREQUAL "HIP") # # `GPU_ARCHES` controls the `--offload-arch` flags. - # `CMAKE_HIP_ARCHITECTURES` is set up by torch and can be controlled - # via the `PYTORCH_ROCM_ARCH` env variable. # - + # If PYTORCH_ROCM_ARCH env variable exists, then we take it as a list, + # if not, then we use CMAKE_HIP_ARCHITECTURES which was generated by calling + # "rocm_agent_enumerator" in "enable_language(HIP)" + # (in file Modules/CMakeDetermineHIPCompiler.cmake) + # + if(DEFINED ENV{PYTORCH_ROCM_ARCH}) + set(HIP_ARCHITECTURES $ENV{PYTORCH_ROCM_ARCH}) + else() + set(HIP_ARCHITECTURES ${CMAKE_HIP_ARCHITECTURES}) + endif() # # Find the intersection of the supported + detected architectures to # set the module architecture flags. # set(${GPU_ARCHES}) - foreach (_ARCH ${CMAKE_HIP_ARCHITECTURES}) + foreach (_ARCH ${HIP_ARCHITECTURES}) if (_ARCH IN_LIST _GPU_SUPPORTED_ARCHES_LIST) list(APPEND ${GPU_ARCHES} ${_ARCH}) endif() @@ -164,7 +171,7 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES) if(NOT ${GPU_ARCHES}) message(FATAL_ERROR - "None of the detected ROCm architectures: ${CMAKE_HIP_ARCHITECTURES} is" + "None of the detected ROCm architectures: ${HIP_ARCHITECTURES} is" " supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.") endif() @@ -294,6 +301,7 @@ endmacro() # INCLUDE_DIRECTORIES - Extra include directories. # LIBRARIES - Extra link libraries. # WITH_SOABI - Generate library with python SOABI suffix name. +# USE_SABI - Use python stable api # # Note: optimization level/debug info is set via cmake build type. # @@ -301,7 +309,7 @@ function (define_gpu_extension_target GPU_MOD_NAME) cmake_parse_arguments(PARSE_ARGV 1 GPU "WITH_SOABI" - "DESTINATION;LANGUAGE" + "DESTINATION;LANGUAGE;USE_SABI" "SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES") # Add hipify preprocessing step when building with HIP/ROCm. @@ -315,7 +323,11 @@ function (define_gpu_extension_target GPU_MOD_NAME) set(GPU_WITH_SOABI) endif() - Python_add_library(${GPU_MOD_NAME} MODULE "${GPU_SOURCES}" ${GPU_WITH_SOABI}) + if (GPU_USE_SABI) + Python_add_library(${GPU_MOD_NAME} MODULE USE_SABI ${GPU_USE_SABI} ${GPU_WITH_SOABI} "${GPU_SOURCES}") + else() + Python_add_library(${GPU_MOD_NAME} MODULE ${GPU_WITH_SOABI} "${GPU_SOURCES}") + endif() if (GPU_LANGUAGE STREQUAL "HIP") # Make this target dependent on the hipify preprocessor step. diff --git a/collect_env.py b/collect_env.py index 1ecfeb8e22e2f..083cb768f5399 100644 --- a/collect_env.py +++ b/collect_env.py @@ -64,6 +64,7 @@ "triton", "optree", "nccl", + "transformers", } DEFAULT_PIP_PATTERNS = { @@ -75,6 +76,7 @@ "optree", "onnx", "nccl", + "transformers", } @@ -601,6 +603,11 @@ def get_version_or_na(cfg, prefix): {conda_packages} """.strip() +# both the above code and the following code use `strip()` to +# remove leading/trailing whitespaces, so we need to add a newline +# in between to separate the two sections +env_info_fmt += "\n" + env_info_fmt += """ ROCM Version: {rocm_version} Neuron SDK Version: {neuron_sdk_version} diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 24d972702c858..5ed1dc3b8f792 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -1,5 +1,5 @@ #include -#include +#include #include #include @@ -10,11 +10,11 @@ namespace vllm { // Activation and gating kernel template. -template +template __global__ void act_and_mul_kernel( - scalar_t* __restrict__ out, // [..., d] - const scalar_t* __restrict__ input, // [..., 2, d] - const int d) { + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., 2, d] + const int d) { const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); @@ -23,72 +23,66 @@ __global__ void act_and_mul_kernel( } } -template +template __device__ __forceinline__ T silu_kernel(const T& x) { // x * sigmoid(x) - return (T) (((float) x) / (1.0f + expf((float) -x))); + return (T)(((float)x) / (1.0f + expf((float)-x))); } -template +template __device__ __forceinline__ T gelu_kernel(const T& x) { // Equivalent to PyTorch GELU with 'none' approximation. // Refer to: // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38 - const float f = (float) x; + const float f = (float)x; constexpr float ALPHA = M_SQRT1_2; - return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA))); + return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA))); } -template +template __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { // Equivalent to PyTorch GELU with 'tanh' approximation. // Refer to: // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30 - const float f = (float) x; + const float f = (float)x; constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f; constexpr float KAPPA = 0.044715; float x_cube = f * f * f; float inner = BETA * (f + KAPPA * x_cube); - return (T) (0.5f * f * (1.0f + ::tanhf(inner))); + return (T)(0.5f * f * (1.0f + ::tanhf(inner))); } -} // namespace vllm +} // namespace vllm // Launch activation and gating kernel. -#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ - int d = input.size(-1) / 2; \ - int64_t num_tokens = input.numel() / input.size(-1); \ - dim3 grid(num_tokens); \ - dim3 block(std::min(d, 1024)); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), \ - "act_and_mul_kernel", \ - [&] { \ - vllm::act_and_mul_kernel><<>>( \ - out.data_ptr(), \ - input.data_ptr(), \ - d); \ - }); - -void silu_and_mul( - torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_kernel> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d); \ + }); + +void silu_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); } -void gelu_and_mul( - torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +void gelu_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel); } -void gelu_tanh_and_mul( - torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel); } @@ -96,11 +90,11 @@ void gelu_tanh_and_mul( namespace vllm { // Element-wise activation kernel template. -template +template __global__ void activation_kernel( - scalar_t* __restrict__ out, // [..., d] - const scalar_t* __restrict__ input, // [..., d] - const int d) { + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., d] + const int d) { const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]); @@ -108,54 +102,61 @@ __global__ void activation_kernel( } } -} // namespace vllm +} // namespace vllm // Launch element-wise activation kernel. -#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ - int d = input.size(-1); \ - int64_t num_tokens = input.numel() / d; \ - dim3 grid(num_tokens); \ - dim3 block(std::min(d, 1024)); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), \ - "activation_kernel", \ - [&] { \ - vllm::activation_kernel><<>>( \ - out.data_ptr(), \ - input.data_ptr(), \ - d); \ - }); +#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ + int d = input.size(-1); \ + int64_t num_tokens = input.numel() / d; \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \ + vllm::activation_kernel> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d); \ + }); namespace vllm { -template +template __device__ __forceinline__ T gelu_new_kernel(const T& x) { - const float x3 = (float) (x * x * x); - const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3)))); - return ((T) 0.5) * x * (((T) 1.0) + t); + const float x3 = (float)(x * x * x); + const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3)))); + return ((T)0.5) * x * (((T)1.0) + t); } -template +template __device__ __forceinline__ T gelu_fast_kernel(const T& x) { - const float f = (float) x; - const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x)); - return ((T) 0.5) * x * (((T) 1.0) + t); + const float f = (float)x; + const T t = + (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x)); + return ((T)0.5) * x * (((T)1.0) + t); } -} // namespace vllm +template +__device__ __forceinline__ T gelu_quick_kernel(const T& x) { + // x * sigmoid(1.702 * x) + return (T)(((float)x) / (1.0f + expf(-1.702f * (float)x))); +} + +} // namespace vllm -void gelu_new( - torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., d] +void gelu_new(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel); } -void gelu_fast( - torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., d] +void gelu_fast(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); } + +void gelu_quick(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., d] +{ + LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel); +} diff --git a/csrc/attention/attention_generic.cuh b/csrc/attention/attention_generic.cuh index 31fb401cbe2c1..62409c0cce93e 100644 --- a/csrc/attention/attention_generic.cuh +++ b/csrc/attention/attention_generic.cuh @@ -1,5 +1,6 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -22,31 +23,31 @@ namespace vllm { // A vector type to store Q, K, V elements. -template +template struct Vec {}; // A vector type to store FP32 accumulators. -template +template struct FloatVec {}; // Template vector operations. -template +template inline __device__ Acc mul(A a, B b); -template +template inline __device__ float sum(T v); -template +template inline __device__ float dot(T a, T b) { return sum(mul(a, b)); } -template +template inline __device__ float dot(T a, T b) { return sum(mul(a, b)); } -template +template inline __device__ void zero(T& dst) { constexpr int WORDS = sizeof(T) / 4; union { @@ -61,4 +62,4 @@ inline __device__ void zero(T& dst) { dst = tmp.raw; } -} // namespace vllm +} // namespace vllm diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 8b1b5e098015f..91083481705cb 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -1,5 +1,6 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -16,30 +17,26 @@ * limitations under the License. */ -#include +#include #include #include +#include #include "attention_dtypes.h" #include "attention_utils.cuh" -#if defined(ENABLE_FP8_E5M2) -#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh" -#elif defined(ENABLE_FP8_E4M3) -#include "../quantization/fp8/amd_detail/quant_utils.cuh" -#endif - -#include - #ifdef USE_ROCM #include - typedef __hip_bfloat16 __nv_bfloat16; + #include "../quantization/fp8/amd/quant_utils.cuh" +typedef __hip_bfloat16 __nv_bfloat16; +#else + #include "../quantization/fp8/nvidia/quant_utils.cuh" #endif #ifndef USE_ROCM -#define WARP_SIZE 32 + #define WARP_SIZE 32 #else -#define WARP_SIZE warpSize + #define WARP_SIZE warpSize #endif #define MAX(a, b) ((a) > (b) ? (a) : (b)) @@ -49,7 +46,7 @@ namespace vllm { // Utility function for attention softmax. -template +template inline __device__ float block_sum(float* red_smem, float sum) { // Decompose the thread index into warp / lane. int warp = threadIdx.x / WARP_SIZE; @@ -86,31 +83,31 @@ inline __device__ float block_sum(float* red_smem, float sum) { // TODO(woosuk): Merge the last two dimensions of the grid. // Grid: (num_heads, num_seqs, max_num_partitions). -template< - typename scalar_t, - typename cache_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS, - bool IS_FP8_KV_CACHE, - int PARTITION_SIZE = 0> // Zero means no partitioning. +template // Zero means no partitioning. __device__ void paged_attention_kernel( - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int num_kv_heads, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride, - const float kv_scale) { + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { const int seq_idx = blockIdx.y; const int partition_idx = blockIdx.z; const int max_num_partitions = gridDim.z; @@ -122,22 +119,29 @@ __device__ void paged_attention_kernel( } const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); - const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; + const int num_blocks_per_partition = + USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; // [start_block_idx, end_block_idx) is the range of blocks to process. - const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; - const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); + const int start_block_idx = + USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; + const int end_block_idx = + MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); const int num_blocks = end_block_idx - start_block_idx; // [start_token_idx, end_token_idx) is the range of tokens to process. const int start_token_idx = start_block_idx * BLOCK_SIZE; - const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); + const int end_token_idx = + MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); const int num_tokens = end_token_idx - start_token_idx; constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); - constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS + constexpr int NUM_THREAD_GROUPS = + NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE + // divides NUM_THREADS assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); - constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = + DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; const int thread_idx = threadIdx.x; const int warp_idx = thread_idx / WARP_SIZE; @@ -147,19 +151,18 @@ __device__ void paged_attention_kernel( const int num_heads = gridDim.x; const int num_queries_per_kv = num_heads / num_kv_heads; const int kv_head_idx = head_idx / num_queries_per_kv; - const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + const float alibi_slope = + alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; // A vector type to store a part of a key or a query. - // The vector size is configured in such a way that the threads in a thread group - // fetch or compute 16 bytes at a time. - // For example, if the size of a thread group is 4 and the data type is half, - // then the vector size is 16 / (4 * sizeof(half)) == 2. + // The vector size is configured in such a way that the threads in a thread + // group fetch or compute 16 bytes at a time. For example, if the size of a + // thread group is 4 and the data type is half, then the vector size is 16 / + // (4 * sizeof(half)) == 2. constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); using K_vec = typename Vec::Type; using Q_vec = typename Vec::Type; -#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) using Quant_vec = typename Vec::Type; -#endif constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; @@ -169,18 +172,21 @@ __device__ void paged_attention_kernel( // Load the query to registers. // Each thread in a thread group has a different part of the query. - // For example, if the the thread group size is 4, then the first thread in the group - // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... - // th vectors of the query, and so on. - // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. + // For example, if the the thread group size is 4, then the first thread in + // the group has 0, 4, 8, ... th vectors of the query, and the second thread + // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because + // q is split from a qkv tensor, it may not be contiguous. const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; #pragma unroll - for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) { + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; + i += NUM_THREAD_GROUPS) { const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; - q_vecs[thread_group_offset][i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + q_vecs[thread_group_offset][i] = + *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); } - __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs + __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a + // memory wall right before we use q_vecs // Memory planning. extern __shared__ char shared_mem[]; @@ -199,51 +205,94 @@ __device__ void paged_attention_kernel( // Each thread group in a warp fetches a key from the block, and computes // dot product with the query. const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; - for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { - // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64 - // because int32 can lead to overflow when this variable is multiplied by large numbers - // (e.g., kv_block_stride). - const int64_t physical_block_number = static_cast(block_table[block_idx]); + + // blocksparse specific vars + int bs_block_offset; + int q_bs_block_id; + if constexpr (IS_BLOCK_SPARSE) { + // const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len, + // blocksparse_block_size); + q_bs_block_id = (seq_len - 1) / blocksparse_block_size; + if (blocksparse_head_sliding_step >= 0) + // sliding on q heads + bs_block_offset = + (tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1; + else + // sliding on kv heads + bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) * + (-blocksparse_head_sliding_step) + + 1; + } + + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast it to + // int64 because int32 can lead to overflow when this variable is multiplied + // by large numbers (e.g., kv_block_stride). + // For blocksparse attention: skip computation on blocks that are not + // attended + if constexpr (IS_BLOCK_SPARSE) { + const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; + const bool is_remote = + ((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0); + const bool is_local = + (k_bs_block_id > q_bs_block_id - blocksparse_local_blocks); + if (!is_remote && !is_local) { + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = + (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + + if (thread_group_offset == 0) { + // NOTE(linxihui): assign very large number to skipped tokens to + // avoid contribution to the sumexp softmax normalizer. This will + // not be used at computing sum(softmax*v) as the blocks will be + // skipped. + logits[token_idx - start_token_idx] = -FLT_MAX; + } + } + continue; + } + } + const int64_t physical_block_number = + static_cast(block_table[block_idx]); // Load a key to registers. // Each thread in a thread group has a different part of the key. - // For example, if the the thread group size is 4, then the first thread in the group - // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th - // vectors of the key, and so on. + // For example, if the the thread group size is 4, then the first thread in + // the group has 0, 4, 8, ... th vectors of the key, and the second thread + // has 1, 5, 9, ... th vectors of the key, and so on. for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { - const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int physical_block_offset = + (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; K_vec k_vecs[NUM_VECS_PER_THREAD]; #pragma unroll for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { - const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride - + kv_head_idx * kv_head_stride - + physical_block_offset * x; + const cache_t* k_ptr = + k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + physical_block_offset * x; const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; const int offset1 = (vec_idx * VEC_SIZE) / x; const int offset2 = (vec_idx * VEC_SIZE) % x; - if constexpr (IS_FP8_KV_CACHE) { -#if defined(ENABLE_FP8_E5M2) - Quant_vec k_vec_quant = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); - // Vector conversion from Quant_vec to K_vec. - k_vecs[j] = fp8_e5m2_unscaled::vec_conversion(k_vec_quant); -#elif defined(ENABLE_FP8_E4M3) - Quant_vec k_vec_quant = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); - // Vector conversion from Quant_vec to K_vec. Use scaled_vec_conversion to convert FP8_E4M3 quantized k - // cache vec to k vec in higher precision (FP16, BFloat16, etc.) - k_vecs[j] = fp8_e4m3::scaled_vec_conversion(k_vec_quant, kv_scale); -#else - assert(false); -#endif + + if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { + k_vecs[j] = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); } else { - k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + // Vector conversion from Quant_vec to K_vec. + Quant_vec k_vec_quant = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + k_vecs[j] = fp8::scaled_convert( + k_vec_quant, kv_scale); } } // Compute dot product. // This includes a reduction across the threads in the same thread group. - float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); + float qk = scale * Qk_dot::dot( + q_vecs[thread_group_offset], k_vecs); // Add the ALiBi bias if slopes are given. qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; @@ -298,13 +347,12 @@ __device__ void paged_attention_kernel( // If partitioning is enabled, store the max logit and exp_sum. if (USE_PARTITIONING && thread_idx == 0) { - float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions - + partition_idx; + float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; *max_logits_ptr = qk_max; - float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions - + partition_idx; + float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; *exp_sums_ptr = exp_sum; } @@ -312,14 +360,13 @@ __device__ void paged_attention_kernel( constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); using V_vec = typename Vec::Type; using L_vec = typename Vec::Type; -#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) using V_quant_vec = typename Vec::Type; -#endif using Float_L_vec = typename FloatVec::Type; constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; - constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); + constexpr int NUM_ROWS_PER_THREAD = + DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. float accs[NUM_ROWS_PER_THREAD]; @@ -330,44 +377,51 @@ __device__ void paged_attention_kernel( scalar_t zero_value; zero(zero_value); - for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { - // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64 - // because int32 can lead to overflow when this variable is multiplied by large numbers - // (e.g., kv_block_stride). - const int64_t physical_block_number = static_cast(block_table[block_idx]); + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast it to + // int64 because int32 can lead to overflow when this variable is multiplied + // by large numbers (e.g., kv_block_stride). + // For blocksparse attention: skip computation on blocks that are not + // attended + if constexpr (IS_BLOCK_SPARSE) { + int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; + if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) && + !((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) { + continue; + } + } + const int64_t physical_block_number = + static_cast(block_table[block_idx]); const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; L_vec logits_vec; - from_float(logits_vec, *reinterpret_cast(logits + token_idx - start_token_idx)); + from_float(logits_vec, *reinterpret_cast(logits + token_idx - + start_token_idx)); - const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride - + kv_head_idx * kv_head_stride; + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; if (row_idx < HEAD_SIZE) { const int offset = row_idx * BLOCK_SIZE + physical_block_offset; V_vec v_vec; - if constexpr (IS_FP8_KV_CACHE) { -#if defined(ENABLE_FP8_E5M2) - V_quant_vec v_quant_vec = *reinterpret_cast(v_ptr + offset); - // Vector conversion from V_quant_vec to V_vec. - v_vec = fp8_e5m2_unscaled::vec_conversion(v_quant_vec); -#elif defined(ENABLE_FP8_E4M3) - V_quant_vec v_quant_vec = *reinterpret_cast(v_ptr + offset); - // Vector conversion from V_quant_vec to V_vec. Use scaled_vec_conversion to convert - // FP8_E4M3 quantized v cache vec to v vec in higher precision (FP16, BFloat16, etc.) - v_vec = fp8_e4m3::scaled_vec_conversion(v_quant_vec, kv_scale); -#else - assert(false); -#endif - } else { + + if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { v_vec = *reinterpret_cast(v_ptr + offset); + } else { + V_quant_vec v_quant_vec = + *reinterpret_cast(v_ptr + offset); + // Vector conversion from V_quant_vec to V_vec. + v_vec = fp8::scaled_convert(v_quant_vec, + kv_scale); } if (block_idx == num_seq_blocks - 1) { - // NOTE(woosuk): When v_vec contains the tokens that are out of the context, - // we should explicitly zero out the values since they may contain NaNs. - // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 + // NOTE(woosuk): When v_vec contains the tokens that are out of the + // context, we should explicitly zero out the values since they may + // contain NaNs. See + // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); #pragma unroll for (int j = 0; j < V_VEC_SIZE; j++) { @@ -390,8 +444,8 @@ __device__ void paged_attention_kernel( accs[i] = acc; } - // NOTE(woosuk): A barrier is required because the shared memory space for logits - // is reused for the output. + // NOTE(woosuk): A barrier is required because the shared memory space for + // logits is reused for the output. __syncthreads(); // Perform reduction across warps. @@ -428,9 +482,9 @@ __device__ void paged_attention_kernel( // Write the final output. if (warp_idx == 0) { - scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE - + partition_idx * HEAD_SIZE; + scalar_t* out_ptr = + out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; @@ -442,79 +496,84 @@ __device__ void paged_attention_kernel( } // Grid: (num_heads, num_seqs, 1). -template< - typename scalar_t, - typename cache_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS, - bool IS_FP8_KV_CACHE> +template __global__ void paged_attention_v1_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int num_kv_heads, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride, - const float kv_scale) { - paged_attention_kernel( - /* exp_sums */ nullptr, /* max_logits */ nullptr, - out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, - max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { + paged_attention_kernel( + /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, + v_cache, num_kv_heads, scale, block_tables, seq_lens, + max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, + kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks, + blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step); } // Grid: (num_heads, num_seqs, max_num_partitions). -template< - typename scalar_t, - typename cache_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS, - bool IS_FP8_KV_CACHE, - int PARTITION_SIZE> +template __global__ void paged_attention_v2_kernel( - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int num_kv_heads, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride, - const float kv_scale) { - paged_attention_kernel( - exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, - block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, - q_stride, kv_block_stride, kv_head_stride, kv_scale); + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { + paged_attention_kernel( + exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, + block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, + kv_block_stride, kv_head_stride, kv_scale, tp_rank, + blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step); } // Grid: (num_heads, num_seqs). -template< - typename scalar_t, - int HEAD_SIZE, - int NUM_THREADS, - int PARTITION_SIZE> +template __global__ void paged_attention_v2_reduce_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_partitions) { + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_partitions) { const int num_heads = gridDim.x; const int head_idx = blockIdx.x; const int seq_idx = blockIdx.y; @@ -522,9 +581,11 @@ __global__ void paged_attention_v2_reduce_kernel( const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); if (num_partitions == 1) { // No need to reduce. Only copy tmp_out to out. - scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; - const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE; + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { out_ptr[i] = tmp_out_ptr[i]; } @@ -543,8 +604,9 @@ __global__ void paged_attention_v2_reduce_kernel( // Load max logits to shared memory. float* shared_max_logits = reinterpret_cast(shared_mem); - const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions; + const float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; float max_logit = -FLT_MAX; for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { const float l = max_logits_ptr[i]; @@ -573,9 +635,11 @@ __global__ void paged_attention_v2_reduce_kernel( max_logit = VLLM_SHFL_SYNC(max_logit, 0); // Load rescaled exp sums to shared memory. - float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions); - const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions; + float* shared_exp_sums = + reinterpret_cast(shared_mem + sizeof(float) * num_partitions); + const float* exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; float global_exp_sum = 0.0f; for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { float l = shared_max_logits[i]; @@ -588,61 +652,52 @@ __global__ void paged_attention_v2_reduce_kernel( const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); // Aggregate tmp_out to out. - const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE; - scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; #pragma unroll for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { float acc = 0.0f; for (int j = 0; j < num_partitions; ++j) { - acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum; + acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * + inv_global_exp_sum; } from_float(out_ptr[i], acc); } } -} // namespace vllm - -#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ - VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ - ((void*)vllm::paged_attention_v1_kernel), shared_mem_size); \ - vllm::paged_attention_v1_kernel<<>>( \ - out_ptr, \ - query_ptr, \ - key_cache_ptr, \ - value_cache_ptr, \ - num_kv_heads, \ - scale, \ - block_tables_ptr, \ - seq_lens_ptr, \ - max_num_blocks_per_seq, \ - alibi_slopes_ptr, \ - q_stride, \ - kv_block_stride, \ - kv_head_stride, \ - kv_scale); +} // namespace vllm + +#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ + VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ + ((void*)vllm::paged_attention_v1_kernel), \ + shared_mem_size); \ + vllm::paged_attention_v1_kernel \ + <<>>( \ + out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ + scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ + kv_scale, tp_rank, blocksparse_local_blocks, \ + blocksparse_vert_stride, blocksparse_block_size, \ + blocksparse_head_sliding_step); // TODO(woosuk): Tune NUM_THREADS. -template< - typename T, - typename CACHE_T, - int BLOCK_SIZE, - bool IS_FP8_KV_CACHE, - int NUM_THREADS = 128> +template void paged_attention_v1_launcher( - torch::Tensor& out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& seq_lens, - int max_seq_len, - const c10::optional& alibi_slopes, - float kv_scale) { + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const c10::optional& alibi_slopes, float kv_scale, + const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -655,9 +710,10 @@ void paged_attention_v1_launcher( assert(head_size % thread_group_size == 0); // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = alibi_slopes ? - reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; T* out_ptr = reinterpret_cast(out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); @@ -667,7 +723,8 @@ void paged_attention_v1_launcher( int* seq_lens_ptr = seq_lens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; + int padded_max_seq_len = + DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; int logits_size = padded_max_seq_len * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len @@ -697,6 +754,9 @@ void paged_attention_v1_launcher( case 128: LAUNCH_PAGED_ATTENTION_V1(128); break; + case 192: + LAUNCH_PAGED_ATTENTION_V1(192); + break; case 256: LAUNCH_PAGED_ATTENTION_V1(256); break; @@ -706,128 +766,94 @@ void paged_attention_v1_launcher( } } -#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ - paged_attention_v1_launcher( \ - out, \ - query, \ - key_cache, \ - value_cache, \ - num_kv_heads, \ - scale, \ - block_tables, \ - seq_lens, \ - max_seq_len, \ - alibi_slopes, \ - kv_scale); +#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ + paged_attention_v1_launcher( \ + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ + seq_lens, max_seq_len, alibi_slopes, kv_scale, tp_rank, \ + blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); + +#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + switch (is_block_sparse) { \ + case true: \ + CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ + break; \ + case false: \ + CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ + break; \ + } // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \ - switch (block_size) { \ - case 8: \ - CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \ - break; \ - case 16: \ - CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \ - break; \ - case 32: \ - CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ + switch (block_size) { \ + case 8: \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ + break; \ + case 16: \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ + break; \ + case 32: \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } void paged_attention_v1( - torch::Tensor& out, // [num_seqs, num_heads, head_size] - torch::Tensor& query, // [num_seqs, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - int num_kv_heads, // [num_heads] - float scale, - torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& seq_lens, // [num_seqs] - int block_size, - int max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, - float kv_scale) { - if (kv_cache_dtype == "auto") { - if (query.dtype() == at::ScalarType::Float) { - CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } - } else if (kv_cache_dtype == "fp8") { - if (query.dtype() == at::ScalarType::Float) { - CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } - } else { - TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); - } + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, // [num_heads] + double scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& seq_lens, // [num_seqs] + int64_t block_size, int64_t max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { + const bool is_block_sparse = (blocksparse_vert_stride > 1); + + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, + CALL_V1_LAUNCHER_BLOCK_SIZE) } -#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ - vllm::paged_attention_v2_kernel \ - <<>>( \ - exp_sums_ptr, \ - max_logits_ptr, \ - tmp_out_ptr, \ - query_ptr, \ - key_cache_ptr, \ - value_cache_ptr, \ - num_kv_heads, \ - scale, \ - block_tables_ptr, \ - seq_lens_ptr, \ - max_num_blocks_per_seq, \ - alibi_slopes_ptr, \ - q_stride, \ - kv_block_stride, \ - kv_head_stride, \ - kv_scale); \ - vllm::paged_attention_v2_reduce_kernel \ - <<>>( \ - out_ptr, \ - exp_sums_ptr, \ - max_logits_ptr, \ - tmp_out_ptr, \ - seq_lens_ptr, \ - max_num_partitions); - -template< - typename T, - typename CACHE_T, - int BLOCK_SIZE, - bool IS_FP8_KV_CACHE, - int NUM_THREADS = 128, - int PARTITION_SIZE = 512> +#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ + vllm::paged_attention_v2_kernel \ + <<>>( \ + exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ + value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ + seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ + kv_block_stride, kv_head_stride, kv_scale, tp_rank, \ + blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); \ + vllm::paged_attention_v2_reduce_kernel \ + <<>>( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \ + max_num_partitions); + +template void paged_attention_v2_launcher( - torch::Tensor& out, - torch::Tensor& exp_sums, - torch::Tensor& max_logits, - torch::Tensor& tmp_out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& seq_lens, - int max_seq_len, - const c10::optional& alibi_slopes, - float kv_scale) { + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const c10::optional& alibi_slopes, float kv_scale, + const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -840,9 +866,10 @@ void paged_attention_v2_launcher( assert(head_size % thread_group_size == 0); // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = alibi_slopes ? - reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; T* out_ptr = reinterpret_cast(out.data_ptr()); float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); @@ -888,6 +915,9 @@ void paged_attention_v2_launcher( case 128: LAUNCH_PAGED_ATTENTION_V2(128); break; + case 192: + LAUNCH_PAGED_ATTENTION_V2(192); + break; case 256: LAUNCH_PAGED_ATTENTION_V2(256); break; @@ -897,81 +927,66 @@ void paged_attention_v2_launcher( } } -#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ - paged_attention_v2_launcher( \ - out, \ - exp_sums, \ - max_logits, \ - tmp_out, \ - query, \ - key_cache, \ - value_cache, \ - num_kv_heads, \ - scale, \ - block_tables, \ - seq_lens, \ - max_seq_len, \ - alibi_slopes, \ - kv_scale); +#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ + paged_attention_v2_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ + kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); + +#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + switch (is_block_sparse) { \ + case true: \ + CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ + break; \ + case false: \ + CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ + break; \ + } // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \ - switch (block_size) { \ - case 8: \ - CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \ - break; \ - case 16: \ - CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \ - break; \ - case 32: \ - CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ + switch (block_size) { \ + case 8: \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ + break; \ + case 16: \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ + break; \ + case 32: \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } void paged_attention_v2( - torch::Tensor& out, // [num_seqs, num_heads, head_size] - torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] - torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] - torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - torch::Tensor& query, // [num_seqs, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - int num_kv_heads, // [num_heads] - float scale, - torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& seq_lens, // [num_seqs] - int block_size, - int max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, - float kv_scale) { - if (kv_cache_dtype == "auto") { - if (query.dtype() == at::ScalarType::Float) { - CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } - } else if (kv_cache_dtype == "fp8") { - if (query.dtype() == at::ScalarType::Float) { - CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } - } else { - TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); - } + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& + tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, // [num_heads] + double scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& seq_lens, // [num_seqs] + int64_t block_size, int64_t max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { + const bool is_block_sparse = (blocksparse_vert_stride > 1); + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, + CALL_V2_LAUNCHER_BLOCK_SIZE) } #undef WARP_SIZE diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh index ff64c4bd8f80c..cdcee42748998 100644 --- a/csrc/attention/attention_utils.cuh +++ b/csrc/attention/attention_utils.cuh @@ -1,5 +1,6 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -26,7 +27,7 @@ namespace vllm { // Q*K^T operation. -template +template inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { using A_vec = typename FloatVec::Type; // Compute the parallel products for Q*K^T (treat vector lanes separately). @@ -45,12 +46,12 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { return qk; } -template +template struct Qk_dot { - template + template static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { return qk_dot_(q, k); } }; -} // namespace vllm +} // namespace vllm diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh index 31e0cee01d2e1..3cdcb95e08099 100644 --- a/csrc/attention/dtype_bfloat16.cuh +++ b/csrc/attention/dtype_bfloat16.cuh @@ -1,6 +1,8 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp - * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -28,8 +30,8 @@ #include #include - typedef __hip_bfloat162 __nv_bfloat162; - typedef __hip_bfloat16 __nv_bfloat16; +typedef __hip_bfloat162 __nv_bfloat162; +typedef __hip_bfloat16 __nv_bfloat16; #endif #include @@ -50,37 +52,37 @@ struct bf16_8_t { }; // BF16 vector types for Q, K, V. -template<> +template <> struct Vec<__nv_bfloat16, 1> { using Type = __nv_bfloat16; }; -template<> +template <> struct Vec<__nv_bfloat16, 2> { using Type = __nv_bfloat162; }; -template<> +template <> struct Vec<__nv_bfloat16, 4> { using Type = bf16_4_t; }; -template<> +template <> struct Vec<__nv_bfloat16, 8> { using Type = bf16_8_t; }; // FP32 accumulator vector types corresponding to Vec. -template<> +template <> struct FloatVec<__nv_bfloat16> { using Type = float; }; -template<> +template <> struct FloatVec<__nv_bfloat162> { using Type = float2; }; -template<> +template <> struct FloatVec { using Type = Float4_; }; -template<> +template <> struct FloatVec { using Type = Float8_; }; @@ -108,9 +110,9 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { assert(false); #else #ifndef USE_ROCM - return a + b; + return a + b; #else - return __hadd(a, b); + return __hadd(a, b); #endif #endif } @@ -161,7 +163,7 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) { } // Vector multiplication. -template<> +template <> inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); @@ -170,7 +172,7 @@ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { #endif } -template<> +template <> inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); @@ -179,12 +181,12 @@ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { #endif } -template<> +template <> inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) { return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); } -template<> +template <> inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { bf16_4_t c; c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); @@ -192,7 +194,7 @@ inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { return c; } -template<> +template <> inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { __nv_bfloat162 s = bf162bf162(a); bf16_4_t c; @@ -201,7 +203,7 @@ inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { return c; } -template<> +template <> inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { bf16_8_t c; c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); @@ -211,7 +213,7 @@ inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { return c; } -template<> +template <> inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { __nv_bfloat162 s = bf162bf162(a); bf16_8_t c; @@ -222,26 +224,26 @@ inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { return c; } -template<> +template <> inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) { float fa = __bfloat162float(a); float fb = __bfloat162float(b); return fa * fb; } -template<> +template <> inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) { float2 fa = bf1622float2(a); float2 fb = bf1622float2(b); return mul(fa, fb); } -template<> +template <> inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) { return mul(bf162bf162(a), b); } -template<> +template <> inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { Float4_ fc; fc.x = mul(a.x, b.x); @@ -249,7 +251,7 @@ inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { return fc; } -template<> +template <> inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { __nv_bfloat162 s = bf162bf162(a); Float4_ fc; @@ -258,7 +260,7 @@ inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { return fc; } -template<> +template <> inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { Float8_ fc; fc.x = mul(a.x, b.x); @@ -268,7 +270,7 @@ inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { return fc; } -template<> +template <> inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { __nv_bfloat162 s = bf162bf162(a); Float8_ fc; @@ -280,7 +282,8 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { } // Vector fused multiply-add. -inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { +inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, + __nv_bfloat162 c) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); #else @@ -288,7 +291,8 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bf #endif } -inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) { +inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, + __nv_bfloat162 c) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); #else @@ -379,23 +383,23 @@ inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) { } // Vector sum. -template<> +template <> inline __device__ float sum(__nv_bfloat16 v) { return __bfloat162float(v); } -template<> +template <> inline __device__ float sum(__nv_bfloat162 v) { float2 vf = bf1622float2(v); return vf.x + vf.y; } -template<> +template <> inline __device__ float sum(bf16_4_t v) { return sum(v.x) + sum(v.y); } -template<> +template <> inline __device__ float sum(bf16_8_t v) { return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); } @@ -448,4 +452,4 @@ inline __device__ void zero(__nv_bfloat16& dst) { #endif } -} // namespace vllm +} // namespace vllm diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh index d3271e69cd69d..3a1815f0ed4fc 100644 --- a/csrc/attention/dtype_float16.cuh +++ b/csrc/attention/dtype_float16.cuh @@ -1,6 +1,8 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp - * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -30,37 +32,37 @@ namespace vllm { // FP16 vector types for Q, K, V. -template<> +template <> struct Vec { using Type = uint16_t; }; -template<> +template <> struct Vec { using Type = uint32_t; }; -template<> +template <> struct Vec { using Type = uint2; }; -template<> +template <> struct Vec { using Type = uint4; }; // FP32 accumulator vector types corresponding to Vec. -template<> +template <> struct FloatVec { using Type = float; }; -template<> +template <> struct FloatVec { using Type = float2; }; -template<> +template <> struct FloatVec { using Type = Float4_; }; -template<> +template <> struct FloatVec { using Type = Float8_; }; @@ -73,8 +75,8 @@ inline __device__ uint32_t h0_h0(uint16_t a) { return b; #else union { - uint32_t u32; - uint16_t u16[2]; + uint32_t u32; + uint16_t u16[2]; } tmp; tmp.u16[0] = a; tmp.u16[1] = a; @@ -130,10 +132,12 @@ inline __device__ uint32_t float2_to_half2(float2 f) { } tmp; #ifndef USE_ROCM #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" + : "=r"(tmp.u32) + : "f"(f.y), "f"(f.x)); #else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); #endif #else tmp.u16[0] = float_to_half(f.x); @@ -201,7 +205,7 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) { } // Vector multiplication. -template<> +template <> inline __device__ uint16_t mul(uint16_t a, uint16_t b) { uint16_t c; #ifndef USE_ROCM @@ -212,7 +216,7 @@ inline __device__ uint16_t mul(uint16_t a, uint16_t b) { return c; } -template<> +template <> inline __device__ uint32_t mul(uint32_t a, uint32_t b) { uint32_t c; #ifndef USE_ROCM @@ -223,12 +227,12 @@ inline __device__ uint32_t mul(uint32_t a, uint32_t b) { return c; } -template<> +template <> inline __device__ uint32_t mul(uint16_t a, uint32_t b) { return mul(h0_h0(a), b); } -template<> +template <> inline __device__ uint2 mul(uint2 a, uint2 b) { uint2 c; c.x = mul(a.x, b.x); @@ -236,7 +240,7 @@ inline __device__ uint2 mul(uint2 a, uint2 b) { return c; } -template<> +template <> inline __device__ uint2 mul(uint16_t a, uint2 b) { uint32_t s = h0_h0(a); uint2 c; @@ -245,7 +249,7 @@ inline __device__ uint2 mul(uint16_t a, uint2 b) { return c; } -template<> +template <> inline __device__ uint4 mul(uint4 a, uint4 b) { uint4 c; c.x = mul(a.x, b.x); @@ -255,7 +259,7 @@ inline __device__ uint4 mul(uint4 a, uint4 b) { return c; } -template<> +template <> inline __device__ uint4 mul(uint16_t a, uint4 b) { uint32_t s = h0_h0(a); uint4 c; @@ -266,26 +270,26 @@ inline __device__ uint4 mul(uint16_t a, uint4 b) { return c; } -template<> +template <> inline __device__ float mul(uint16_t a, uint16_t b) { float fa = half_to_float(a); float fb = half_to_float(b); return fa * fb; } -template<> +template <> inline __device__ float2 mul(uint32_t a, uint32_t b) { float2 fa = half2_to_float2(a); float2 fb = half2_to_float2(b); return mul(fa, fb); } -template<> +template <> inline __device__ float2 mul(uint16_t a, uint32_t b) { return mul(h0_h0(a), b); } -template<> +template <> inline __device__ Float4_ mul(uint2 a, uint2 b) { Float4_ fc; fc.x = mul(a.x, b.x); @@ -293,7 +297,7 @@ inline __device__ Float4_ mul(uint2 a, uint2 b) { return fc; } -template<> +template <> inline __device__ Float4_ mul(uint16_t a, uint2 b) { uint32_t s = h0_h0(a); Float4_ fc; @@ -302,7 +306,7 @@ inline __device__ Float4_ mul(uint16_t a, uint2 b) { return fc; } -template<> +template <> inline __device__ Float8_ mul(uint4 a, uint4 b) { Float8_ fc; fc.x = mul(a.x, b.x); @@ -312,7 +316,7 @@ inline __device__ Float8_ mul(uint4 a, uint4 b) { return fc; } -template<> +template <> inline __device__ Float8_ mul(uint16_t a, uint4 b) { uint32_t s = h0_h0(a); Float8_ fc; @@ -327,9 +331,13 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) { inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { uint32_t d; #ifndef USE_ROCM - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(d) + : "r"(a), "r"(b), "r"(c)); #else - asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c)); + asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" + : "=v"(d) + : "v"(a), "v"(b), "v"(c)); #endif return d; } @@ -423,24 +431,24 @@ inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) { } // Vector sum. -template<> +template <> inline __device__ float sum(uint16_t v) { return half_to_float(v); } -template<> +template <> inline __device__ float sum(uint32_t v) { float2 tmp = half2_to_float2(v); return tmp.x + tmp.y; } -template<> +template <> inline __device__ float sum(uint2 v) { uint32_t c = add(v.x, v.y); return sum(c); } -template<> +template <> inline __device__ float sum(uint4 v) { uint32_t c = add(v.x, v.y); c = add(c, v.z); @@ -470,13 +478,9 @@ inline __device__ void from_float(uint4& dst, Float8_ src) { } // From float16 to float32. -inline __device__ float to_float(uint16_t u) { - return half_to_float(u); -} +inline __device__ float to_float(uint16_t u) { return half_to_float(u); } -inline __device__ float2 to_float(uint32_t u) { - return half2_to_float2(u); -} +inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); } inline __device__ Float4_ to_float(uint2 u) { Float4_ tmp; @@ -495,8 +499,6 @@ inline __device__ Float8_ to_float(uint4 u) { } // Zero-out a variable. -inline __device__ void zero(uint16_t& dst) { - dst = uint16_t(0); -} +inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); } -} // namespace vllm +} // namespace vllm diff --git a/csrc/attention/dtype_float32.cuh b/csrc/attention/dtype_float32.cuh index b200d2d226eb0..7c6a686db3ba9 100644 --- a/csrc/attention/dtype_float32.cuh +++ b/csrc/attention/dtype_float32.cuh @@ -1,6 +1,8 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp - * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -38,37 +40,35 @@ struct Float8_ { }; // FP32 vector types for Q, K, V. -template<> +template <> struct Vec { using Type = float; }; -template<> +template <> struct Vec { using Type = float2; }; -template<> +template <> struct Vec { using Type = float4; }; // FP32 accumulator vector types corresponding to Vec. -template<> +template <> struct FloatVec { using Type = float; }; -template<> +template <> struct FloatVec { using Type = float2; }; -template<> +template <> struct FloatVec { using Type = float4; }; // Vector addition. -inline __device__ float add(float a, float b) { - return a + b; -} +inline __device__ float add(float a, float b) { return a + b; } inline __device__ float2 add(float2 a, float2 b) { float2 c; @@ -87,12 +87,12 @@ inline __device__ float4 add(float4 a, float4 b) { } // Vector multiplication. -template<> +template <> inline __device__ float mul(float a, float b) { return a * b; } -template<> +template <> inline __device__ float2 mul(float2 a, float2 b) { float2 c; c.x = a.x * b.x; @@ -100,7 +100,7 @@ inline __device__ float2 mul(float2 a, float2 b) { return c; } -template<> +template <> inline __device__ float2 mul(float a, float2 b) { float2 c; c.x = a * b.x; @@ -108,7 +108,7 @@ inline __device__ float2 mul(float a, float2 b) { return c; } -template<> +template <> inline __device__ float4 mul(float4 a, float4 b) { float4 c; c.x = a.x * b.x; @@ -118,7 +118,7 @@ inline __device__ float4 mul(float4 a, float4 b) { return c; } -template<> +template <> inline __device__ float4 mul(float a, float4 b) { float4 c; c.x = a * b.x; @@ -129,9 +129,7 @@ inline __device__ float4 mul(float a, float4 b) { } // Vector fused multiply-add. -inline __device__ float fma(float a, float b, float c) { - return a * b + c; -} +inline __device__ float fma(float a, float b, float c) { return a * b + c; } inline __device__ float2 fma(float2 a, float2 b, float2 c) { float2 d; @@ -182,35 +180,33 @@ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { } // Vector sum. -template<> +template <> inline __device__ float sum(float v) { return v; } -template<> +template <> inline __device__ float sum(float2 v) { return v.x + v.y; } -template<> +template <> inline __device__ float sum(float4 v) { return v.x + v.y + v.z + v.w; } -template<> +template <> inline __device__ float sum(Float4_ v) { return v.x.x + v.x.y + v.y.x + v.y.y; } -template<> +template <> inline __device__ float sum(Float8_ v) { return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; } // Vector dot product. -inline __device__ float dot(float a, float b) { - return a * b; -} +inline __device__ float dot(float a, float b) { return a * b; } inline __device__ float dot(float2 a, float2 b) { float2 c = mul(a, b); @@ -232,42 +228,24 @@ inline __device__ float dot(Float8_ a, Float8_ b) { } // From float to float. -inline __device__ void from_float(float& dst, float src) { - dst = src; -} +inline __device__ void from_float(float& dst, float src) { dst = src; } -inline __device__ void from_float(float2& dst, float2 src) { - dst = src; -} +inline __device__ void from_float(float2& dst, float2 src) { dst = src; } -inline __device__ void from_float(float4& dst, float4 src) { - dst = src; -} +inline __device__ void from_float(float4& dst, float4 src) { dst = src; } // From float to float. -inline __device__ float to_float(float u) { - return u; -} +inline __device__ float to_float(float u) { return u; } -inline __device__ float2 to_float(float2 u) { - return u; -} +inline __device__ float2 to_float(float2 u) { return u; } -inline __device__ float4 to_float(float4 u) { - return u; -} +inline __device__ float4 to_float(float4 u) { return u; } -inline __device__ Float4_ to_float(Float4_ u) { - return u; -} +inline __device__ Float4_ to_float(Float4_ u) { return u; } -inline __device__ Float8_ to_float(Float8_ u) { - return u; -} +inline __device__ Float8_ to_float(Float8_ u) { return u; } // Zero-out a variable. -inline __device__ void zero(float& dst) { - dst = 0.f; -} +inline __device__ void zero(float& dst) { dst = 0.f; } -} // namespace vllm +} // namespace vllm diff --git a/csrc/attention/dtype_fp8.cuh b/csrc/attention/dtype_fp8.cuh index d11dee91ebe87..e714e321b0beb 100644 --- a/csrc/attention/dtype_fp8.cuh +++ b/csrc/attention/dtype_fp8.cuh @@ -3,33 +3,39 @@ #include "attention_generic.cuh" #include -#ifdef ENABLE_FP8_E5M2 -#include -#endif +#ifdef ENABLE_FP8 + #ifndef USE_ROCM + #include + #endif // USE_ROCM +#endif // ENABLE_FP8 namespace vllm { -#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) -// fp8 vector types for quantization of kv cache -template<> +enum class Fp8KVCacheDataType { + kAuto = 0, + kFp8E4M3 = 1, + kFp8E5M2 = 2, +}; + +// fp8 vector types for quantization of kv cache +template <> struct Vec { - using Type = uint8_t; + using Type = uint8_t; }; -template<> +template <> struct Vec { - using Type = uint16_t; + using Type = uint16_t; }; -template<> +template <> struct Vec { - using Type = uint32_t; + using Type = uint32_t; }; -template<> +template <> struct Vec { - using Type = uint2; + using Type = uint2; }; -#endif // ENABLE_FP8_E5M2 -} // namespace vllm +} // namespace vllm diff --git a/csrc/cache.h b/csrc/cache.h index 10871b3670bac..86caa9345361d 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -1,38 +1,32 @@ #pragma once -#include +#include #include #include -void swap_blocks( - torch::Tensor& src, - torch::Tensor& dst, - const std::map& block_mapping); +void swap_blocks(torch::Tensor& src, torch::Tensor& dst, + const torch::Tensor& block_mapping); -void copy_blocks( - std::vector& key_caches, - std::vector& value_caches, - torch::Tensor& block_mapping); +// Note: the key_caches and value_caches vectors are constant but +// not the Tensors they contain. The vectors need to be const refs +// in order to satisfy pytorch's C++ operator registration code. +void copy_blocks(std::vector const& key_caches, + std::vector const& value_caches, + const torch::Tensor& block_mapping); -void reshape_and_cache( - torch::Tensor& key, - torch::Tensor& value, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, - const float kv_scale); +void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, + torch::Tensor& key_cache, torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype, + const double kv_scale); -void reshape_and_cache_flash( - torch::Tensor& key, - torch::Tensor& value, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype); +void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype); // Just for unittest -void convert_fp8( - torch::Tensor& src_cache, - torch::Tensor& dst_cache); +void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, + const double scale, const std::string& kv_cache_dtype); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 1e02f7fcbae4c..72041076ae009 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1,13 +1,14 @@ -#include +#include #include #include #include "cuda_compat.h" #include "dispatch_utils.h" -#if defined(ENABLE_FP8_E5M2) -#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh" -#elif defined(ENABLE_FP8_E4M3) -#include "quantization/fp8/amd_detail/quant_utils.cuh" + +#ifdef USE_ROCM + #include "quantization/fp8/amd/quant_utils.cuh" +#else + #include "quantization/fp8/nvidia/quant_utils.cuh" #endif #include @@ -17,20 +18,17 @@ #ifdef USE_ROCM #include - typedef __hip_bfloat16 __nv_bfloat16; +typedef __hip_bfloat16 __nv_bfloat16; #endif -void swap_blocks( - torch::Tensor& src, - torch::Tensor& dst, - const std::map& block_mapping) { +void swap_blocks(torch::Tensor& src, torch::Tensor& dst, + const torch::Tensor& block_mapping) { torch::Device src_device = src.device(); torch::Device dst_device = dst.device(); cudaMemcpyKind memcpy_type; if (src_device.is_cuda() && dst_device.is_cuda()) { - TORCH_CHECK( - src_device.index() == dst_device.index(), - "src and dst must be on the same GPU"); + TORCH_CHECK(src_device.index() == dst_device.index(), + "src and dst must be on the same GPU"); memcpy_type = cudaMemcpyDeviceToDevice; } else if (src_device.is_cuda() && dst_device.is_cpu()) { memcpy_type = cudaMemcpyDeviceToHost; @@ -40,41 +38,44 @@ void swap_blocks( TORCH_CHECK(false, "Invalid device combination"); } - char *src_ptr = static_cast(src.data_ptr()); - char *dst_ptr = static_cast(dst.data_ptr()); + // NOTE(youkaichao): keep in mind that `block_mapping` should be + // a cpu tensor, otherwise every `item` call will require a gpu-cpu + // synchronization. + TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU"); + + char* src_ptr = static_cast(src.data_ptr()); + char* dst_ptr = static_cast(dst.data_ptr()); const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); - const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device); + const at::cuda::OptionalCUDAGuard device_guard( + src_device.is_cuda() ? src_device : dst_device); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // NOTE(woosuk): This can be slow if the number of blocks is large. - for (const auto& pair : block_mapping) { - int64_t src_block_number = pair.first; - int64_t dst_block_number = pair.second; + const int64_t num_blocks = block_mapping.size(0); + for (size_t i = 0; i < num_blocks; i++) { + int64_t src_block_number = block_mapping[i][0].item(); + int64_t dst_block_number = block_mapping[i][1].item(); int64_t src_offset = src_block_number * block_size_in_bytes; int64_t dst_offset = dst_block_number * block_size_in_bytes; - cudaMemcpyAsync( - dst_ptr + dst_offset, - src_ptr + src_offset, - block_size_in_bytes, - memcpy_type, - stream); + cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset, + block_size_in_bytes, memcpy_type, stream); } } namespace vllm { // Grid: (num_layers, num_pairs) -template -__global__ void copy_blocks_kernel( - int64_t* key_cache_ptrs, - int64_t* value_cache_ptrs, - const int64_t* __restrict__ block_mapping, - const int numel_per_block) { +template +__global__ void copy_blocks_kernel(int64_t* key_cache_ptrs, + int64_t* value_cache_ptrs, + const int64_t* __restrict__ block_mapping, + const int numel_per_block) { const int layer_idx = blockIdx.x; const int pair_idx = blockIdx.y; scalar_t* key_cache = reinterpret_cast(key_cache_ptrs[layer_idx]); - scalar_t* value_cache = reinterpret_cast(value_cache_ptrs[layer_idx]); + scalar_t* value_cache = + reinterpret_cast(value_cache_ptrs[layer_idx]); int64_t src_block_number = block_mapping[2 * pair_idx]; int64_t dst_block_number = block_mapping[2 * pair_idx + 1]; @@ -92,12 +93,14 @@ __global__ void copy_blocks_kernel( } } -} // namespace vllm +} // namespace vllm -void copy_blocks( - std::vector& key_caches, - std::vector& value_caches, - torch::Tensor& block_mapping) { +// Note: the key_caches and value_caches vectors are constant but +// not the Tensors they contain. The vectors need to be const refs +// in order to satisfy pytorch's C++ operator registration code. +void copy_blocks(std::vector const& key_caches, + std::vector const& value_caches, + const torch::Tensor& block_mapping) { int num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); if (num_layers == 0) { @@ -111,8 +114,10 @@ void copy_blocks( int64_t key_cache_ptrs[num_layers]; int64_t value_cache_ptrs[num_layers]; for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { - key_cache_ptrs[layer_idx] = reinterpret_cast(key_caches[layer_idx].data_ptr()); - value_cache_ptrs[layer_idx] = reinterpret_cast(value_caches[layer_idx].data_ptr()); + key_cache_ptrs[layer_idx] = + reinterpret_cast(key_caches[layer_idx].data_ptr()); + value_cache_ptrs[layer_idx] = + reinterpret_cast(value_caches[layer_idx].data_ptr()); } // block_mapping is a 2D tensor with shape (num_pairs, 2). @@ -120,10 +125,12 @@ void copy_blocks( // Move the data structures to the GPU. // NOTE: This synchronizes the CPU and GPU. - torch::Tensor key_cache_ptrs_tensor = torch::from_blob( - key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); - torch::Tensor value_cache_ptrs_tensor = torch::from_blob( - value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); + torch::Tensor key_cache_ptrs_tensor = + torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64) + .to(cache_device); + torch::Tensor value_cache_ptrs_tensor = + torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64) + .to(cache_device); // Launch the kernel. const int numel_per_block = key_caches[0][0].numel(); @@ -132,31 +139,28 @@ void copy_blocks( const at::cuda::OptionalCUDAGuard device_guard(cache_device); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( - key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { - vllm::copy_blocks_kernel<<>>( - key_cache_ptrs_tensor.data_ptr(), - value_cache_ptrs_tensor.data_ptr(), - block_mapping.data_ptr(), - numel_per_block); - })); + key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { + vllm::copy_blocks_kernel<<>>( + key_cache_ptrs_tensor.data_ptr(), + value_cache_ptrs_tensor.data_ptr(), + block_mapping.data_ptr(), numel_per_block); + })); } namespace vllm { -template +template __global__ void reshape_and_cache_kernel( - const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] - const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] - cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] - const int64_t* __restrict__ slot_mapping, // [num_tokens] - const int key_stride, - const int value_stride, - const int num_heads, - const int head_size, - const int block_size, - const int x, - const float kv_scale) { + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, + // block_size, x] + cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, + // block_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int key_stride, const int value_stride, const int num_heads, + const int head_size, const int block_size, const int x, + const float kv_scale) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; if (slot_idx < 0) { @@ -177,47 +181,39 @@ __global__ void reshape_and_cache_kernel( const int x_idx = head_offset / x; const int x_offset = head_offset % x; - const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x - + head_idx * (head_size / x) * block_size * x - + x_idx * block_size * x - + block_offset * x - + x_offset; - const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size - + head_idx * head_size * block_size - + head_offset * block_size - + block_offset; + const int64_t tgt_key_idx = + block_idx * num_heads * (head_size / x) * block_size * x + + head_idx * (head_size / x) * block_size * x + x_idx * block_size * x + + block_offset * x + x_offset; + const int64_t tgt_value_idx = + block_idx * num_heads * head_size * block_size + + head_idx * head_size * block_size + head_offset * block_size + + block_offset; scalar_t tgt_key = key[src_key_idx]; scalar_t tgt_value = value[src_value_idx]; - if constexpr (is_fp8_kv_cache) { -#if defined(ENABLE_FP8_E5M2) - key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_key); - value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_value); -#elif defined(ENABLE_FP8_E4M3) - key_cache[tgt_key_idx] = fp8_e4m3::scaled_vec_conversion(tgt_key, kv_scale); - value_cache[tgt_value_idx] = fp8_e4m3::scaled_vec_conversion(tgt_value, kv_scale); -#else - assert(false); -#endif - } else { + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { key_cache[tgt_key_idx] = tgt_key; value_cache[tgt_value_idx] = tgt_value; + } else { + key_cache[tgt_key_idx] = + fp8::scaled_convert(tgt_key, kv_scale); + value_cache[tgt_value_idx] = + fp8::scaled_convert(tgt_value, kv_scale); } } } -template +template __global__ void reshape_and_cache_flash_kernel( - const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] - const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] - scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size] - scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size] - const int64_t* __restrict__ slot_mapping, // [num_tokens] - const int block_stride, - const int key_stride, - const int value_stride, - const int num_heads, - const int head_size, - const int block_size) { + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, + // head_size] + scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, + // head_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride, const int key_stride, const int value_stride, + const int num_heads, const int head_size, const int block_size) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; // NOTE: slot_idx can be -1 if the token is padded @@ -232,40 +228,37 @@ __global__ void reshape_and_cache_flash_kernel( const int64_t src_value_idx = token_idx * value_stride + i; const int head_idx = i / head_size; const int head_offset = i % head_size; - const int64_t tgt_value_idx = block_idx * block_stride - + block_offset * num_heads * head_size - + head_idx * head_size - + head_offset; + const int64_t tgt_value_idx = block_idx * block_stride + + block_offset * num_heads * head_size + + head_idx * head_size + head_offset; k_cache[tgt_value_idx] = key[src_key_idx]; v_cache[tgt_value_idx] = value[src_value_idx]; } } -} // namespace vllm - -#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \ - vllm::reshape_and_cache_kernel<<>>( \ - reinterpret_cast(key.data_ptr()), \ - reinterpret_cast(value.data_ptr()), \ - reinterpret_cast(key_cache.data_ptr()), \ - reinterpret_cast(value_cache.data_ptr()), \ - slot_mapping.data_ptr(), \ - key_stride, \ - value_stride, \ - num_heads, \ - head_size, \ - block_size, \ - x, \ - kv_scale); +} // namespace vllm + +// KV_T is the stored data type of kv-cache. +// CACHE_T is the data type of key and value tensors. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ + vllm::reshape_and_cache_kernel \ + <<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + slot_mapping.data_ptr(), key_stride, value_stride, \ + num_heads, head_size, block_size, x, kv_scale); void reshape_and_cache( - torch::Tensor& key, // [num_tokens, num_heads, head_size] - torch::Tensor& value, // [num_tokens, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - torch::Tensor& slot_mapping, // [num_tokens] - const std::string& kv_cache_dtype, - const float kv_scale) -{ + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& slot_mapping, // [num_tokens] + const std::string& kv_cache_dtype, const double kv_scale) { int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); @@ -279,35 +272,18 @@ void reshape_and_cache( dim3 block(std::min(num_heads * head_size, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (kv_cache_dtype == "auto") { - if (key.dtype() == at::ScalarType::Float) { - CALL_RESHAPE_AND_CACHE(float, float, false); - } else if (key.dtype() == at::ScalarType::Half) { - CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false); - } else if (key.dtype() == at::ScalarType::BFloat16) { - CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false); - } - } else if (kv_cache_dtype == "fp8") { - if (key.dtype() == at::ScalarType::Float) { - CALL_RESHAPE_AND_CACHE(float, uint8_t, true); - } else if (key.dtype() == at::ScalarType::Half) { - CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true); - } else if (key.dtype() == at::ScalarType::BFloat16) { - CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true); - } - } else { - TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); - } + + DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, + CALL_RESHAPE_AND_CACHE) } void reshape_and_cache_flash( - torch::Tensor& key, // [num_tokens, num_heads, head_size] - torch::Tensor& value, // [num_tokens, num_heads, head_size] - torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& slot_mapping, // [num_tokens] - const std::string& kv_cache_dtype) -{ + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& slot_mapping, // [num_tokens] + const std::string& kv_cache_dtype) { // FIXME: only support auto datatype, does not support fp8 if (kv_cache_dtype != "auto") { TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); @@ -327,63 +303,47 @@ void reshape_and_cache_flash( const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( - key.scalar_type(), - "reshape_and_cache_flash", - [&] { - vllm::reshape_and_cache_flash_kernel<<>>( - key.data_ptr(), - value.data_ptr(), - k_cache.data_ptr(), - v_cache.data_ptr(), - slot_mapping.data_ptr(), - block_stride, - key_stride, - value_stride, - num_heads, - head_size, - block_size); - }); + key.scalar_type(), "reshape_and_cache_flash", [&] { + vllm::reshape_and_cache_flash_kernel + <<>>( + key.data_ptr(), value.data_ptr(), + k_cache.data_ptr(), v_cache.data_ptr(), + slot_mapping.data_ptr(), block_stride, key_stride, + value_stride, num_heads, head_size, block_size); + }); } namespace vllm { -template -__global__ void convert_fp8_kernel( - const Tin* __restrict__ src_cache, - Tout* __restrict__ dst_cache, - const int64_t block_stride) { +template +__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache, + Tout* __restrict__ dst_cache, + const float kv_scale, + const int64_t block_stride) { const int64_t block_idx = blockIdx.x; for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { int64_t idx = block_idx * block_stride + i; -#if defined(ENABLE_FP8_E5M2) - dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion(src_cache[idx]); -#elif defined(ENABLE_FP8_E4M3) - dst_cache[idx] = fp8_e4m3::vec_conversion(src_cache[idx]); -#else - assert(false); -#endif + dst_cache[idx] = + fp8::scaled_convert(src_cache[idx], kv_scale); } } -} // namespace vllm +} // namespace vllm -#define CALL_CONVERT_FP8(Tout, Tin) \ - vllm::convert_fp8_kernel<<>>( \ - reinterpret_cast(src_cache.data_ptr()), \ - reinterpret_cast(dst_cache.data_ptr()), \ - block_stride); +#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \ + vllm::convert_fp8_kernel<<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst_cache.data_ptr()), kv_scale, block_stride); -void convert_fp8( - torch::Tensor& src_cache, - torch::Tensor& dst_cache) -{ +// Only for testing. +void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, + const double kv_scale, const std::string& kv_cache_dtype) { torch::Device src_device = src_cache.device(); torch::Device dst_device = dst_cache.device(); TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU") - TORCH_CHECK( - src_device.index() == dst_device.index(), - "src and dst must be on the same GPU"); + TORCH_CHECK(src_device.index() == dst_device.index(), + "src and dst must be on the same GPU"); at::cuda::OptionalCUDAGuard device_guard(src_device); int64_t num_blocks = src_cache.size(0); @@ -393,17 +353,37 @@ void convert_fp8( dim3 block(std::min(block_stride, int64_t(512))); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (src_cache.dtype() == at::ScalarType::Float) { - CALL_CONVERT_FP8(uint8_t, float); - } else if (src_cache.dtype() == at::ScalarType::Half) { - CALL_CONVERT_FP8(uint8_t, uint16_t); - } else if (src_cache.dtype() == at::ScalarType::BFloat16) { - CALL_CONVERT_FP8(uint8_t, __nv_bfloat16); - } else if (dst_cache.dtype() == at::ScalarType::Float) { - CALL_CONVERT_FP8(float, uint8_t); - } else if (dst_cache.dtype() == at::ScalarType::Half) { - CALL_CONVERT_FP8(uint16_t, uint8_t); - } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { - CALL_CONVERT_FP8(__nv_bfloat16, uint8_t); + if (kv_cache_dtype == "auto") { + if (src_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto); + } else if (src_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); + } else if (src_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); + } else if (dst_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto); + } else if (dst_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto); + } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto); + } + } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") { + if (src_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (src_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (src_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, + vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (dst_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (dst_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3); + } + } else { + TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype); } } diff --git a/csrc/cpu/activation.cpp b/csrc/cpu/activation.cpp index 1bd24eb79d129..039b8d5c30d46 100644 --- a/csrc/cpu/activation.cpp +++ b/csrc/cpu/activation.cpp @@ -1,10 +1,10 @@ #include "cpu_types.hpp" namespace { -template -void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input, - scalar_t *__restrict__ output) { +void activation_kernel(int num_tokens, int d, scalar_t* __restrict__ input, + scalar_t* __restrict__ output) { using scalar_vec_t = vec_op::vec_t; constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); @@ -34,13 +34,13 @@ void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input, } } -FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8 &x) { +FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8& x) { const vec_op::FP32Vec8 zeros(0.0); const vec_op::FP32Vec8 ones(1.0); return x / (ones + (zeros - x).exp()); } -FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) { +FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8& x) { const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 w1(0.79788456f); const vec_op::FP32Vec8 w2(0.044715f); @@ -50,7 +50,7 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) { return w3 * x * (ones + t); } -FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) { +FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) { const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 w1(0.79788456f); const vec_op::FP32Vec8 w2(0.044715f); @@ -59,14 +59,21 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) { return w3 * x * (ones + t); } -FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8 &x) { +FORCE_INLINE vec_op::FP32Vec8 gelu_quick_act(const vec_op::FP32Vec8& x) { + const vec_op::FP32Vec8 zeros(0.0); + const vec_op::FP32Vec8 ones(1.0); + const vec_op::FP32Vec8 w1(1.702f); + return x / (ones + (zeros - w1 * x).exp()); +} + +FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8& x) { const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 w1(M_SQRT1_2); const vec_op::FP32Vec8 w2(0.5); return x * w2 * (ones + (x * w1).er()); } -FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) { +FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8& x) { const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5); const vec_op::FP32Vec8 w2(0.5); @@ -75,40 +82,36 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) { const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3); return x * w2 * (ones + inner.tanh()); } -}; // namespace +}; // namespace -void silu_and_mul(torch::Tensor &out, torch::Tensor &input) { +void silu_and_mul(torch::Tensor& out, torch::Tensor& input) { int num_tokens = input.numel() / input.size(-1); int d = input.size(-1) / 2; - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "silu_and_mul_impl", [&] { - CPU_KERNEL_GUARD_IN(silu_and_mul_impl) - activation_kernel(num_tokens, d, - input.data_ptr(), - out.data_ptr()); - CPU_KERNEL_GUARD_OUT(silu_and_mul_impl) - }); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_impl", [&] { + CPU_KERNEL_GUARD_IN(silu_and_mul_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), out.data_ptr()); + CPU_KERNEL_GUARD_OUT(silu_and_mul_impl) + }); } -void gelu_and_mul(torch::Tensor &out, // [..., d] - torch::Tensor &input) // [..., 2 * d] +void gelu_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] { int num_tokens = input.numel() / input.size(-1); int d = input.size(-1) / 2; - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "gelu_and_mul_impl", [&] { - CPU_KERNEL_GUARD_IN(gelu_and_mul_impl) - activation_kernel(num_tokens, d, - input.data_ptr(), - out.data_ptr()); - CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl) - }); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul_impl", [&] { + CPU_KERNEL_GUARD_IN(gelu_and_mul_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), out.data_ptr()); + CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl) + }); } -void gelu_tanh_and_mul(torch::Tensor &out, // [..., d] - torch::Tensor &input) // [..., 2 * d] +void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] { int num_tokens = input.numel() / input.size(-1); int d = input.size(-1) / 2; @@ -123,7 +126,7 @@ void gelu_tanh_and_mul(torch::Tensor &out, // [..., d] }); } -void gelu_new(torch::Tensor &out, torch::Tensor &input) { +void gelu_new(torch::Tensor& out, torch::Tensor& input) { int num_tokens = input.numel() / input.size(-1); int d = input.size(-1); @@ -135,7 +138,7 @@ void gelu_new(torch::Tensor &out, torch::Tensor &input) { }); } -void gelu_fast(torch::Tensor &out, torch::Tensor &input) { +void gelu_fast(torch::Tensor& out, torch::Tensor& input) { int num_tokens = input.numel() / input.size(-1); int d = input.size(-1); @@ -146,3 +149,15 @@ void gelu_fast(torch::Tensor &out, torch::Tensor &input) { CPU_KERNEL_GUARD_OUT(gelu_fast_impl) }); } + +void gelu_quick(torch::Tensor& out, torch::Tensor& input) { + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1); + + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_quick_impl", [&] { + CPU_KERNEL_GUARD_IN(gelu_quick_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), out.data_ptr()); + CPU_KERNEL_GUARD_OUT(gelu_quick_impl) + }); +} diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index c1d765be05598..8367093325314 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -2,7 +2,8 @@ namespace { -template struct KernelVecType { +template +struct KernelVecType { using q_load_vec_type = void; using q_vec_type = void; using k_load_vec_type = void; @@ -11,7 +12,8 @@ template struct KernelVecType { using v_load_vec_type = void; }; -template <> struct KernelVecType { +template <> +struct KernelVecType { using q_load_vec_type = vec_op::FP32Vec4; using q_vec_type = vec_op::FP32Vec16; using k_load_vec_type = vec_op::FP32Vec16; @@ -21,7 +23,8 @@ template <> struct KernelVecType { }; #ifdef __AVX512BF16__ -template <> struct KernelVecType { +template <> +struct KernelVecType { using q_load_vec_type = vec_op::BF16Vec8; using q_vec_type = vec_op::BF16Vec32; using k_load_vec_type = vec_op::BF16Vec32; @@ -30,7 +33,8 @@ template <> struct KernelVecType { using v_load_vec_type = vec_op::BF16Vec16; }; #else -template <> struct KernelVecType { +template <> +struct KernelVecType { using q_load_vec_type = vec_op::BF16Vec8; using q_vec_type = vec_op::FP32Vec16; using k_load_vec_type = vec_op::BF16Vec16; @@ -41,7 +45,7 @@ template <> struct KernelVecType { #endif template -FORCE_INLINE std::pair reduceSoftmax(T *data, const int size, +FORCE_INLINE std::pair reduceSoftmax(T* data, const int size, const int capacity) { T max = data[0]; for (int i = 1; i < size; ++i) { @@ -67,10 +71,11 @@ FORCE_INLINE std::pair reduceSoftmax(T *data, const int size, } template -FORCE_INLINE std::pair -reduceSoftmaxAlibi(T *data, const int size, const int capacity, - const float alibi_slope, const int start_index, - const int seq_len) { +FORCE_INLINE std::pair reduceSoftmaxAlibi(T* data, const int size, + const int capacity, + const float alibi_slope, + const int start_index, + const int seq_len) { data[0] += alibi_slope * (start_index - seq_len + 1); T max = data[0]; for (int i = 1; i < size; ++i) { @@ -98,7 +103,7 @@ reduceSoftmaxAlibi(T *data, const int size, const int capacity, } template -FORCE_INLINE void reducePartitonSoftmax(const T *max_data, T *sum_data, +FORCE_INLINE void reducePartitonSoftmax(const T* max_data, T* sum_data, const int size) { T max = max_data[0]; for (int i = 1; i < size; ++i) { @@ -132,9 +137,9 @@ struct reduceQKBlockKernel { static_assert(k_load_vec_type::get_elem_num() % x == 0); static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16); - FORCE_INLINE static void call(const scalar_t *__restrict__ q, - const scalar_t *__restrict__ k_block, - float *__restrict__ logits, float scale, + FORCE_INLINE static void call(const scalar_t* __restrict__ q, + const scalar_t* __restrict__ k_block, + float* __restrict__ logits, float scale, const int token_num) { const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP; @@ -196,8 +201,8 @@ struct reduceQKBlockKernel { template -FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block, - acc_t &&acc) { +FORCE_INLINE void reduceValueBlock(const float* prob, const scalar_t* v_block, + acc_t&& acc) { using v_load_vec_type = typename KernelVecType::v_load_vec_type; constexpr int ELEM_NUM = v_load_vec_type::get_elem_num(); static_assert(BLOCK_SIZE == ELEM_NUM); @@ -209,27 +214,27 @@ FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block, acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec; }); } -}; // namespace +}; // namespace // Paged attention v1 namespace { template struct paged_attention_v1_impl { - static void - call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] - const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, + static void call( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, // head_size/x, block_size, x] - const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, // head_size, block_size] - const int num_kv_heads, const float scale, - const int - *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int *__restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float *__restrict__ alibi_slopes, // [num_heads] - const int q_stride, const int kv_block_stride, const int kv_head_stride, - const int num_seqs, const int num_heads) { + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, + // max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const int num_seqs, const int num_heads) { constexpr int x = 16 / sizeof(scalar_t); const int num_queries_per_kv = num_heads / num_kv_heads; @@ -243,32 +248,31 @@ struct paged_attention_v1_impl { size_t logits_bytes = parallel_work_item_num * max_seq_len_padded * sizeof(float); - float *logits = (float *)std::aligned_alloc( - 64, logits_bytes); // Cacheline alignment for each context token. - // [parallel_work_item_num, max_seq_len_padded] + float* logits = (float*)std::aligned_alloc( + 64, logits_bytes); // Cacheline alignment for each context token. + // [parallel_work_item_num, max_seq_len_padded] #pragma omp parallel for collapse(2) schedule(dynamic, 1) for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { int seq_len = seq_lens[seq_idx]; - const int *seq_block_table = + const int* seq_block_table = block_tables + max_num_blocks_per_seq * seq_idx; const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE; const int64_t kv_head_idx = head_idx / num_queries_per_kv; - const scalar_t *__restrict__ q_vec_ptr = + const scalar_t* __restrict__ q_vec_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; - const int last_block_token_num = - seq_len - (block_num - 1) * BLOCK_SIZE; - float *__restrict__ thread_block_logits = + const int last_block_token_num = seq_len - (block_num - 1) * BLOCK_SIZE; + float* __restrict__ thread_block_logits = logits + omp_get_thread_num() * max_seq_len_padded; // Compute logits for (int block_idx = 0; block_idx < block_num; ++block_idx) { const int64_t physical_block_idx = seq_block_table[block_idx]; - const scalar_t *__restrict__ k_block_cache_ptr = + const scalar_t* __restrict__ k_block_cache_ptr = k_cache + physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride; - float *__restrict__ head_block_logits = + float* __restrict__ head_block_logits = thread_block_logits + block_idx * BLOCK_SIZE; reduceQKBlockKernel::call( @@ -282,8 +286,7 @@ struct paged_attention_v1_impl { block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0, seq_len); } else { - reduceSoftmax(thread_block_logits, seq_len, - block_num * BLOCK_SIZE); + reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE); } // Compute value @@ -293,14 +296,14 @@ struct paged_attention_v1_impl { for (int head_part_idx = 0; head_part_idx < head_partition_num; ++head_part_idx) { vec_op::FP32Vec16 accums[head_elem_num_per_partition]; - scalar_t *__restrict__ out_ptr = + scalar_t* __restrict__ out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_part_idx * head_elem_num_per_partition; for (int block_idx = 0; block_idx < block_num; ++block_idx) { const int64_t physical_block_idx = seq_block_table[block_idx]; - const float *__restrict__ prob_vec_ptr = + const float* __restrict__ prob_vec_ptr = thread_block_logits + block_idx * BLOCK_SIZE; - const scalar_t *__restrict__ v_block_cache_ptr = + const scalar_t* __restrict__ v_block_cache_ptr = v_cache + physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; @@ -311,7 +314,7 @@ struct paged_attention_v1_impl { if (block_idx != block_num - 1) { const int64_t next_physical_block_idx = seq_block_table[block_idx + 1]; - const scalar_t *__restrict__ next_v_block_cache_ptr = + const scalar_t* __restrict__ next_v_block_cache_ptr = v_cache + next_physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; @@ -340,16 +343,16 @@ struct paged_attention_v1_impl { #define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ paged_attention_v1_impl::call( \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ - block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ + block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \ num_heads); template void paged_attention_v1_impl_launcher( - torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache, - torch::Tensor &value_cache, int num_kv_heads, float scale, - torch::Tensor &block_tables, torch::Tensor &seq_lens, - int max_seq_len, const c10::optional &alibi_slopes) { + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const c10::optional& alibi_slopes) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -359,68 +362,74 @@ void paged_attention_v1_impl_launcher( int kv_head_stride = key_cache.stride(1); // NOTE: alibi_slopes is optional. - const float *alibi_slopes_ptr = + const float* alibi_slopes_ptr = alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) + ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; - T *out_ptr = reinterpret_cast(out.data_ptr()); - T *query_ptr = reinterpret_cast(query.data_ptr()); - T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int *block_tables_ptr = block_tables.data_ptr(); - int *seq_lens_ptr = seq_lens.data_ptr(); + T* out_ptr = reinterpret_cast(out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); switch (head_size) { - case 64: - LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); - break; - case 80: - LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); - break; - case 96: - LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); - break; - case 112: - LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); - break; - case 128: - LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); - break; - case 256: - LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; + case 64: + LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); + break; + case 80: + LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); + break; + case 96: + LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); + break; + case 112: + LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); + break; + case 128: + LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); + break; + case 192: + LAUNCH_V1_ATTENTION_KERNEL(T, 192, BLOCK_SIZE); + break; + case 256: + LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; } } -#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v1_impl_launcher( \ - out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ +#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v1_impl_launcher( \ + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ seq_lens, max_seq_len, alibi_slopes); -#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 16: \ - CALL_V1_KERNEL_LAUNCHER(T, 16); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 16: \ + CALL_V1_KERNEL_LAUNCHER(T, 16); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } -} // namespace - -void paged_attention_v1(torch::Tensor &out, torch::Tensor &query, - torch::Tensor &key_cache, torch::Tensor &value_cache, - int num_kv_heads, float scale, - torch::Tensor &block_tables, - torch::Tensor &seq_lens, int block_size, - int max_seq_len, - const c10::optional &alibi_slopes, - const std::string &kv_cache_dtype, float kv_scale) { +} // namespace + +void paged_attention_v1( + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { TORCH_CHECK(kv_scale == 1.0f); + TORCH_CHECK(blocksparse_vert_stride <= 1, + "CPU backend does not support blocksparse attention yet."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", [&] { CPU_KERNEL_GUARD_IN(paged_attention_v1_impl) @@ -434,23 +443,24 @@ namespace { template struct paged_attention_v2_impl { static void call( - scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] - float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float - *__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t *__restrict__ tmp_out, // [num_seqs, num_heads, - // max_num_partitions, head_size] - const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] const int num_kv_heads, const float scale, - const int - *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int *__restrict__ seq_lens, // [num_seqs] + const int* __restrict__ block_tables, // [num_seqs, + // max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, - const float *__restrict__ alibi_slopes, // [num_heads] + const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, const int num_seqs, const int num_heads, const int max_num_partitions) { constexpr int x = 16 / sizeof(scalar_t); @@ -468,8 +478,7 @@ struct paged_attention_v2_impl { const int seq_len = seq_lens[seq_idx]; const int start_token_idx = partition_idx * PARTITION_SIZE; - if (start_token_idx >= seq_len) - continue; + if (start_token_idx >= seq_len) continue; const int partition_num = (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; @@ -477,15 +486,14 @@ struct paged_attention_v2_impl { const int token_num = (std::min(seq_len, start_token_idx + PARTITION_SIZE) - start_token_idx); - const int block_num = - (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; const int last_block_token_num = token_num - (block_num - 1) * BLOCK_SIZE; - const int *seq_block_table = block_tables + + const int* seq_block_table = block_tables + max_num_blocks_per_seq * seq_idx + start_token_idx / BLOCK_SIZE; const int64_t kv_head_idx = head_idx / num_queries_per_kv; - const scalar_t *__restrict__ q_vec_ptr = + const scalar_t* __restrict__ q_vec_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0}; @@ -493,10 +501,10 @@ struct paged_attention_v2_impl { // Compute logits for (int block_idx = 0; block_idx < block_num; ++block_idx) { const int64_t physical_block_idx = seq_block_table[block_idx]; - const scalar_t *__restrict__ k_block_cache_ptr = + const scalar_t* __restrict__ k_block_cache_ptr = k_cache + physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride; - float *__restrict__ head_block_logits = + float* __restrict__ head_block_logits = logits + block_idx * BLOCK_SIZE; reduceQKBlockKernel::call( @@ -510,13 +518,13 @@ struct paged_attention_v2_impl { logits, token_num, block_num * BLOCK_SIZE, alibi_slopes[head_idx], start_token_idx, seq_len); } else { - max_and_sum = reduceSoftmax(logits, token_num, - block_num * BLOCK_SIZE); + max_and_sum = + reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE); } - auto &&[max_logit, exp_sum] = max_and_sum; + auto&& [max_logit, exp_sum] = max_and_sum; - scalar_t *__restrict__ output_buffer = nullptr; + scalar_t* __restrict__ output_buffer = nullptr; if (!no_reduce) { auto idx = seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions + partition_idx; @@ -538,13 +546,13 @@ struct paged_attention_v2_impl { for (int head_part_idx = 0; head_part_idx < head_partition_num; ++head_part_idx) { vec_op::FP32Vec16 accums[head_elem_num_per_partition]; - scalar_t *__restrict__ out_ptr = + scalar_t* __restrict__ out_ptr = output_buffer + head_part_idx * head_elem_num_per_partition; for (int block_idx = 0; block_idx < block_num; ++block_idx) { const int64_t physical_block_idx = seq_block_table[block_idx]; - const float *__restrict__ prob_vec_ptr = + const float* __restrict__ prob_vec_ptr = logits + block_idx * BLOCK_SIZE; - const scalar_t *__restrict__ v_block_cache_ptr = + const scalar_t* __restrict__ v_block_cache_ptr = v_cache + physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; @@ -555,7 +563,7 @@ struct paged_attention_v2_impl { if (block_idx != block_num - 1) { const int64_t next_physical_block_idx = seq_block_table[block_idx + 1]; - const scalar_t *__restrict__ next_v_block_cache_ptr = + const scalar_t* __restrict__ next_v_block_cache_ptr = v_cache + next_physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; @@ -587,8 +595,7 @@ struct paged_attention_v2_impl { const int partition_num = (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; - if (partition_num == 1) - continue; + if (partition_num == 1) continue; reducePartitonSoftmax( max_logits + seq_idx * num_heads * max_num_partitions + @@ -603,11 +610,11 @@ struct paged_attention_v2_impl { using v_load_vec_type = typename KernelVecType::v_load_vec_type; static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE); constexpr int head_elem_num_per_group = - 16; // Note: didn't align with the cacheline size, due to some HEAD_SIZE - // didn't align with 64 bytes + 16; // Note: didn't align with the cacheline size, due to some + // HEAD_SIZE didn't align with 64 bytes static_assert(HEAD_SIZE % head_elem_num_per_group == 0); constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group; - const float *__restrict__ rescale_factors = exp_sums; + const float* __restrict__ rescale_factors = exp_sums; #pragma omp parallel for collapse(3) schedule(static, 1) for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { @@ -616,17 +623,16 @@ struct paged_attention_v2_impl { const int partition_num = (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; - if (partition_num == 1) - continue; + if (partition_num == 1) continue; - const float *__restrict__ seq_head_rescale_factors = + const float* __restrict__ seq_head_rescale_factors = rescale_factors + seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions; - const scalar_t *__restrict__ seq_head_tmp_out = + const scalar_t* __restrict__ seq_head_tmp_out = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + head_idx * max_num_partitions * HEAD_SIZE + group_idx * head_elem_num_per_group; - scalar_t *__restrict__ seq_head_output = + scalar_t* __restrict__ seq_head_output = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + group_idx * head_elem_num_per_group; @@ -645,21 +651,21 @@ struct paged_attention_v2_impl { } }; -#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ - paged_attention_v2_impl::call( \ - out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ - key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ - seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ - kv_block_stride, kv_head_stride, num_seqs, num_heads, \ +#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ + paged_attention_v2_impl::call( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ + key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ + seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ + kv_block_stride, kv_head_stride, num_seqs, num_heads, \ max_num_partitions); template void paged_attention_v2_impl_launcher( - torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits, - torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache, - torch::Tensor &value_cache, int num_kv_heads, float scale, - torch::Tensor &block_tables, torch::Tensor &seq_lens, int block_size, - int max_seq_len, const c10::optional &alibi_slopes) { + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, + int max_seq_len, const c10::optional& alibi_slopes) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -670,73 +676,79 @@ void paged_attention_v2_impl_launcher( int max_num_partitions = exp_sums.size(-1); // NOTE: alibi_slopes is optional. - const float *alibi_slopes_ptr = + const float* alibi_slopes_ptr = alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) + ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; - T *out_ptr = reinterpret_cast(out.data_ptr()); - float *exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); - float *max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); - T *tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); - T *query_ptr = reinterpret_cast(query.data_ptr()); - T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int *block_tables_ptr = block_tables.data_ptr(); - int *seq_lens_ptr = seq_lens.data_ptr(); + T* out_ptr = reinterpret_cast(out.data_ptr()); + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); switch (head_size) { - case 64: - LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); - break; - case 80: - LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); - break; - case 96: - LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); - break; - case 112: - LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); - break; - case 128: - LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); - break; - case 256: - LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; + case 64: + LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); + break; + case 80: + LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); + break; + case 96: + LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); + break; + case 112: + LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); + break; + case 128: + LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); + break; + case 192: + LAUNCH_V2_ATTENTION_KERNEL(T, 192, BLOCK_SIZE); + break; + case 256: + LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; } } -#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v2_impl_launcher( \ - out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, seq_lens, block_size, \ - max_seq_len, alibi_slopes); - -#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 16: \ - CALL_V2_KERNEL_LAUNCHER(T, 16); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v2_impl_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \ + alibi_slopes); + +#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 16: \ + CALL_V2_KERNEL_LAUNCHER(T, 16); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } -} // namespace - -void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums, - torch::Tensor &max_logits, torch::Tensor &tmp_out, - torch::Tensor &query, torch::Tensor &key_cache, - torch::Tensor &value_cache, int num_kv_heads, - float scale, torch::Tensor &block_tables, - torch::Tensor &seq_lens, int block_size, - int max_seq_len, - const c10::optional &alibi_slopes, - const std::string &kv_cache_dtype, float kv_scale) { +} // namespace + +void paged_attention_v2( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { TORCH_CHECK(kv_scale == 1.0f); + TORCH_CHECK(blocksparse_vert_stride <= 1, + "CPU backend does not support blocksparse attention yet."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", [&] { CPU_KERNEL_GUARD_IN(paged_attention_v2_impl) diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp index 95e3f11900fde..2b5c3bd6ee70b 100644 --- a/csrc/cpu/cache.cpp +++ b/csrc/cpu/cache.cpp @@ -5,25 +5,26 @@ namespace { template -void copy_blocks_cpu_impl( - std::vector &key_caches, - std::vector &value_caches, - const torch::Tensor& mapping_pairs, - const int element_num_per_block, const int layer_num) { +void copy_blocks_cpu_impl(std::vector const& key_caches, + std::vector const& value_caches, + const torch::Tensor& mapping_pairs, + const int element_num_per_block, + const int layer_num) { const size_t pair_num = mapping_pairs.size(0); const size_t block_bytes = sizeof(scalar_t) * element_num_per_block; #pragma omp parallel for collapse(2) for (int layer = 0; layer < layer_num; ++layer) { for (size_t pair = 0; pair < pair_num; ++pair) { - int64_t source_offset = element_num_per_block * mapping_pairs[pair][0].item(); + int64_t source_offset = + element_num_per_block * mapping_pairs[pair][0].item(); int64_t target_offset = element_num_per_block * mapping_pairs[pair][1].item(); - scalar_t *key_cache_ptr = key_caches[layer].data_ptr(); - scalar_t *source_ptr = key_cache_ptr + source_offset; - scalar_t *target_ptr = key_cache_ptr + target_offset; + scalar_t* key_cache_ptr = key_caches[layer].data_ptr(); + scalar_t* source_ptr = key_cache_ptr + source_offset; + scalar_t* target_ptr = key_cache_ptr + target_offset; std::memcpy(target_ptr, source_ptr, block_bytes); - scalar_t *value_cache_ptr = value_caches[layer].data_ptr(); + scalar_t* value_cache_ptr = value_caches[layer].data_ptr(); source_ptr = value_cache_ptr + source_offset; target_ptr = value_cache_ptr + target_offset; std::memcpy(target_ptr, source_ptr, block_bytes); @@ -33,9 +34,9 @@ void copy_blocks_cpu_impl( template void reshape_and_cache_cpu_impl( - const scalar_t *__restrict__ key, const scalar_t *__restrict__ value, - scalar_t *__restrict__ key_cache, scalar_t *__restrict__ value_cache, - const int64_t *__restrict__ slot_mapping, const int num_tokens, + const scalar_t* __restrict__ key, const scalar_t* __restrict__ value, + scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache, + const int64_t* __restrict__ slot_mapping, const int num_tokens, const int key_stride, const int value_stride, const int num_heads, const int head_size, const int block_size, const int x) { const int block_elem_num = num_heads * head_size * block_size; @@ -48,14 +49,14 @@ void reshape_and_cache_cpu_impl( int src_key_head_idx = token_idx * key_stride + head_idx * head_size; int src_value_head_idx = token_idx * value_stride + head_idx * head_size; - const scalar_t *src_key_head_ptr = key + src_key_head_idx; - const scalar_t *src_value_head_ptr = value + src_value_head_idx; + const scalar_t* src_key_head_ptr = key + src_key_head_idx; + const scalar_t* src_value_head_ptr = value + src_value_head_idx; const int64_t block_index = slot_idx / block_size; const int64_t block_offset = slot_idx % block_size; - scalar_t *target_key_head_ptr = key_cache + + scalar_t* target_key_head_ptr = key_cache + block_elem_num * block_index + head_idx * block_size * head_size; - scalar_t *target_value_head_ptr = value_cache + + scalar_t* target_value_head_ptr = value_cache + block_elem_num * block_index + head_idx * block_size * head_size; @@ -79,12 +80,15 @@ void reshape_and_cache_cpu_impl( } } } -}; // namespace +}; // namespace -void copy_blocks(std::vector &key_caches, - std::vector &value_caches, - torch::Tensor& block_mapping) { - int num_layers = key_caches.size(); +// Note: the key_caches and value_caches vectors are constant but +// not the Tensors they contain. The vectors need to be const refs +// in order to satisfy pytorch's C++ operator registration code. +void copy_blocks(std::vector const& key_caches, + std::vector const& value_caches, + const torch::Tensor& block_mapping) { + unsigned num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); if (num_layers == 0) { return; @@ -100,10 +104,10 @@ void copy_blocks(std::vector &key_caches, }); } -void reshape_and_cache(torch::Tensor &key, torch::Tensor &value, - torch::Tensor &key_cache, torch::Tensor &value_cache, - torch::Tensor &slot_mapping, - const std::string &kv_cache_dtype, float kv_scale) { +void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, + torch::Tensor& key_cache, torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype, double kv_scale) { TORCH_CHECK(kv_scale == 1.0f); int num_tokens = key.size(0); @@ -127,7 +131,7 @@ void reshape_and_cache(torch::Tensor &key, torch::Tensor &value, }); } -void swap_blocks(torch::Tensor &src, torch::Tensor &dst, - const std::map &block_mapping) { +void swap_blocks(torch::Tensor& src, torch::Tensor& dst, + const torch::Tensor& block_mapping) { TORCH_CHECK(false, "swap_blocks is unsupported on CPU.") } diff --git a/csrc/cpu/cpu_types.hpp b/csrc/cpu/cpu_types.hpp index c1d3ec058b991..0213be09105ed 100644 --- a/csrc/cpu/cpu_types.hpp +++ b/csrc/cpu/cpu_types.hpp @@ -2,351 +2,14 @@ #ifndef CPU_TYPES_HPP #define CPU_TYPES_HPP -#include -#include - -namespace vec_op { - -// FIXME: FP16 is not fully supported in Torch-CPU -#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) - -#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) - -#ifndef CPU_OP_GUARD -#define CPU_KERNEL_GUARD_IN(NAME) -#define CPU_KERNEL_GUARD_OUT(NAME) +#if defined(__x86_64__) + //x86 implementation + #include "cpu_types_x86.hpp" +#elif defined(__POWER9_VECTOR__) + //ppc implementation + #include "cpu_types_vsx.hpp" #else -#define CPU_KERNEL_GUARD_IN(NAME) \ - std::cout << #NAME << " invoked." << std::endl; -#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl; -#endif - -#define FORCE_INLINE __attribute__((always_inline)) inline - -namespace { -template -constexpr void unroll_loop_item(std::integer_sequence, F &&f) { - (f(std::integral_constant{}), ...); -} -}; // namespace - -template >> -constexpr void unroll_loop(F &&f) { - unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); -} - -template struct Vec { - constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } -}; - -struct FP32Vec8; -struct FP32Vec16; - -#ifdef __AVX512FP16__ -struct FP16Vec8 : public Vec { - constexpr static int VEC_ELEM_NUM = 8; - - __m128h reg; - - explicit FP16Vec8(_Float16 v) : reg(_mm_set1_ph(v)) {} - - explicit FP16Vec8(const void *ptr) : reg(_mm_loadu_ph(ptr)) {} - - explicit FP16Vec8(__m128h data) : reg(data) {} - - FP16Vec8 operator*(const FP16Vec8 &b) const { - return FP16Vec8(_mm_mul_ph(reg, b.reg)); - } - - FP16Vec8 operator+(const FP16Vec8 &b) const { - return FP16Vec8(_mm_add_ph(reg, b.reg)); - } - - FP16Vec8 operator-(const FP16Vec8 &b) const { - return FP16Vec8(_mm_sub_ph(reg, b.reg)); - } - - FP16Vec8 operator/(const FP16Vec8 &b) const { - return FP16Vec8(_mm_div_ph(reg, b.reg)); - } - - void save(void *ptr) const { _mm_storeu_ph(ptr, reg); } -}; + #warning "unsupported vLLM cpu implementation" #endif -struct BF16Vec8 : public Vec { - constexpr static int VEC_ELEM_NUM = 8; - - __m128i reg; - - explicit BF16Vec8(const void *ptr) - : reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {} - - explicit BF16Vec8(const FP32Vec8 &); - - void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; } -}; - -struct BF16Vec16 : public Vec { - constexpr static int VEC_ELEM_NUM = 16; - - __m256i reg; - - explicit BF16Vec16(const void *ptr) - : reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {} - - explicit BF16Vec16(const FP32Vec16 &); - - void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; } -}; - -struct BF16Vec32 : public Vec { - constexpr static int VEC_ELEM_NUM = 32; - - __m512i reg; - - explicit BF16Vec32(const void *ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {} - - explicit BF16Vec32(__m512i data) : reg(data) {} - - explicit BF16Vec32(BF16Vec8 &vec8_data) - : reg((__m512i)_mm512_inserti32x4( - _mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512( - (__m128i)vec8_data.reg), - (__m128i)vec8_data.reg, 1), - (__m128i)vec8_data.reg, 2), - (__m128i)vec8_data.reg, 3)) {} - - void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; } -}; - -struct FP32Vec4 : public Vec { - constexpr static int VEC_ELEM_NUM = 4; - union AliasReg { - __m128 reg; - float values[VEC_ELEM_NUM]; - }; - - __m128 reg; - - explicit FP32Vec4(float v) : reg(_mm_set1_ps(v)) {} - - explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {} - - explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {} - - explicit FP32Vec4(__m128 data) : reg(data) {} - - explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {} -}; - -struct FP32Vec8 : public Vec { - constexpr static int VEC_ELEM_NUM = 8; - union AliasReg { - __m256 reg; - float values[VEC_ELEM_NUM]; - }; - - __m256 reg; - - explicit FP32Vec8(float v) : reg(_mm256_set1_ps(v)) {} - - explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {} - - explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {} - - explicit FP32Vec8(__m256 data) : reg(data) {} - - explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {} - -#ifdef __AVX512FP16__ - explicit FP32Vec8(__m128h v) : reg(_mm256_cvtph_ps(_mm_castph_si128(v))) {} -#endif - - explicit FP32Vec8(const BF16Vec8 &v) - : reg(_mm256_castsi256_ps( - _mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {} - - float reduce_sum() const { - AliasReg ar; - ar.reg = reg; - float result = 0; - unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); - - return result; - } - - FP32Vec8 exp() const { - AliasReg ar; - ar.reg = reg; - return FP32Vec8(_mm256_set_ps(expf(ar.values[7]), expf(ar.values[6]), - expf(ar.values[5]), expf(ar.values[4]), - expf(ar.values[3]), expf(ar.values[2]), - expf(ar.values[1]), expf(ar.values[0]))); - } - - FP32Vec8 tanh() const { - AliasReg ar; - ar.reg = reg; - return FP32Vec8(_mm256_set_ps(tanhf(ar.values[7]), tanhf(ar.values[6]), - tanhf(ar.values[5]), tanhf(ar.values[4]), - tanhf(ar.values[3]), tanhf(ar.values[2]), - tanhf(ar.values[1]), tanhf(ar.values[0]))); - } - - FP32Vec8 er() const { - AliasReg ar; - ar.reg = reg; - return FP32Vec8(_mm256_set_ps(erf(ar.values[7]), erf(ar.values[6]), - erf(ar.values[5]), erf(ar.values[4]), - erf(ar.values[3]), erf(ar.values[2]), - erf(ar.values[1]), erf(ar.values[0]))); - } - - FP32Vec8 operator*(const FP32Vec8 &b) const { - return FP32Vec8(_mm256_mul_ps(reg, b.reg)); - } - - FP32Vec8 operator+(const FP32Vec8 &b) const { - return FP32Vec8(_mm256_add_ps(reg, b.reg)); - } - - FP32Vec8 operator-(const FP32Vec8 &b) const { - return FP32Vec8(_mm256_sub_ps(reg, b.reg)); - } - - FP32Vec8 operator/(const FP32Vec8 &b) const { - return FP32Vec8(_mm256_div_ps(reg, b.reg)); - } - - void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); } -}; - -struct FP32Vec16 : public Vec { - constexpr static int VEC_ELEM_NUM = 16; - union AliasReg { - __m512 reg; - float values[VEC_ELEM_NUM]; - }; - - __m512 reg; - - explicit FP32Vec16(float v) : reg(_mm512_set1_ps(v)) {} - - explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {} - - explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {} - - explicit FP32Vec16(__m512 data) : reg(data) {} - - explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {} - - explicit FP32Vec16(const FP32Vec4 &data) - : reg((__m512)_mm512_inserti32x4( - _mm512_inserti32x4( - _mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg), - (__m128i)data.reg, 1), - (__m128i)data.reg, 2), - (__m128i)data.reg, 3)) {} - - explicit FP32Vec16(const FP32Vec8 &data) - : reg((__m512)_mm512_inserti32x8( - _mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {} - - explicit FP32Vec16(const BF16Vec16 &v) - : reg(_mm512_castsi512_ps( - _mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {} - - explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} - - FP32Vec16 operator*(const FP32Vec16 &b) const { - return FP32Vec16(_mm512_mul_ps(reg, b.reg)); - } - - FP32Vec16 operator+(const FP32Vec16 &b) const { - return FP32Vec16(_mm512_add_ps(reg, b.reg)); - } - - FP32Vec16 operator-(const FP32Vec16 &b) const { - return FP32Vec16(_mm512_sub_ps(reg, b.reg)); - } - - FP32Vec16 operator/(const FP32Vec16 &b) const { - return FP32Vec16(_mm512_div_ps(reg, b.reg)); - } - - float reduce_sum() const { return _mm512_reduce_add_ps(reg); } - - template float reduce_sub_sum(int idx) { - static_assert(VEC_ELEM_NUM % group_size == 0); - constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); - __mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size)); - return _mm512_mask_reduce_add_ps(mask, reg); - } - - void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); } -}; - -template struct VecType { using vec_type = void; }; - -template using vec_t = typename VecType::vec_type; - -template <> struct VecType { using vec_type = FP32Vec8; }; - -#ifdef __AVX512FP16__ -template <> struct VecType { using vec_type = FP16Vec16; }; -#endif - -template <> struct VecType { using vec_type = BF16Vec8; }; - -template void storeFP32(float v, T *ptr) { *ptr = v; } - -#ifdef __AVX512FP16__ -template <> inline void storeFP32(float v, c10::Half *ptr) { - *reinterpret_cast<_Float16 *>(ptr) = v; -} -#endif - -inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { - acc = acc + a * b; -} - -#ifdef __AVX512BF16__ -template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { - *reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v); -} - -inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) - : reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {} - -inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) - : reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {} - -inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) { - acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg); -} -#else -template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { - c10::BFloat16 __attribute__((__may_alias__)) *v_ptr = - reinterpret_cast(&v); - *ptr = *(v_ptr + 1); -} - -inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) - : reg(_mm256_cvtepi32_epi16( - _mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {} - -inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) - : reg(_mm512_cvtepi32_epi16( - _mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {} -#endif - -inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); } - -}; // namespace vec_op - #endif diff --git a/csrc/cpu/cpu_types_vsx.hpp b/csrc/cpu/cpu_types_vsx.hpp new file mode 100644 index 0000000000000..b50bdadc5713d --- /dev/null +++ b/csrc/cpu/cpu_types_vsx.hpp @@ -0,0 +1,491 @@ + +#ifndef CPU_TYPES_VSX_HPP +#define CPU_TYPES_VSX_HPP + +#include +#include +#include + +namespace vec_op { + +// FIXME: FP16 is not fully supported in Torch-CPU +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#ifndef CPU_OP_GUARD +#define CPU_KERNEL_GUARD_IN(NAME) +#define CPU_KERNEL_GUARD_OUT(NAME) +#else +#define CPU_KERNEL_GUARD_IN(NAME) \ + std::cout << #NAME << " invoked." << std::endl; +#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl; +#endif + +#define FORCE_INLINE __attribute__((always_inline)) inline + +namespace { +template +constexpr void unroll_loop_item(std::integer_sequence, F &&f) { + (f(std::integral_constant{}), ...); +} +}; // namespace + +template >> +constexpr void unroll_loop(F &&f) { + unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); +} + +template struct Vec { + constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } +}; + +typedef struct ss16x8x2_t { + __vector signed short val[2]; +} ss16x8x2_t; + +typedef struct ss16x8x4_t { + __vector signed short val[4]; +} ss16x8x4_t; + +typedef struct f32x4x2_t { + __vector float val[2]; +} f32x4x2_t; + +typedef struct f32x4x4_t { + __vector float val[4]; +} f32x4x4_t; + +struct FP32Vec8; +struct FP32Vec16; + +struct BF16Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + + __vector signed short reg; + + explicit BF16Vec8(const void *ptr) + : reg((__vector signed short)vec_xl(0, (__vector signed short *)ptr)) {} + + explicit BF16Vec8(const FP32Vec8 &); + + void save(void *ptr) const { *reinterpret_cast<__vector signed short *>(ptr) = reg; } +}; + +struct BF16Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + + ss16x8x2_t reg; + + explicit BF16Vec16(const void *ptr) { + // Load 256 bits in two parts + reg.val[0] = (__vector signed short)vec_xl(0, (signed short *)ptr); + reg.val[1] = (__vector signed short)vec_xl(16, (signed short *)ptr); + } + + explicit BF16Vec16(const FP32Vec16 &); + + void save(void *ptr) const { + // Save 256 bits in two parts + vec_xst(reg.val[0], 0, (signed short *)ptr); + vec_xst(reg.val[1], 16, (signed short *)ptr); + } +}; + +const static __vector signed short zero = vec_splats((signed short)0); + +struct BF16Vec32 : public Vec { + constexpr static int VEC_ELEM_NUM = 32; + + ss16x8x4_t reg; + explicit BF16Vec32(const void *ptr) + : reg(*reinterpret_cast(ptr)) {} + + explicit BF16Vec32(ss16x8x4_t data) : reg(data) {} + + explicit BF16Vec32(const BF16Vec8 &vec8_data) : reg({ + vec8_data.reg, + vec8_data.reg, + vec8_data.reg, + vec8_data.reg + }) {} + + void save(void *ptr) const { *reinterpret_cast(ptr) = reg; } +}; + +struct FP32Vec4 : public Vec { + constexpr static int VEC_ELEM_NUM = 4; + union AliasReg { + __vector float reg; + float values[VEC_ELEM_NUM]; + }; + + __vector float reg; + + explicit FP32Vec4(float v) : reg(vec_splats(v)) {} + + explicit FP32Vec4() : reg(vec_splats(0.0f)) {} + + explicit FP32Vec4(const float *ptr) : reg(vec_xl(0, ptr)) {} + + explicit FP32Vec4(__vector float data) : reg(data) {} + + explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {} +}; + +struct FP32Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + union AliasReg { + f32x4x2_t reg; + float values[VEC_ELEM_NUM]; + }; + + f32x4x2_t reg; + + explicit FP32Vec8(float v) { + reg.val[0] = vec_splats(v); + reg.val[1] = vec_splats(v); + } + + explicit FP32Vec8() { + reg.val[0] = vec_splats(0.0f); + reg.val[1] = vec_splats(0.0f); + } + + explicit FP32Vec8(const float *ptr) { + reg.val[0] = vec_xl(0, ptr); + reg.val[1] = vec_xl(16, ptr); + } + + explicit FP32Vec8(f32x4x2_t data) : reg(data) {} + + explicit FP32Vec8(const FP32Vec8 &data) { + reg.val[0] = data.reg.val[0]; + reg.val[1] = data.reg.val[1]; + } + + explicit FP32Vec8(const BF16Vec8 &v) { + reg.val[0] = (__vector float)vec_mergeh(zero, v.reg); + reg.val[1] = (__vector float)vec_mergel(zero, v.reg); + } + + float reduce_sum() const { + AliasReg ar; + ar.reg = reg; + float result = 0; + unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); + + return result; + } + + FP32Vec8 exp() const { + // TODO: Vectorize this + AliasReg ar; + ar.reg = reg; + f32x4x4_t ret; + ret.val[0][0] = std::exp(ar.values[0]); + ret.val[0][1] = std::exp(ar.values[1]); + ret.val[0][2] = std::exp(ar.values[2]); + ret.val[0][3] = std::exp(ar.values[3]); + ret.val[1][0] = std::exp(ar.values[4]); + ret.val[1][1] = std::exp(ar.values[5]); + ret.val[1][2] = std::exp(ar.values[6]); + ret.val[1][3] = std::exp(ar.values[7]); + return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); + } + + FP32Vec8 tanh() const { + // TODO: Vectorize this + AliasReg ar; + ar.reg = reg; + f32x4x4_t ret; + ret.val[0][0] = std::tanh(ar.values[0]); + ret.val[0][1] = std::tanh(ar.values[1]); + ret.val[0][2] = std::tanh(ar.values[2]); + ret.val[0][3] = std::tanh(ar.values[3]); + ret.val[1][0] = std::tanh(ar.values[4]); + ret.val[1][1] = std::tanh(ar.values[5]); + ret.val[1][2] = std::tanh(ar.values[6]); + ret.val[1][3] = std::tanh(ar.values[7]); + return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); + } + + FP32Vec8 er() const { + // TODO: Vectorize this + AliasReg ar; + ar.reg = reg; + f32x4x4_t ret; + ret.val[0][0] = std::erf(ar.values[0]); + ret.val[0][1] = std::erf(ar.values[1]); + ret.val[0][2] = std::erf(ar.values[2]); + ret.val[0][3] = std::erf(ar.values[3]); + ret.val[1][0] = std::erf(ar.values[4]); + ret.val[1][1] = std::erf(ar.values[5]); + ret.val[1][2] = std::erf(ar.values[6]); + ret.val[1][3] = std::erf(ar.values[7]); + return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); + } + + FP32Vec8 operator*(const FP32Vec8 &b) const { + return FP32Vec8({vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])}); + } + + FP32Vec8 operator+(const FP32Vec8 &b) const { + return FP32Vec8({vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])}); + } + + FP32Vec8 operator-(const FP32Vec8 &b) const { + return FP32Vec8({vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])}); + } + + FP32Vec8 operator/(const FP32Vec8 &b) const { + return FP32Vec8({vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])}); + } + + void save(float *ptr) const { + vec_xst(reg.val[0], 0, ptr); + vec_xst(reg.val[1], 16, ptr); + } +}; + +struct FP32Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + f32x4x4_t reg; + float values[VEC_ELEM_NUM]; + }; + + f32x4x4_t reg; + + explicit FP32Vec16(float v) { + reg.val[0] = vec_splats(v); + reg.val[1] = vec_splats(v); + reg.val[2] = vec_splats(v); + reg.val[3] = vec_splats(v); + } + + explicit FP32Vec16() { + reg.val[0] = vec_splats(0.0f); + reg.val[1] = vec_splats(0.0f); + reg.val[2] = vec_splats(0.0f); + reg.val[3] = vec_splats(0.0f); + } + + explicit FP32Vec16(const float *ptr) { + reg.val[0] = vec_xl(0, ptr); + reg.val[1] = vec_xl(16, ptr); + reg.val[2] = vec_xl(32, ptr); + reg.val[3] = vec_xl(48, ptr); + } + + explicit FP32Vec16(f32x4x4_t data) : reg(data) {} + + explicit FP32Vec16(const FP32Vec16 &data) { + reg.val[0] = data.reg.val[0]; + reg.val[1] = data.reg.val[1]; + reg.val[2] = data.reg.val[2]; + reg.val[3] = data.reg.val[3]; + } + + explicit FP32Vec16(const FP32Vec4 &data) { + reg.val[0] = data.reg; + reg.val[1] = data.reg; + reg.val[2] = data.reg; + reg.val[3] = data.reg; + } + + explicit FP32Vec16(const FP32Vec8 &data) { + reg.val[0] = data.reg.val[0]; + reg.val[1] = data.reg.val[1]; + reg.val[2] = data.reg.val[0]; + reg.val[3] = data.reg.val[1]; + } + + explicit FP32Vec16(const BF16Vec16 &v) { + reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]); + reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]); + reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]); + reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]); + } + + explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + + FP32Vec16 operator*(const FP32Vec16 &b) const { + return FP32Vec16(f32x4x4_t({ + vec_mul(reg.val[0], b.reg.val[0]), + vec_mul(reg.val[1], b.reg.val[1]), + vec_mul(reg.val[2], b.reg.val[2]), + vec_mul(reg.val[3], b.reg.val[3])})); + } + + FP32Vec16 operator+(const FP32Vec16 &b) const { + return FP32Vec16(f32x4x4_t({ + vec_add(reg.val[0], b.reg.val[0]), + vec_add(reg.val[1], b.reg.val[1]), + vec_add(reg.val[2], b.reg.val[2]), + vec_add(reg.val[3], b.reg.val[3])})); + } + + FP32Vec16 operator-(const FP32Vec16 &b) const { + return FP32Vec16(f32x4x4_t({ + vec_sub(reg.val[0], b.reg.val[0]), + vec_sub(reg.val[1], b.reg.val[1]), + vec_sub(reg.val[2], b.reg.val[2]), + vec_sub(reg.val[3], b.reg.val[3])})); + } + + FP32Vec16 operator/(const FP32Vec16 &b) const { + return FP32Vec16(f32x4x4_t({ + vec_div(reg.val[0], b.reg.val[0]), + vec_div(reg.val[1], b.reg.val[1]), + vec_div(reg.val[2], b.reg.val[2]), + vec_div(reg.val[3], b.reg.val[3])})); + } + + float reduce_sum() const { + AliasReg ar; + ar.reg = reg; + float result = 0; + unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); + + return result; + } + + template float reduce_sub_sum(int idx) { + static_assert(VEC_ELEM_NUM % group_size == 0); + + AliasReg ar; + ar.reg = reg; + float result = 0; + const int start = idx * group_size; + unroll_loop( + [&result, &start, ar](int i) { result += ar.values[start + i]; }); + + return result; + } + + void save(float *ptr) const { + vec_xst(reg.val[0], 0, ptr); + vec_xst(reg.val[1], 16, ptr); + vec_xst(reg.val[2], 32, ptr); + vec_xst(reg.val[3], 48, ptr); + } +}; + +template struct VecType { using vec_type = void; }; + +template using vec_t = typename VecType::vec_type; + +template <> struct VecType { using vec_type = FP32Vec8; }; + +template <> struct VecType { using vec_type = BF16Vec8; }; + +template void storeFP32(float v, T *ptr) { *ptr = v; } + +inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { + acc = acc + a * b; +} + +template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { + c10::BFloat16 __attribute__((__may_alias__)) *v_ptr = + reinterpret_cast(&v); + *ptr = *(v_ptr + 1); +} + +#ifndef __VEC_CLASS_FP_NAN +#define __VEC_CLASS_FP_NAN (1 << 6) +#endif + +const static __vector unsigned char omask = { 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29 }; +#ifndef _ARCH_PWR10 +const static __vector unsigned int bias = { 0x00007fff, 0x00007fff, 0x00007fff, 0x00007fff }; +const static __vector unsigned int nan = { 0x7fc00000, 0x7fc00000, 0x7fc00000, 0x7fc00000 }; +const static __vector unsigned int sh16 = { 16, 16, 16, 16 }; +const static __vector unsigned int one = { 1, 1, 1, 1 }; +#endif + +inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) { +#ifdef _ARCH_PWR10 + __vector signed short ret[2]; + ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]); + ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]); + reg = vec_perm(ret[0], ret[1], omask); +#elif defined(_ARCH_PWR9) + __vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]); + __vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]); + __vector unsigned int lsb0 = vec_sr(inp0, sh16); + __vector unsigned int lsb1 = vec_sr(inp1, sh16); + lsb0 = vec_and(lsb0, one); + lsb1 = vec_and(lsb1, one); + __vector unsigned int rnd0 = vec_add(lsb0, bias); + __vector unsigned int rnd1 = vec_add(lsb1, bias); + inp0 = vec_add(inp0, rnd0); + inp1 = vec_add(inp1, rnd1); + __vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN); + __vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN); + inp0 = vec_sel(inp0, nan, sel0); + inp1 = vec_sel(inp1, nan, sel1); + inp0 = vec_sr(inp0, sh16); + inp1 = vec_sr(inp1, sh16); + reg = (__vector signed short)vec_perm(inp0, inp1, omask); +#endif +} + +inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) { +#ifdef _ARCH_PWR10 + __vector signed short ret[4]; + ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]); + ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]); + ret[2] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[2]); + ret[3] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[3]); + reg.val[0] = vec_perm(ret[0], ret[1], omask); + reg.val[1] = vec_perm(ret[2], ret[3], omask); +#elif defined(_ARCH_PWR9) + __vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]); + __vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]); + __vector unsigned int inp2 = (__vector unsigned int)(v.reg.val[2]); + __vector unsigned int inp3 = (__vector unsigned int)(v.reg.val[3]); + __vector unsigned int lsb0 = vec_sr(inp0, sh16); + __vector unsigned int lsb1 = vec_sr(inp1, sh16); + __vector unsigned int lsb2 = vec_sr(inp2, sh16); + __vector unsigned int lsb3 = vec_sr(inp3, sh16); + lsb0 = vec_and(lsb0, one); + lsb1 = vec_and(lsb1, one); + lsb2 = vec_and(lsb2, one); + lsb3 = vec_and(lsb3, one); + __vector unsigned int rnd0 = vec_add(lsb0, bias); + __vector unsigned int rnd1 = vec_add(lsb1, bias); + __vector unsigned int rnd2 = vec_add(lsb2, bias); + __vector unsigned int rnd3 = vec_add(lsb3, bias); + inp0 = vec_add(inp0, rnd0); + inp1 = vec_add(inp1, rnd1); + inp2 = vec_add(inp2, rnd2); + inp3 = vec_add(inp3, rnd3); + __vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN); + __vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN); + __vector __bool int sel2 = vec_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN); + __vector __bool int sel3 = vec_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN); + inp0 = vec_sel(inp0, nan, sel0); + inp1 = vec_sel(inp1, nan, sel1); + inp2 = vec_sel(inp2, nan, sel2); + inp3 = vec_sel(inp3, nan, sel3); + inp0 = vec_sr(inp0, sh16); + inp1 = vec_sr(inp1, sh16); + inp2 = vec_sr(inp2, sh16); + inp3 = vec_sr(inp3, sh16); + reg.val[0] = (__vector signed short)vec_perm(inp0, inp1, omask); + reg.val[1] = (__vector signed short)vec_perm(inp2, inp3, omask); +#endif +} + +inline void prefetch(const void *addr) { + __asm__ __volatile__("dcbt 0, %0" : : "r"(addr) : "memory"); +} + +}; // namespace vec_op + +#endif diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp new file mode 100644 index 0000000000000..f50620a5287d4 --- /dev/null +++ b/csrc/cpu/cpu_types_x86.hpp @@ -0,0 +1,515 @@ + +#ifndef CPU_TYPES_X86_HPP +#define CPU_TYPES_X86_HPP + +#include +#include + +#ifndef __AVX2__ +static_assert(false, "AVX2 must be supported for the current implementation."); +#endif + +namespace vec_op { + +// FIXME: FP16 is not fully supported in Torch-CPU +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#ifndef CPU_OP_GUARD +#define CPU_KERNEL_GUARD_IN(NAME) +#define CPU_KERNEL_GUARD_OUT(NAME) +#else +#define CPU_KERNEL_GUARD_IN(NAME) \ + std::cout << #NAME << " invoked." << std::endl; +#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl; +#endif + +#define FORCE_INLINE __attribute__((always_inline)) inline + +namespace { +template +constexpr void unroll_loop_item(std::integer_sequence, F &&f) { + (f(std::integral_constant{}), ...); +} +}; // namespace + +template >> +constexpr void unroll_loop(F &&f) { + unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); +} + +template struct Vec { + constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } +}; + +struct FP32Vec8; +struct FP32Vec16; + +#ifdef __AVX512FP16__ +struct FP16Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + + __m128h reg; + + explicit FP16Vec8(_Float16 v) : reg(_mm_set1_ph(v)) {} + + explicit FP16Vec8(const void *ptr) : reg(_mm_loadu_ph(ptr)) {} + + explicit FP16Vec8(__m128h data) : reg(data) {} + + FP16Vec8 operator*(const FP16Vec8 &b) const { + return FP16Vec8(_mm_mul_ph(reg, b.reg)); + } + + FP16Vec8 operator+(const FP16Vec8 &b) const { + return FP16Vec8(_mm_add_ph(reg, b.reg)); + } + + FP16Vec8 operator-(const FP16Vec8 &b) const { + return FP16Vec8(_mm_sub_ph(reg, b.reg)); + } + + FP16Vec8 operator/(const FP16Vec8 &b) const { + return FP16Vec8(_mm_div_ph(reg, b.reg)); + } + + void save(void *ptr) const { _mm_storeu_ph(ptr, reg); } +}; +#endif + +struct BF16Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + + __m128i reg; + + explicit BF16Vec8(const void *ptr) + : reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {} + + explicit BF16Vec8(const FP32Vec8 &); + + void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; } +}; + +struct BF16Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + + __m256i reg; + + explicit BF16Vec16(const void *ptr) + : reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {} + + explicit BF16Vec16(const FP32Vec16 &); + + void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; } +}; + +#ifdef __AVX512F__ +struct BF16Vec32 : public Vec { + constexpr static int VEC_ELEM_NUM = 32; + + __m512i reg; + + explicit BF16Vec32(const void *ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {} + + explicit BF16Vec32(__m512i data) : reg(data) {} + + explicit BF16Vec32(BF16Vec8 &vec8_data) + : reg((__m512i)_mm512_inserti32x4( + _mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512( + (__m128i)vec8_data.reg), + (__m128i)vec8_data.reg, 1), + (__m128i)vec8_data.reg, 2), + (__m128i)vec8_data.reg, 3)) {} + + void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; } +}; +#else +struct BF16Vec32 : public Vec { + constexpr static int VEC_ELEM_NUM = 32; + + __m256i reg_low; + __m256i reg_high; + + explicit BF16Vec32(const void *ptr) + : reg_low(_mm256_loadu_si256((__m256i const *)ptr)), + reg_high(_mm256_loadu_si256((__m256i const *)ptr + 1)) {} + + explicit BF16Vec32(__m256i low, __m256i high) : reg_low(low), + reg_high(high) {} + + explicit BF16Vec32(BF16Vec8 &vec8_data) + : reg_low((__m256i)_mm256_inserti32x4( + _mm256_castsi128_si256((__m128i)vec8_data.reg), + (__m128i)vec8_data.reg, 1)), + reg_high((__m256i)_mm256_inserti32x4( + _mm256_castsi128_si256((__m128i)vec8_data.reg), + (__m128i)vec8_data.reg, 1)) {} + + void save(void *ptr) const { + *reinterpret_cast<__m256i *>(ptr) = reg_low; + *reinterpret_cast<__m256i *>((__m256i *)ptr + 1) = reg_high; + } +}; +#endif + +struct FP32Vec4 : public Vec { + constexpr static int VEC_ELEM_NUM = 4; + union AliasReg { + __m128 reg; + float values[VEC_ELEM_NUM]; + }; + + __m128 reg; + + explicit FP32Vec4(float v) : reg(_mm_set1_ps(v)) {} + + explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {} + + explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {} + + explicit FP32Vec4(__m128 data) : reg(data) {} + + explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {} +}; + +struct FP32Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + union AliasReg { + __m256 reg; + float values[VEC_ELEM_NUM]; + }; + + __m256 reg; + + explicit FP32Vec8(float v) : reg(_mm256_set1_ps(v)) {} + + explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {} + + explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {} + + explicit FP32Vec8(__m256 data) : reg(data) {} + + explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {} + +#ifdef __AVX512FP16__ + explicit FP32Vec8(__m128h v) : reg(_mm256_cvtph_ps(_mm_castph_si128(v))) {} +#endif + + explicit FP32Vec8(const BF16Vec8 &v) + : reg(_mm256_castsi256_ps( + _mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {} + + float reduce_sum() const { + AliasReg ar; + ar.reg = reg; + float result = 0; + unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); + + return result; + } + + FP32Vec8 exp() const { + AliasReg ar; + ar.reg = reg; + return FP32Vec8(_mm256_set_ps(expf(ar.values[7]), expf(ar.values[6]), + expf(ar.values[5]), expf(ar.values[4]), + expf(ar.values[3]), expf(ar.values[2]), + expf(ar.values[1]), expf(ar.values[0]))); + } + + FP32Vec8 tanh() const { + AliasReg ar; + ar.reg = reg; + return FP32Vec8(_mm256_set_ps(tanhf(ar.values[7]), tanhf(ar.values[6]), + tanhf(ar.values[5]), tanhf(ar.values[4]), + tanhf(ar.values[3]), tanhf(ar.values[2]), + tanhf(ar.values[1]), tanhf(ar.values[0]))); + } + + FP32Vec8 er() const { + AliasReg ar; + ar.reg = reg; + return FP32Vec8(_mm256_set_ps(erf(ar.values[7]), erf(ar.values[6]), + erf(ar.values[5]), erf(ar.values[4]), + erf(ar.values[3]), erf(ar.values[2]), + erf(ar.values[1]), erf(ar.values[0]))); + } + + FP32Vec8 operator*(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_mul_ps(reg, b.reg)); + } + + FP32Vec8 operator+(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_add_ps(reg, b.reg)); + } + + FP32Vec8 operator-(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_sub_ps(reg, b.reg)); + } + + FP32Vec8 operator/(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_div_ps(reg, b.reg)); + } + + void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); } +}; + +#ifdef __AVX512F__ +struct FP32Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + __m512 reg; + float values[VEC_ELEM_NUM]; + }; + + __m512 reg; + + explicit FP32Vec16(float v) : reg(_mm512_set1_ps(v)) {} + + explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {} + + explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {} + + explicit FP32Vec16(__m512 data) : reg(data) {} + + explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {} + + explicit FP32Vec16(const FP32Vec4 &data) + : reg((__m512)_mm512_inserti32x4( + _mm512_inserti32x4( + _mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg), + (__m128i)data.reg, 1), + (__m128i)data.reg, 2), + (__m128i)data.reg, 3)) {} + + explicit FP32Vec16(const FP32Vec8 &data) + : reg((__m512)_mm512_inserti32x8( + _mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {} + + explicit FP32Vec16(const BF16Vec16 &v) + : reg(_mm512_castsi512_ps( + _mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {} + + explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + + FP32Vec16 operator*(const FP32Vec16 &b) const { + return FP32Vec16(_mm512_mul_ps(reg, b.reg)); + } + + FP32Vec16 operator+(const FP32Vec16 &b) const { + return FP32Vec16(_mm512_add_ps(reg, b.reg)); + } + + FP32Vec16 operator-(const FP32Vec16 &b) const { + return FP32Vec16(_mm512_sub_ps(reg, b.reg)); + } + + FP32Vec16 operator/(const FP32Vec16 &b) const { + return FP32Vec16(_mm512_div_ps(reg, b.reg)); + } + + float reduce_sum() const { return _mm512_reduce_add_ps(reg); } + + template float reduce_sub_sum(int idx) { + static_assert(VEC_ELEM_NUM % group_size == 0); + constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); + __mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size)); + return _mm512_mask_reduce_add_ps(mask, reg); + } + + void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); } +}; +#else +struct FP32Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + + union AliasReg { + __m256 reg; + float values[8]; + }; + + __m256 reg_low; + __m256 reg_high; + + explicit FP32Vec16(float v) : reg_low(_mm256_set1_ps(v)), + reg_high(_mm256_set1_ps(v)) {} + + explicit FP32Vec16() : reg_low(_mm256_set1_ps(0.0)), + reg_high(_mm256_set1_ps(0.0)) {} + + explicit FP32Vec16(const float *ptr) : reg_low(_mm256_loadu_ps(ptr)), + reg_high(_mm256_loadu_ps(ptr + 8)) {} + + explicit FP32Vec16(__m256 low, __m256 high) : reg_low(low), reg_high(high) {} + + explicit FP32Vec16(const FP32Vec16 &data) : reg_low(data.reg_low), + reg_high(data.reg_high) {} + + explicit FP32Vec16(const FP32Vec4 &data) + : reg_low((__m256)_mm256_inserti128_si256( + _mm256_castsi128_si256((__m128i)data.reg), + (__m128i)data.reg, 1)), + reg_high((__m256)_mm256_inserti128_si256( + _mm256_castsi128_si256((__m128i)data.reg), + (__m128i)data.reg, 1)) {} + + explicit FP32Vec16(const FP32Vec8 &data) + : reg_low(data.reg), reg_high(data.reg) {} + + explicit FP32Vec16(const BF16Vec16 &v) { + __m128i low = _mm256_extractf128_si256(v.reg, 0); + __m128i high = _mm256_extractf128_si256(v.reg, 1); + + __m256i v_low_epi32 = _mm256_cvtepu16_epi32(low); + __m256i v_high_epi32 = _mm256_cvtepu16_epi32(high); + + __m256i v_low_shifted = _mm256_bslli_epi128(v_low_epi32, 2); + __m256i v_high_shifted = _mm256_bslli_epi128(v_high_epi32, 2); + + reg_low = _mm256_castsi256_ps(v_low_shifted); + reg_high = _mm256_castsi256_ps(v_high_shifted); + } + + explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + + FP32Vec16 operator*(const FP32Vec16 &b) const { + return FP32Vec16(_mm256_mul_ps(reg_low, b.reg_low), + _mm256_mul_ps(reg_high, b.reg_high)); + } + + FP32Vec16 operator+(const FP32Vec16 &b) const { + return FP32Vec16(_mm256_add_ps(reg_low, b.reg_low), + _mm256_add_ps(reg_high, b.reg_high)); + } + + FP32Vec16 operator-(const FP32Vec16 &b) const { + return FP32Vec16(_mm256_sub_ps(reg_low, b.reg_low), + _mm256_sub_ps(reg_high, b.reg_high)); + } + + FP32Vec16 operator/(const FP32Vec16 &b) const { + return FP32Vec16(_mm256_div_ps(reg_low, b.reg_low), + _mm256_div_ps(reg_high, b.reg_high)); + } + + float reduce_sum() const { + FP32Vec8 low = FP32Vec8(reg_low); + FP32Vec8 high = FP32Vec8(reg_high); + return low.reduce_sum() + high.reduce_sum(); + } + + template float reduce_sub_sum(int idx) { + float sum = 0.0; + static_assert(VEC_ELEM_NUM % group_size == 0); + constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); + uint32_t mask = base_mask << (idx * group_size); + + AliasReg ar; + + auto func = [&sum, &mask, &ar](int i) { + int flag = mask & 0x1; + mask = mask >> 1; + if (flag != 0) sum += ar.values[i]; + }; + + ar.reg = reg_low; + unroll_loop(func); + + ar.reg = reg_high; + unroll_loop(func); + + return sum; + } + + void save(float *ptr) const { + _mm256_storeu_ps(ptr, reg_low); + _mm256_storeu_ps(ptr + 8, reg_high); + } +}; +#endif + +template struct VecType { using vec_type = void; }; + +template using vec_t = typename VecType::vec_type; + +template <> struct VecType { using vec_type = FP32Vec8; }; + +#ifdef __AVX512FP16__ +template <> struct VecType { using vec_type = FP16Vec16; }; +#endif + +template <> struct VecType { using vec_type = BF16Vec8; }; + +template void storeFP32(float v, T *ptr) { *ptr = v; } + +#ifdef __AVX512FP16__ +template <> inline void storeFP32(float v, c10::Half *ptr) { + *reinterpret_cast<_Float16 *>(ptr) = v; +} +#endif + +inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { + acc = acc + a * b; +} + +#ifdef __AVX512BF16__ +template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { + *reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v); +} + +inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) + : reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {} + +inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) + : reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {} + +inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) { + acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg); +} +#else +template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { + c10::BFloat16 __attribute__((__may_alias__)) *v_ptr = + reinterpret_cast(&v); + *ptr = *(v_ptr + 1); +} + +#ifdef __AVX512F__ +inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) + : reg(_mm256_cvtepi32_epi16( + _mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {} + +inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) + : reg(_mm512_cvtepi32_epi16( + _mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {} +#else +namespace{ +__m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) { + __m256i ai = _mm256_castps_si256(a); + ai = _mm256_srli_epi32(ai, 16); + ai = _mm256_packus_epi32(ai, ai); + ai = _mm256_permute4x64_epi64(ai, 0b00111001); + return _mm256_extracti128_si256(ai, 0); +} +} + +inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) + : reg(FP32Vec8_to_BF16Vec8_avx2(v.reg)) {} + +inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) { + BF16Vec8 low = BF16Vec8(FP32Vec8(v.reg_low)); + BF16Vec8 high = BF16Vec8(FP32Vec8(v.reg_high)); + reg = _mm256_insertf128_si256(_mm256_castsi128_si256(low.reg), high.reg, 1); +} +#endif // __AVX512F__ +#endif // __AVX512BF16__ + +inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); } + +}; // namespace vec_op + +#endif diff --git a/csrc/cpu/layernorm.cpp b/csrc/cpu/layernorm.cpp index 467f0dc84982c..a76ad08928a2c 100644 --- a/csrc/cpu/layernorm.cpp +++ b/csrc/cpu/layernorm.cpp @@ -2,10 +2,10 @@ namespace { template -void rms_norm_impl(scalar_t *__restrict__ out, - const scalar_t *__restrict__ input, - const scalar_t *__restrict__ weight, const float epsilon, - const int num_tokens, const int hidden_size) { +void rms_norm_impl(scalar_t* __restrict__ out, + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ weight, const float epsilon, + const int num_tokens, const int hidden_size) { using scalar_vec_t = vec_op::vec_t; constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); @@ -41,11 +41,11 @@ void rms_norm_impl(scalar_t *__restrict__ out, } template -void fused_add_rms_norm_impl(scalar_t *__restrict__ input, - scalar_t *__restrict__ residual, - const scalar_t *__restrict__ weight, - const float epsilon, const int num_tokens, - const int hidden_size) { +void fused_add_rms_norm_impl(scalar_t* __restrict__ input, + scalar_t* __restrict__ residual, + const scalar_t* __restrict__ weight, + const float epsilon, const int num_tokens, + const int hidden_size) { using scalar_vec_t = vec_op::vec_t; constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); @@ -85,24 +85,24 @@ void fused_add_rms_norm_impl(scalar_t *__restrict__ input, } } } -} // namespace +} // namespace -void rms_norm(torch::Tensor &out, torch::Tensor &input, - torch::Tensor &weight, float epsilon) { +void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, + double epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] { CPU_KERNEL_GUARD_IN(rms_norm_impl) rms_norm_impl(out.data_ptr(), input.data_ptr(), - weight.data_ptr(), epsilon, num_tokens, - hidden_size); + weight.data_ptr(), epsilon, num_tokens, + hidden_size); CPU_KERNEL_GUARD_OUT(rms_norm_impl) }); } -void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual, - torch::Tensor &weight, float epsilon) { +void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, + torch::Tensor& weight, double epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; diff --git a/csrc/cpu/pos_encoding.cpp b/csrc/cpu/pos_encoding.cpp index e9b3992204bb2..96bce7dda0132 100644 --- a/csrc/cpu/pos_encoding.cpp +++ b/csrc/cpu/pos_encoding.cpp @@ -4,107 +4,107 @@ namespace { template void rotary_embedding_impl( - const int64_t - *__restrict__ positions, // [batch_size, seq_len] or [num_tokens] - scalar_t - *__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or - /// [num_tokens, num_heads, head_size] - scalar_t - *__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or - // [num_tokens, num_kv_heads, head_size] - const scalar_t - *__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads, + /// head_size] or [num_tokens, num_heads, + /// head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int num_heads, const int num_kv_heads, const int head_size, const int num_tokens) { using scalar_vec_t = vec_op::vec_t; constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); - constexpr int ELEM_SIZE = sizeof(scalar_t); const int embed_dim = rot_dim / 2; - TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0); + bool flag = (embed_dim % VEC_ELEM_NUM == 0); + const int loop_upper = flag ? embed_dim : embed_dim - VEC_ELEM_NUM; -#pragma omp parallel for - for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { - int64_t pos = positions[token_idx]; - const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; - - for (int i = 0; i < num_heads; ++i) { - const int head_idx = i; - const int64_t token_head = - token_idx * query_stride + head_idx * head_size; - for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) { - const int rot_offset = j; - const int x_index = rot_offset; - const int y_index = embed_dim + rot_offset; + auto compute_loop = [&](const int64_t token_head, const scalar_t* cache_ptr, + scalar_t* qk) { + int j = 0; + for (; j < loop_upper; j += VEC_ELEM_NUM) { + const int rot_offset = j; + const int x_index = rot_offset; + const int y_index = embed_dim + rot_offset; - const int64_t out_x = token_head + x_index; - const int64_t out_y = token_head + y_index; + const int64_t out_x = token_head + x_index; + const int64_t out_y = token_head + y_index; - const scalar_vec_t cos(cache_ptr + x_index); - const scalar_vec_t sin(cache_ptr + y_index); + const scalar_vec_t cos(cache_ptr + x_index); + const scalar_vec_t sin(cache_ptr + y_index); - const scalar_vec_t q_x(query + out_x); - const scalar_vec_t q_y(query + out_y); + const scalar_vec_t q_x(qk + out_x); + const scalar_vec_t q_y(qk + out_y); - vec_op::FP32Vec8 fp32_cos(cos); - vec_op::FP32Vec8 fp32_sin(sin); + vec_op::FP32Vec8 fp32_cos(cos); + vec_op::FP32Vec8 fp32_sin(sin); - vec_op::FP32Vec8 fp32_q_x(q_x); - vec_op::FP32Vec8 fp32_q_y(q_y); + vec_op::FP32Vec8 fp32_q_x(q_x); + vec_op::FP32Vec8 fp32_q_y(q_y); - auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin; - scalar_vec_t(out1).save(query + out_x); + auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin; + scalar_vec_t(out1).save(qk + out_x); - auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin; - scalar_vec_t(out2).save(query + out_y); - } + auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin; + scalar_vec_t(out2).save(qk + out_y); } - - for (int i = 0; i < num_kv_heads; ++i) { - const int head_idx = i; - const int64_t token_head = token_idx * key_stride + head_idx * head_size; - for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) { - const int rot_offset = j; - const int x_index = rot_offset; - const int y_index = embed_dim + rot_offset; + if (!flag) { + for (; j < embed_dim; ++j) { + const int x_index = j; + const int y_index = embed_dim + j; const int64_t out_x = token_head + x_index; const int64_t out_y = token_head + y_index; - const scalar_vec_t cos(cache_ptr + x_index); - const scalar_vec_t sin(cache_ptr + y_index); + const float fp32_cos = cache_ptr[x_index]; + const float fp32_sin = cache_ptr[y_index]; - const scalar_vec_t k_x(key + out_x); - const scalar_vec_t k_y(key + out_y); + const float fp32_q_x = qk[out_x]; + const float fp32_q_y = qk[out_y]; - vec_op::FP32Vec8 fp32_cos(cos); - vec_op::FP32Vec8 fp32_sin(sin); + qk[out_x] = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin; + qk[out_y] = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin; + } + } + }; - vec_op::FP32Vec8 fp32_k_x(k_x); - vec_op::FP32Vec8 fp32_k_y(k_y); +#pragma omp parallel for + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + int64_t pos = positions[token_idx]; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; - auto out1 = fp32_k_x * fp32_cos - fp32_k_y * fp32_sin; - scalar_vec_t(out1).save(key + out_x); - auto out2 = fp32_k_y * fp32_cos + fp32_k_x * fp32_sin; - scalar_vec_t(out2).save(key + out_y); - } + for (int i = 0; i < num_heads; ++i) { + const int head_idx = i; + const int64_t token_head = + token_idx * query_stride + head_idx * head_size; + compute_loop(token_head, cache_ptr, query); + } + + for (int i = 0; i < num_kv_heads; ++i) { + const int head_idx = i; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; + compute_loop(token_head, cache_ptr, key); } } } template void rotary_embedding_gptj_impl( - const int64_t - *__restrict__ positions, // [batch_size, seq_len] or [num_tokens] - scalar_t - *__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or - /// [num_tokens, num_heads, head_size] - scalar_t - *__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or - // [num_tokens, num_kv_heads, head_size] - const scalar_t - *__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads, + /// head_size] or [num_tokens, num_heads, + /// head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int num_heads, const int num_kv_heads, const int head_size, const int num_tokens) { @@ -114,13 +114,13 @@ void rotary_embedding_gptj_impl( for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int i = 0; i < num_heads; ++i) { int64_t pos = positions[token_idx]; - const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; - const scalar_t *cos_cache_ptr = cache_ptr; - const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + const scalar_t* cos_cache_ptr = cache_ptr; + const scalar_t* sin_cache_ptr = cache_ptr + embed_dim; const int head_idx = i; const int64_t token_head = token_idx * query_stride + head_idx * head_size; - scalar_t *head_query = token_head + query; + scalar_t* head_query = token_head + query; for (int j = 0; j < embed_dim; j += 1) { const int rot_offset = j; const int x_index = 2 * rot_offset; @@ -142,12 +142,12 @@ void rotary_embedding_gptj_impl( for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int i = 0; i < num_kv_heads; ++i) { int64_t pos = positions[token_idx]; - const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; - const scalar_t *cos_cache_ptr = cache_ptr; - const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + const scalar_t* cos_cache_ptr = cache_ptr; + const scalar_t* sin_cache_ptr = cache_ptr + embed_dim; const int head_idx = i; const int64_t token_head = token_idx * key_stride + head_idx * head_size; - scalar_t *head_key = key + token_head; + scalar_t* head_key = key + token_head; for (int j = 0; j < embed_dim; j += 1) { const int rot_offset = j; const int x_index = 2 * rot_offset; @@ -165,11 +165,11 @@ void rotary_embedding_gptj_impl( } } } -}; // namespace +}; // namespace -void rotary_embedding(torch::Tensor &positions, torch::Tensor &query, - torch::Tensor &key, int head_size, - torch::Tensor &cos_sin_cache, bool is_neox) { +void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, + torch::Tensor& key, int64_t head_size, + torch::Tensor& cos_sin_cache, bool is_neox) { int num_tokens = query.numel() / query.size(-1); int rot_dim = cos_sin_cache.size(1); int num_heads = query.size(-1) / head_size; diff --git a/csrc/cpu/pybind.cpp b/csrc/cpu/pybind.cpp deleted file mode 100644 index bba044087f37c..0000000000000 --- a/csrc/cpu/pybind.cpp +++ /dev/null @@ -1,73 +0,0 @@ -#include "cache.h" -#include "cuda_utils.h" -#include "ops.h" -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - // vLLM custom ops - pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); - - // Attention ops - ops.def( - "paged_attention_v1", - &paged_attention_v1, - "Compute the attention between an input query and the cached keys/values using PagedAttention."); - ops.def( - "paged_attention_v2", - &paged_attention_v2, - "PagedAttention V2."); - - // Activation ops - ops.def( - "silu_and_mul", - &silu_and_mul, - "Activation function used in SwiGLU."); - ops.def( - "gelu_and_mul", - &gelu_and_mul, - "Activation function used in GeGLU with `none` approximation."); - ops.def( - "gelu_tanh_and_mul", - &gelu_tanh_and_mul, - "Activation function used in GeGLU with `tanh` approximation."); - ops.def( - "gelu_new", - &gelu_new, - "GELU implementation used in GPT-2."); - ops.def( - "gelu_fast", - &gelu_fast, - "Approximate GELU implementation."); - - // Layernorm - ops.def( - "rms_norm", - &rms_norm, - "Apply Root Mean Square (RMS) Normalization to the input tensor."); - - ops.def( - "fused_add_rms_norm", - &fused_add_rms_norm, - "In-place fused Add and RMS Normalization"); - - // Rotary embedding - ops.def( - "rotary_embedding", - &rotary_embedding, - "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); - - // Cache ops - pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); - cache_ops.def( - "swap_blocks", - &swap_blocks, - "Swap in (out) the cache blocks from src to dst"); - cache_ops.def( - "copy_blocks", - ©_blocks, - "Copy the cache blocks from src to dst"); - cache_ops.def( - "reshape_and_cache", - &reshape_and_cache, - "Reshape the key and value tensors and cache them"); -} diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp new file mode 100644 index 0000000000000..39e8cf3ed3c10 --- /dev/null +++ b/csrc/cpu/torch_bindings.cpp @@ -0,0 +1,110 @@ +#include "cache.h" +#include "ops.h" +#include "registration.h" + +#include + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + // vLLM custom ops + + // Attention ops + // Compute the attention between an input query and the cached keys/values + // using PagedAttention. + ops.def( + "paged_attention_v1(" + " Tensor! out, Tensor query, Tensor key_cache," + " Tensor value_cache, int num_kv_heads, float scale," + " Tensor block_tables, Tensor seq_lens, int block_size," + " int max_seq_len, Tensor? alibi_slopes," + " str kv_cache_dtype, float kv_scale, int tp_rank," + " int blocksparse_local_blocks," + " int blocksparse_vert_stride, int blocksparse_block_size," + " int blocksparse_head_sliding_step) -> ()"); + ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1); + + // PagedAttention V2. + ops.def( + "paged_attention_v2(" + " Tensor! out, Tensor exp_sums, Tensor max_logits," + " Tensor tmp_out, Tensor query, Tensor key_cache," + " Tensor value_cache, int num_kv_heads, float scale," + " Tensor block_tables, Tensor seq_lens, int block_size," + " int max_seq_len, Tensor? alibi_slopes," + " str kv_cache_dtype, float kv_scale, int tp_rank," + " int blocksparse_local_blocks," + " int blocksparse_vert_stride, int blocksparse_block_size," + " int blocksparse_head_sliding_step) -> ()"); + ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2); + + // Activation ops + + // Activation function used in SwiGLU. + ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); + ops.impl("silu_and_mul", torch::kCPU, &silu_and_mul); + + // Activation function used in GeGLU with `none` approximation. + ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()"); + ops.impl("gelu_and_mul", torch::kCPU, &gelu_and_mul); + + // Activation function used in GeGLU with `tanh` approximation. + ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()"); + ops.impl("gelu_tanh_and_mul", torch::kCPU, &gelu_tanh_and_mul); + + // GELU implementation used in GPT-2. + ops.def("gelu_new(Tensor! out, Tensor input) -> ()"); + ops.impl("gelu_new", torch::kCPU, &gelu_new); + + // Approximate GELU implementation. + ops.def("gelu_fast(Tensor! out, Tensor input) -> ()"); + ops.impl("gelu_fast", torch::kCPU, &gelu_fast); + + // Quick GELU implementation. + ops.def("gelu_quick(Tensor! out, Tensor input) -> ()"); + ops.impl("gelu_quick", torch::kCPU, &gelu_quick); + + // Layernorm + // Apply Root Mean Square (RMS) Normalization to the input tensor. + ops.def( + "rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> " + "()"); + ops.impl("rms_norm", torch::kCPU, &rms_norm); + + // In-place fused Add and RMS Normalization. + ops.def( + "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, " + "float epsilon) -> ()"); + ops.impl("fused_add_rms_norm", torch::kCPU, &fused_add_rms_norm); + + // Rotary embedding + // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. + ops.def( + "rotary_embedding(Tensor positions, Tensor! query," + " Tensor! key, int head_size," + " Tensor cos_sin_cache, bool is_neox) -> ()"); + ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); +} + +TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { + // Cache ops + // Swap in (out) the cache blocks from src to dst. + cache_ops.def( + "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()"); + cache_ops.impl("swap_blocks", torch::kCPU, &swap_blocks); + + // Copy the cache blocks from src to dst. + cache_ops.def( + "copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor " + "block_mapping) -> ()"); + cache_ops.impl("copy_blocks", torch::kCPU, ©_blocks); + + // Reshape the key and value tensors and cache them. + cache_ops.def( + "reshape_and_cache(Tensor key, Tensor value," + " Tensor! key_cache, Tensor! value_cache," + " Tensor slot_mapping," + " str kv_cache_dtype," + " float kv_scale) -> ()"); + cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index c711d8d1b24b9..82e55613d915a 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -1,7 +1,7 @@ #pragma once #ifdef USE_ROCM -#include + #include #endif #ifndef USE_ROCM @@ -17,9 +17,14 @@ #endif #ifndef USE_ROCM - #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask) + #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask, width) #else #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) + #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ + __shfl_xor(var, lane_mask, width) #endif #ifndef USE_ROCM @@ -28,6 +33,13 @@ #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) #endif +#ifndef USE_ROCM + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \ + __shfl_down_sync(uint32_t(-1), var, lane_delta) +#else + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta) +#endif + #ifndef USE_ROCM #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) @@ -35,4 +47,3 @@ #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) #endif - diff --git a/csrc/cuda_utils.h b/csrc/cuda_utils.h index 1483484faeb4a..73944f4c14890 100644 --- a/csrc/cuda_utils.h +++ b/csrc/cuda_utils.h @@ -1,10 +1,5 @@ #pragma once -#include +int64_t get_device_attribute(int64_t attribute, int64_t device_id); -int get_device_attribute( - int attribute, - int device_id); - -int get_max_shared_memory_per_block_device_attribute( - int device_id); +int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id); diff --git a/csrc/cuda_utils_kernels.cu b/csrc/cuda_utils_kernels.cu index 1a443ef3620cc..d6f9eb646fad5 100644 --- a/csrc/cuda_utils_kernels.cu +++ b/csrc/cuda_utils_kernels.cu @@ -2,34 +2,28 @@ #include #include #endif -int get_device_attribute( - int attribute, - int device_id) -{ - int device, value; - if (device_id < 0) { - cudaGetDevice(&device); - } - else { - device = device_id; - } - cudaDeviceGetAttribute(&value, static_cast(attribute), device); - return value; +int64_t get_device_attribute(int64_t attribute, int64_t device_id) { + int device, value; + if (device_id < 0) { + cudaGetDevice(&device); + } else { + device = device_id; + } + cudaDeviceGetAttribute(&value, static_cast(attribute), + device); + return value; } - -int get_max_shared_memory_per_block_device_attribute( - int device_id) -{ -int attribute; -// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html -// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74 +int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id) { + int64_t attribute; + // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html + // cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74 #ifdef USE_ROCM - attribute = hipDeviceAttributeMaxSharedMemoryPerBlock; + attribute = hipDeviceAttributeMaxSharedMemoryPerBlock; #else - attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin; + attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin; #endif - return get_device_attribute(attribute, device_id); + return get_device_attribute(attribute, device_id); } diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index 3906dcfc80dbf..82a3563979f16 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -1,17 +1,17 @@ #include #include #include -#include +#include #include "custom_all_reduce.cuh" -// fake pointer type -using fptr_t = uint64_t; -static_assert(sizeof(void *) == sizeof(fptr_t)); +// fake pointer type, must match fptr_t type in ops.h +using fptr_t = int64_t; +static_assert(sizeof(void*) == sizeof(fptr_t)); -fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, - const std::vector &handles, - const std::vector &offsets, int rank, +fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, + const std::vector& handles, + const std::vector& offsets, int64_t rank, bool full_nvlink) { int world_size = offsets.size(); if (world_size > 8) @@ -29,7 +29,7 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t)); } return (fptr_t) new vllm::CustomAllreduce( - reinterpret_cast(meta.data_ptr()), rank_data.data_ptr(), + reinterpret_cast(meta.data_ptr()), rank_data.data_ptr(), rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); } @@ -49,13 +49,13 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, * 5. A[None].expand(2, -1, -1, -1): Not OK * 6. A[:, 1:, 1:]: Not OK */ -bool _is_weak_contiguous(torch::Tensor &t) { +bool _is_weak_contiguous(torch::Tensor& t) { return t.is_contiguous() || (t.storage().nbytes() - t.storage_offset() * t.element_size() == t.numel() * t.element_size()); } -bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, +bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size, bool full_nvlink) { auto inp_size = inp.numel() * inp.element_size(); // custom allreduce requires input byte size to be multiples of 16 @@ -67,28 +67,27 @@ bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, return false; } -void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, +void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, cudaStream_t stream) { - auto fa = reinterpret_cast(_fa); + auto fa = reinterpret_cast(_fa); TORCH_CHECK(_is_weak_contiguous(out)); switch (out.scalar_type()) { case at::ScalarType::Float: { - fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), - reinterpret_cast(out.data_ptr()), + fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), out.numel()); break; } case at::ScalarType::Half: { - fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), - reinterpret_cast(out.data_ptr()), - out.numel()); + fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), out.numel()); break; } #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) case at::ScalarType::BFloat16: { fa->allreduce( - stream, reinterpret_cast(inp.data_ptr()), - reinterpret_cast(out.data_ptr()), out.numel()); + stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), out.numel()); break; } #endif @@ -98,7 +97,7 @@ void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, } } -void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) { +void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) { const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); auto stream = c10::cuda::getCurrentCUDAStream().stream(); TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); @@ -106,8 +105,8 @@ void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) { _all_reduce(_fa, inp, out, stream); } -void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, - torch::Tensor &out) { +void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, + torch::Tensor& out) { const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); auto stream = c10::cuda::getCurrentCUDAStream().stream(); @@ -122,27 +121,33 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, } void dispose(fptr_t _fa) { - auto fa = reinterpret_cast(_fa); + auto fa = reinterpret_cast(_fa); delete fa; } -int meta_size() { return sizeof(vllm::Signal); } +int64_t meta_size() { return sizeof(vllm::Signal); } -void register_buffer(fptr_t _fa, torch::Tensor &t, - const std::vector &handles, - const std::vector &offsets) { - auto fa = reinterpret_cast(_fa); +void register_buffer(fptr_t _fa, torch::Tensor& t, + const std::vector& handles, + const std::vector& offsets) { + auto fa = reinterpret_cast(_fa); fa->register_buffer(handles, offsets, t.data_ptr()); } -std::pair, std::vector> get_graph_buffer_ipc_meta( +std::tuple> get_graph_buffer_ipc_meta( fptr_t _fa) { - auto fa = reinterpret_cast(_fa); - return fa->get_graph_buffer_ipc_meta(); + auto fa = reinterpret_cast(_fa); + auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta(); + auto options = + torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto handles = + torch::empty({static_cast(handle_bytes.size())}, options); + std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size()); + return {handles, std::move(offsets)}; } -void register_graph_buffers(fptr_t _fa, const std::vector &handles, - const std::vector> &offsets) { - auto fa = reinterpret_cast(_fa); +void register_graph_buffers(fptr_t _fa, const std::vector& handles, + const std::vector>& offsets) { + auto fa = reinterpret_cast(_fa); fa->register_graph_buffers(handles, offsets); } diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 750e68d42f6c6..1ed49b8aa9cae 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -31,9 +31,9 @@ struct Signal { alignas(128) uint32_t end[kMaxBlocks][8]; }; -struct __align__(16) RankData { const void *__restrict__ ptrs[8]; }; +struct __align__(16) RankData { const void* __restrict__ ptrs[8]; }; -struct __align__(16) RankSignals { volatile Signal *signals[8]; }; +struct __align__(16) RankSignals { volatile Signal* signals[8]; }; // like std::array, but aligned template @@ -68,11 +68,11 @@ DINLINE half downcast_s(float val) { // scalar add functions // for some reason when compiling with Pytorch, the + operator for half and // bfloat is disabled so we call the intrinsics directly -DINLINE half &assign_add(half &a, half b) { +DINLINE half& assign_add(half& a, half b) { a = __hadd(a, b); return a; } -DINLINE float &assign_add(float &a, float b) { return a += b; } +DINLINE float& assign_add(float& a, float b) { return a += b; } #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); } @@ -80,14 +80,14 @@ template <> DINLINE nv_bfloat16 downcast_s(float val) { return __float2bfloat16(val); } -DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) { +DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) { a = __hadd(a, b); return a; } #endif template -DINLINE array_t &packed_assign_add(array_t &a, array_t b) { +DINLINE array_t& packed_assign_add(array_t& a, array_t b) { #pragma unroll for (int i = 0; i < N; i++) { assign_add(a.data[i], b.data[i]); @@ -128,7 +128,7 @@ DINLINE O downcast(array_t val) { // prior memory accesses. Note: volatile writes will not be reordered against // other volatile writes. template -DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, +DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg, int rank) { if (threadIdx.x < ngpus) { // reset flag for next time @@ -137,8 +137,7 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, // Latency = 1 p2p write sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1; // wait until we got true from all ranks - while (!self_sg->start[blockIdx.x][threadIdx.x]) - ; + while (!self_sg->start[blockIdx.x][threadIdx.x]); } __syncthreads(); } @@ -147,13 +146,13 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, // barrier in the all reduce kernel. If it's the final synchronization barrier, // we don't need to make any visibility guarantees for prior memory accesses. template -DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg, +DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg, int rank) { __syncthreads(); // eliminate the case that prior writes are not visible after signals become // visible. Note that I did not managed to make this happen through a lot of // testing. Might be the case that hardware provides stronger guarantee than - // the memory model. + // the memory model. if constexpr (!final_sync) __threadfence_system(); if (threadIdx.x < ngpus) { // reset flag for next time @@ -162,14 +161,13 @@ DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg, // Latency = 1 p2p write sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1; // wait until we got true from all ranks - while (!self_sg->end[blockIdx.x][threadIdx.x]) - ; + while (!self_sg->end[blockIdx.x][threadIdx.x]); } if constexpr (!final_sync) __syncthreads(); } template -DINLINE P packed_reduce(const P *ptrs[], int idx) { +DINLINE P packed_reduce(const P* ptrs[], int idx) { A tmp = upcast(ptrs[0][idx]); #pragma unroll for (int i = 1; i < ngpus; i++) { @@ -180,8 +178,8 @@ DINLINE P packed_reduce(const P *ptrs[], int idx) { template __global__ void __launch_bounds__(512, 1) - cross_device_reduce_1stage(RankData *_dp, RankSignals sg, - volatile Signal *self_sg, T *__restrict__ result, + cross_device_reduce_1stage(RankData* _dp, RankSignals sg, + volatile Signal* self_sg, T* __restrict__ result, int rank, int size) { using P = typename packed_t::P; using A = typename packed_t::A; @@ -192,21 +190,20 @@ __global__ void __launch_bounds__(512, 1) // do the actual reduction for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { - ((P *)result)[idx] = - packed_reduce((const P **)&dp.ptrs[0], idx); + ((P*)result)[idx] = packed_reduce((const P**)&dp.ptrs[0], idx); } end_sync(sg, self_sg, rank); } template -DINLINE P *get_tmp_buf(volatile Signal *sg) { - return (P *)(((Signal *)sg) + 1); +DINLINE P* get_tmp_buf(volatile Signal* sg) { + return (P*)(((Signal*)sg) + 1); } template __global__ void __launch_bounds__(512, 1) - cross_device_reduce_2stage(RankData *_dp, RankSignals sg, - volatile Signal *self_sg, T *__restrict__ result, + cross_device_reduce_2stage(RankData* _dp, RankSignals sg, + volatile Signal* self_sg, T* __restrict__ result, int rank, int size) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = gridDim.x * blockDim.x; @@ -216,12 +213,12 @@ __global__ void __launch_bounds__(512, 1) int start = rank * part; int end = rank == ngpus - 1 ? size : start + part; int largest_part = part + size % ngpus; - const P *ptrs[ngpus]; - P *tmps[ngpus]; + const P* ptrs[ngpus]; + P* tmps[ngpus]; #pragma unroll for (int i = 0; i < ngpus; i++) { int target = (rank + i) % ngpus; - ptrs[i] = (const P *)_dp->ptrs[target]; + ptrs[i] = (const P*)_dp->ptrs[target]; tmps[i] = get_tmp_buf

(sg.signals[target]); } auto tmp_out = tmps[0]; @@ -243,7 +240,7 @@ __global__ void __launch_bounds__(512, 1) int gather_from_rank = ((rank + i) % ngpus); if (gather_from_rank == ngpus - 1 || idx < part) { int dst_idx = gather_from_rank * part + idx; - ((P *)result)[dst_idx] = tmps[i][idx]; + ((P*)result)[dst_idx] = tmps[i][idx]; } } } @@ -261,14 +258,14 @@ class CustomAllreduce { // below are device pointers RankSignals sg_; - std::unordered_map buffers_; - Signal *self_sg_; + std::unordered_map buffers_; + Signal* self_sg_; // stores the registered device pointers from all ranks RankData *d_rank_data_base_, *d_rank_data_end_; - std::vector graph_unreg_buffers_; + std::vector graph_unreg_buffers_; // a map from IPC handles to opened IPC pointers - std::map ipc_handles_; + std::map ipc_handles_; /** * meta is a pointer to device metadata and temporary buffer for allreduce. @@ -279,22 +276,22 @@ class CustomAllreduce { * note: this class does not own any device memory. Any required buffers * are passed in from the constructor */ - CustomAllreduce(Signal *meta, void *rank_data, size_t rank_data_sz, - const cudaIpcMemHandle_t *handles, - const std::vector &offsets, int rank, + CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz, + const cudaIpcMemHandle_t* handles, + const std::vector& offsets, int rank, bool full_nvlink = true) : rank_(rank), world_size_(offsets.size()), full_nvlink_(full_nvlink), self_sg_(meta), - d_rank_data_base_(reinterpret_cast(rank_data)), + d_rank_data_base_(reinterpret_cast(rank_data)), d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) { for (int i = 0; i < world_size_; i++) { - Signal *rank_sg; + Signal* rank_sg; if (i != rank_) { - char *handle = open_ipc_handle(&handles[i]); + char* handle = open_ipc_handle(&handles[i]); handle += offsets[i]; - rank_sg = (Signal *)handle; + rank_sg = (Signal*)handle; } else { rank_sg = self_sg_; } @@ -302,13 +299,13 @@ class CustomAllreduce { } } - char *open_ipc_handle(const void *ipc_handle) { + char* open_ipc_handle(const void* ipc_handle) { auto [it, new_handle] = - ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr}); + ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr}); if (new_handle) { - char *ipc_ptr; - CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr, - *((const cudaIpcMemHandle_t *)ipc_handle), + char* ipc_ptr; + CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr, + *((const cudaIpcMemHandle_t*)ipc_handle), cudaIpcMemLazyEnablePeerAccess)); it->second = ipc_ptr; } @@ -323,7 +320,7 @@ class CustomAllreduce { std::vector offsets(num_buffers); for (int i = 0; i < num_buffers; i++) { auto ptr = graph_unreg_buffers_[i]; - void *base_ptr; + void* base_ptr; // note: must share the base address of each allocation, or we get wrong // address if (cuPointerGetAttribute(&base_ptr, @@ -331,8 +328,8 @@ class CustomAllreduce { (CUdeviceptr)ptr) != CUDA_SUCCESS) throw std::runtime_error("failed to get pointer attr"); CUDACHECK(cudaIpcGetMemHandle( - (cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr)); - offsets[i] = ((char *)ptr) - ((char *)base_ptr); + (cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr)); + offsets[i] = ((char*)ptr) - ((char*)base_ptr); } return std::make_pair(handles, offsets); } @@ -344,13 +341,13 @@ class CustomAllreduce { std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); } - void register_buffer(const std::vector &handles, - const std::vector &offsets, void *self) { + void register_buffer(const std::vector& handles, + const std::vector& offsets, void* self) { check_rank_data_capacity(); RankData data; for (int i = 0; i < world_size_; i++) { if (i != rank_) { - char *handle = open_ipc_handle(handles[i].data()); + char* handle = open_ipc_handle(handles[i].data()); handle += offsets[i]; data.ptrs[i] = handle; } else { @@ -371,17 +368,17 @@ class CustomAllreduce { // got a different address. IPC handles have internal reference counting // mechanism so overhead should be small. void register_graph_buffers( - const std::vector &handles, - const std::vector> &offsets) { + const std::vector& handles, + const std::vector>& offsets) { auto num_buffers = graph_unreg_buffers_.size(); check_rank_data_capacity(num_buffers); std::vector rank_data(num_buffers); for (int i = 0; i < num_buffers; i++) { auto self_ptr = graph_unreg_buffers_[i]; - auto &rd = rank_data[i]; + auto& rd = rank_data[i]; for (int j = 0; j < world_size_; j++) { if (j != rank_) { - char *handle = + char* handle = open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]); handle += offsets[j][i]; rd.ptrs[j] = handle; @@ -405,7 +402,7 @@ class CustomAllreduce { * will cause contention on NVLink bus. */ template - void allreduce(cudaStream_t stream, T *input, T *output, int size, + void allreduce(cudaStream_t stream, T* input, T* output, int size, int threads = 512, int block_limit = 36) { auto d = packed_t::P::size; if (size % d != 0) @@ -418,7 +415,7 @@ class CustomAllreduce { std::to_string(kMaxBlocks) + ". Got " + std::to_string(block_limit)); - RankData *ptrs; + RankData* ptrs; cudaStreamCaptureStatus status; CUDACHECK(cudaStreamIsCapturing(stream, &status)); if (status == cudaStreamCaptureStatusActive) { diff --git a/csrc/custom_all_reduce_test.cu b/csrc/custom_all_reduce_test.cu index c34a50389c21c..f7868233076cd 100644 --- a/csrc/custom_all_reduce_test.cu +++ b/csrc/custom_all_reduce_test.cu @@ -48,7 +48,7 @@ __global__ void dummy_kernel() { } template -__global__ void set_data(T *data, int size, int myRank) { +__global__ void set_data(T* data, int size, int myRank) { for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { data[idx] = myRank * 0.11f; @@ -56,8 +56,8 @@ __global__ void set_data(T *data, int size, int myRank) { } template -__global__ void convert_data(const T *data1, const T *data2, double *fdata1, - double *fdata2, int size) { +__global__ void convert_data(const T* data1, const T* data2, double* fdata1, + double* fdata2, int size) { for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { fdata1[idx] = data1[idx]; @@ -65,7 +65,7 @@ __global__ void convert_data(const T *data1, const T *data2, double *fdata1, } } -__global__ void init_rand(curandState_t *state, int size, int nRanks) { +__global__ void init_rand(curandState_t* state, int size, int nRanks) { for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { for (int i = 0; i < nRanks; i++) { @@ -75,7 +75,7 @@ __global__ void init_rand(curandState_t *state, int size, int nRanks) { } template -__global__ void gen_data(curandState_t *state, T *data, double *ground_truth, +__global__ void gen_data(curandState_t* state, T* data, double* ground_truth, int myRank, int nRanks, int size) { for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { @@ -91,9 +91,9 @@ __global__ void gen_data(curandState_t *state, T *data, double *ground_truth, } template -void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, +void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, int data_size, bool performance_test) { - T *result; + T* result; cudaStream_t stream; CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); CUDACHECK(cudaMalloc(&result, data_size * sizeof(T))); @@ -101,8 +101,8 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, cudaIpcMemHandle_t self_data_handle; cudaIpcMemHandle_t data_handles[8]; - vllm::Signal *buffer; - T *self_data_copy; + vllm::Signal* buffer; + T* self_data_copy; /** * Allocate IPC buffer * @@ -125,22 +125,22 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t), MPI_BYTE, MPI_COMM_WORLD)); - void *rank_data; + void* rank_data; size_t rank_data_sz = 16 * 1024 * 1024; CUDACHECK(cudaMalloc(&rank_data, rank_data_sz)); std::vector offsets(nRanks, 0); vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles, offsets, myRank); - auto *self_data = - reinterpret_cast(reinterpret_cast(buffer) + - sizeof(vllm::Signal) + data_size * sizeof(T)); + auto* self_data = + reinterpret_cast(reinterpret_cast(buffer) + + sizeof(vllm::Signal) + data_size * sizeof(T)); // hack buffer registration { std::vector handles; handles.reserve(nRanks); for (int i = 0; i < nRanks; i++) { - char *begin = (char *)&data_handles[i]; - char *end = (char *)&data_handles[i + 1]; + char* begin = (char*)&data_handles[i]; + char* end = (char*)&data_handles[i + 1]; handles.emplace_back(begin, end); } std::vector offsets(nRanks, @@ -148,9 +148,9 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, fa.register_buffer(handles, offsets, self_data); } - double *ground_truth; + double* ground_truth; CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double))); - curandState_t *states; + curandState_t* states; CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size)); init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks); gen_data<<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank, @@ -287,7 +287,7 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, CUDACHECK(cudaStreamDestroy(stream)); } -int main(int argc, char **argv) { +int main(int argc, char** argv) { int nRanks, myRank; MPICHECK(MPI_Init(&argc, &argv)); MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank)); @@ -296,7 +296,7 @@ int main(int argc, char **argv) { ncclUniqueId id; ncclComm_t comm; if (myRank == 0) ncclGetUniqueId(&id); - MPICHECK(MPI_Bcast(static_cast(&id), sizeof(id), MPI_BYTE, 0, + MPICHECK(MPI_Bcast(static_cast(&id), sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD)); NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank)); diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 91abd9e85b4bb..a634e1c3d4886 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -4,34 +4,32 @@ */ #pragma once -#include +#include -#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) -#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH( \ - TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) -#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ +#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) -#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH( \ - TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) - -#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ +#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, \ + VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) -#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH( \ - TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) +#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index e56b4d2204005..ca1c04bd880d9 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -11,26 +11,24 @@ #include #include - using __nv_bfloat16 = __hip_bfloat16; - using __nv_bfloat162 = __hip_bfloat162; +using __nv_bfloat16 = __hip_bfloat16; +using __nv_bfloat162 = __hip_bfloat162; #endif namespace vllm { // TODO(woosuk): Further optimize this kernel. -template +template __global__ void rms_norm_kernel( - scalar_t* __restrict__ out, // [..., hidden_size] - const scalar_t* __restrict__ input, // [..., hidden_size] - const scalar_t* __restrict__ weight, // [hidden_size] - const float epsilon, - const int num_tokens, - const int hidden_size) { + scalar_t* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, const int hidden_size) { __shared__ float s_variance; float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - const float x = (float) input[blockIdx.x * hidden_size + idx]; + const float x = (float)input[blockIdx.x * hidden_size + idx]; variance += x * x; } variance = blockReduceSum(variance); @@ -40,12 +38,12 @@ __global__ void rms_norm_kernel( __syncthreads(); for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float) input[blockIdx.x * hidden_size + idx]; - out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; + float x = (float)input[blockIdx.x * hidden_size + idx]; + out[blockIdx.x * hidden_size + idx] = + ((scalar_t)(x * s_variance)) * weight[idx]; } } - /* Converter structs for the conversion from torch types to HIP/CUDA types, and the associated type conversions within HIP/CUDA. These helpers need to be implemented for now because the relevant type conversion @@ -54,51 +52,68 @@ __global__ void rms_norm_kernel( Each struct should have the member static constexpr bool `exists`: If false, the optimized kernel is not used for the corresponding torch type. - If true, the struct should be fully defined as shown in the examples below. + If true, the struct should be fully defined as shown in the examples below. */ -template -struct _typeConvert { static constexpr bool exists = false; }; +template +struct _typeConvert { + static constexpr bool exists = false; +}; #if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) // CUDA < 12.0 runs into issues with packed type conversion -template<> +template <> struct _typeConvert { static constexpr bool exists = true; using hip_type = __half; using packed_hip_type = __half2; __device__ static inline float convert(hip_type x) { return __half2float(x); } - __device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); } - __device__ static inline hip_type convert(float x) { return __float2half_rn(x); } - __device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); } + __device__ static inline float2 convert(packed_hip_type x) { + return __half22float2(x); + } + __device__ static inline hip_type convert(float x) { + return __float2half_rn(x); + } + __device__ static inline packed_hip_type convert(float2 x) { + return __float22half2_rn(x); + } }; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // CUDA_ARCH < 800 does not have BF16 support // TODO: Add in ROCm support once public headers handle bf16 maturely -template<> +template <> struct _typeConvert { static constexpr bool exists = true; using hip_type = __nv_bfloat16; using packed_hip_type = __nv_bfloat162; - __device__ static inline float convert(hip_type x) { return __bfloat162float(x); } - __device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); } - __device__ static inline hip_type convert(float x) { return __float2bfloat16(x); } - __device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); } + __device__ static inline float convert(hip_type x) { + return __bfloat162float(x); + } + __device__ static inline float2 convert(packed_hip_type x) { + return __bfloat1622float2(x); + } + __device__ static inline hip_type convert(float x) { + return __float2bfloat16(x); + } + __device__ static inline packed_hip_type convert(float2 x) { + return __float22bfloat162_rn(x); + } }; -#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) + #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= + // 12000)) /* Vector POD struct to generate vectorized and packed FP16/BF16 ops for appropriate specializations of fused_add_rms_norm_kernel. Only functions that are necessary in that kernel are implemented. Alignment to 16 bytes is required to use 128-bit global memory ops. */ -template +template struct alignas(16) _f16Vec { - /* Not theoretically necessary that width is a power of 2 but should - almost always be the case for optimization purposes */ + /* Not theoretically necessary that width is a power of 2 but should + almost always be the case for optimization purposes */ static_assert(width > 0 && (width & (width - 1)) == 0, "Width is not a positive power of 2!"); using Converter = _typeConvert; @@ -108,51 +123,49 @@ struct alignas(16) _f16Vec { __device__ _f16Vec& operator+=(const _f16Vec& other) { if constexpr (width % 2 == 0) { - #pragma unroll +#pragma unroll for (int i = 0; i < width; i += 2) { - T2 temp{data[i], data[i+1]}; - temp += T2{other.data[i], other.data[i+1]}; + T2 temp{data[i], data[i + 1]}; + temp += T2{other.data[i], other.data[i + 1]}; data[i] = temp.x; - data[i+1] = temp.y; + data[i + 1] = temp.y; } } else { - #pragma unroll - for (int i = 0; i < width; ++i) - data[i] += other.data[i]; +#pragma unroll + for (int i = 0; i < width; ++i) data[i] += other.data[i]; } return *this; } __device__ _f16Vec& operator*=(const _f16Vec& other) { if constexpr (width % 2 == 0) { - #pragma unroll +#pragma unroll for (int i = 0; i < width; i += 2) { - T2 temp{data[i], data[i+1]}; - temp *= T2{other.data[i], other.data[i+1]}; + T2 temp{data[i], data[i + 1]}; + temp *= T2{other.data[i], other.data[i + 1]}; data[i] = temp.x; - data[i+1] = temp.y; + data[i + 1] = temp.y; } } else { - #pragma unroll - for (int i = 0; i < width; ++i) - data[i] *= other.data[i]; +#pragma unroll + for (int i = 0; i < width; ++i) data[i] *= other.data[i]; } return *this; } __device__ _f16Vec& operator*=(const float scale) { if constexpr (width % 2 == 0) { - #pragma unroll +#pragma unroll for (int i = 0; i < width; i += 2) { - float2 temp_f = Converter::convert(T2{data[i], data[i+1]}); + float2 temp_f = Converter::convert(T2{data[i], data[i + 1]}); temp_f.x *= scale; temp_f.y *= scale; T2 temp = Converter::convert(temp_f); data[i] = temp.x; - data[i+1] = temp.y; + data[i + 1] = temp.y; } } else { - #pragma unroll +#pragma unroll for (int i = 0; i < width; ++i) { float temp = Converter::convert(data[i]) * scale; data[i] = Converter::convert(temp); @@ -164,13 +177,13 @@ struct alignas(16) _f16Vec { __device__ float sum_squares() const { float result = 0.0f; if constexpr (width % 2 == 0) { - #pragma unroll +#pragma unroll for (int i = 0; i < width; i += 2) { - float2 z = Converter::convert(T2{data[i], data[i+1]}); + float2 z = Converter::convert(T2{data[i], data[i + 1]}); result += z.x * z.x + z.y * z.y; } } else { - #pragma unroll +#pragma unroll for (int i = 0; i < width; ++i) { float x = Converter::convert(data[i]); result += x * x; @@ -184,15 +197,13 @@ struct alignas(16) _f16Vec { Additional optimizations we can make in this case are packed and vectorized operations, which help with the memory latency bottleneck. */ -template -__global__ std::enable_if_t< - (width > 0) && _typeConvert::exists> fused_add_rms_norm_kernel( - scalar_t* __restrict__ input, // [..., hidden_size] - scalar_t* __restrict__ residual, // [..., hidden_size] - const scalar_t* __restrict__ weight, // [hidden_size] - const float epsilon, - const int num_tokens, - const int hidden_size) { +template +__global__ std::enable_if_t<(width > 0) && _typeConvert::exists> +fused_add_rms_norm_kernel( + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, const int hidden_size) { // Sanity checks on our vector struct and type-punned pointer arithmetic static_assert(std::is_pod_v<_f16Vec>); static_assert(sizeof(_f16Vec) == sizeof(scalar_t) * width); @@ -203,9 +214,12 @@ __global__ std::enable_if_t< /* These and the argument pointers are all declared `restrict` as they are not aliased in practice. Argument pointers should not be dereferenced in this kernel as that would be undefined behavior */ - auto* __restrict__ input_v = reinterpret_cast<_f16Vec*>(input); - auto* __restrict__ residual_v = reinterpret_cast<_f16Vec*>(residual); - auto* __restrict__ weight_v = reinterpret_cast*>(weight); + auto* __restrict__ input_v = + reinterpret_cast<_f16Vec*>(input); + auto* __restrict__ residual_v = + reinterpret_cast<_f16Vec*>(residual); + auto* __restrict__ weight_v = + reinterpret_cast*>(weight); for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { int id = blockIdx.x * vec_hidden_size + idx; @@ -215,10 +229,11 @@ __global__ std::enable_if_t< residual_v[id] = temp; } /* Keep the following if-else block in sync with the - calculation of max_block_size in fused_add_rms_norm */ + calculation of max_block_size in fused_add_rms_norm */ if (num_tokens < 256) { variance = blockReduceSum(variance); - } else variance = blockReduceSum(variance); + } else + variance = blockReduceSum(variance); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } @@ -233,52 +248,50 @@ __global__ std::enable_if_t< } } - /* Generic fused_add_rms_norm_kernel The width field is not used here but necessary for other specializations. */ -template -__global__ std::enable_if_t< - (width == 0) || !_typeConvert::exists> fused_add_rms_norm_kernel( - scalar_t* __restrict__ input, // [..., hidden_size] - scalar_t* __restrict__ residual, // [..., hidden_size] - const scalar_t* __restrict__ weight, // [hidden_size] - const float epsilon, - const int num_tokens, - const int hidden_size) { +template +__global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> +fused_add_rms_norm_kernel( + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, const int hidden_size) { __shared__ float s_variance; float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { scalar_t z = input[blockIdx.x * hidden_size + idx]; z += residual[blockIdx.x * hidden_size + idx]; - float x = (float) z; + float x = (float)z; variance += x * x; residual[blockIdx.x * hidden_size + idx] = z; } /* Keep the following if-else block in sync with the - calculation of max_block_size in fused_add_rms_norm */ + calculation of max_block_size in fused_add_rms_norm */ if (num_tokens < 256) { variance = blockReduceSum(variance); - } else variance = blockReduceSum(variance); + } else + variance = blockReduceSum(variance); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } __syncthreads(); for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float) residual[blockIdx.x * hidden_size + idx]; - input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; + float x = (float)residual[blockIdx.x * hidden_size + idx]; + input[blockIdx.x * hidden_size + idx] = + ((scalar_t)(x * s_variance)) * weight[idx]; } } -} // namespace vllm +} // namespace vllm -void rms_norm( - torch::Tensor& out, // [..., hidden_size] - torch::Tensor& input, // [..., hidden_size] - torch::Tensor& weight, // [hidden_size] - float epsilon) { +void rms_norm(torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + double epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; @@ -286,40 +299,27 @@ void rms_norm( dim3 block(std::min(hidden_size, 1024)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), - "rms_norm_kernel", - [&] { - vllm::rms_norm_kernel<<>>( - out.data_ptr(), - input.data_ptr(), - weight.data_ptr(), - epsilon, - num_tokens, - hidden_size); - }); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { + vllm::rms_norm_kernel<<>>( + out.data_ptr(), input.data_ptr(), + weight.data_ptr(), epsilon, num_tokens, hidden_size); + }); } -#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), \ - "fused_add_rms_norm_kernel", \ - [&] { \ - vllm::fused_add_rms_norm_kernel \ - <<>>( \ - input.data_ptr(), \ - residual.data_ptr(), \ - weight.data_ptr(), \ - epsilon, \ - num_tokens, \ - hidden_size); \ - }); - -void fused_add_rms_norm( - torch::Tensor& input, // [..., hidden_size] - torch::Tensor& residual, // [..., hidden_size] - torch::Tensor& weight, // [hidden_size] - float epsilon) { +#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \ + vllm::fused_add_rms_norm_kernel \ + <<>>(input.data_ptr(), \ + residual.data_ptr(), \ + weight.data_ptr(), epsilon, \ + num_tokens, hidden_size); \ + }); + +void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] + torch::Tensor& residual, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + double epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; @@ -342,8 +342,8 @@ void fused_add_rms_norm( auto inp_ptr = reinterpret_cast(input.data_ptr()); auto res_ptr = reinterpret_cast(residual.data_ptr()); auto wt_ptr = reinterpret_cast(weight.data_ptr()); - bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 \ - && wt_ptr % 16 == 0; + bool ptrs_are_aligned = + inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; if (ptrs_are_aligned && hidden_size % 8 == 0) { LAUNCH_FUSED_ADD_RMS_NORM(8); } else { diff --git a/csrc/moe/moe_ops.cpp b/csrc/moe/moe_ops.cpp deleted file mode 100644 index 35c328499a22d..0000000000000 --- a/csrc/moe/moe_ops.cpp +++ /dev/null @@ -1,7 +0,0 @@ -#include "moe_ops.h" - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs."); -} diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index a01be3e426d72..a251730aa765a 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -1,9 +1,7 @@ #pragma once -#include +#include -void topk_softmax( - torch::Tensor& topk_weights, - torch::Tensor& topk_indices, - torch::Tensor& token_expert_indices, - torch::Tensor& gating_output); +void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices, + torch::Tensor& token_expert_indices, + torch::Tensor& gating_output); diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index 8c65f40fe836a..de9747b602524 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -16,18 +16,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include +#include #include #include +#include "../cuda_compat.h" -#include -#include +#ifndef USE_ROCM + #include + #include +#else + #include + #include +#endif + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) namespace vllm { namespace moe { -static constexpr int WARP_SIZE = 32; - /// Aligned array type template < typename T, @@ -265,7 +272,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ #pragma unroll for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { - thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW)); + thread_max = max(thread_max, VLLM_SHFL_XOR_SYNC_WIDTH(thread_max, mask, THREADS_PER_ROW)); } // From this point, thread max in all the threads have the max within the row. @@ -282,7 +289,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ #pragma unroll for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { - row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW); + row_sum += VLLM_SHFL_XOR_SYNC_WIDTH(row_sum, mask, THREADS_PER_ROW); } // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables @@ -332,8 +339,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ #pragma unroll for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { - float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW); - int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW); + float other_max = VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW); + int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW); // We want lower indices to "win" in every thread so we break ties this way if (other_max > max_val || (other_max == max_val && other_expert < expert)) @@ -383,7 +390,7 @@ struct TopkConstants { static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); - static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); + static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; static constexpr int THREADS_PER_ROW = EXPERTS / VPT; static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; @@ -396,7 +403,7 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f { static constexpr std::size_t MAX_BYTES_PER_LDG = 16; - static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); + static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); using Constants = detail::TopkConstants; static constexpr int VPT = Constants::VPT; static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp new file mode 100644 index 0000000000000..243752b9a9e8c --- /dev/null +++ b/csrc/moe/torch_bindings.cpp @@ -0,0 +1,12 @@ +#include "registration.h" +#include "moe_ops.h" + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { + // Apply topk softmax to the gating outputs. + m.def( + "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " + "token_expert_indices, Tensor gating_output) -> ()"); + m.impl("topk_softmax", torch::kCUDA, &topk_softmax); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/moe_align_block_size_kernels.cu b/csrc/moe_align_block_size_kernels.cu index e01b23685ef4e..1f8d75da83bb8 100644 --- a/csrc/moe_align_block_size_kernels.cu +++ b/csrc/moe_align_block_size_kernels.cu @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -7,119 +7,128 @@ #include "cuda_compat.h" #include "dispatch_utils.h" -#define CEILDIV(x,y) (((x) + (y) - 1) / (y)) +#define CEILDIV(x, y) (((x) + (y) - 1) / (y)) namespace vllm { namespace { -__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) { - // don't worry about overflow because num_experts is relatively small - return row * total_col + col; -} +__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, + int32_t col) { + // don't worry about overflow because num_experts is relatively small + return row * total_col + col; } +} // namespace template -__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, - int32_t *sorted_token_ids, - int32_t *expert_ids, - int32_t *total_tokens_post_pad, - int32_t num_experts, - int32_t block_size, - size_t numel) { - const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); - const size_t start_idx = threadIdx.x * tokens_per_thread; - - extern __shared__ int32_t shared_mem[]; - - int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts) - int32_t* cumsum = shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1) - - for (int i = 0; i < num_experts; ++i) { - tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; - } - - /** - * In the first step we compute token_cnts[thread_index + 1][expert_index], - * which counts how many tokens in the token shard of thread_index are assigned - * to expert expert_index. - */ - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; - } - - __syncthreads(); - - // For each expert we accumulate the token counts from the different threads. - tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; - for (int i = 1; i <= blockDim.x; ++i) { - tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i-1, threadIdx.x)]; - } - - __syncthreads(); - - // We accumulate the token counts of all experts in thread 0. - if (threadIdx.x == 0) { - cumsum[0] = 0; - for (int i = 1; i <= num_experts; ++i) { - cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size; - } - *total_tokens_post_pad = cumsum[num_experts]; - } - - __syncthreads(); - - /** - * For each expert, each thread processes the tokens of the corresponding blocks - * and stores the corresponding expert_id for each block. - */ - for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) { - expert_ids[i / block_size] = threadIdx.x; +__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, + int32_t* sorted_token_ids, + int32_t* expert_ids, + int32_t* total_tokens_post_pad, + int32_t num_experts, + int32_t block_size, size_t numel) { + const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); + const size_t start_idx = threadIdx.x * tokens_per_thread; + + extern __shared__ int32_t shared_mem[]; + + int32_t* tokens_cnts = + shared_mem; // 2d tensor with shape (num_experts + 1, num_experts) + int32_t* cumsum = + shared_mem + (num_experts + 1) * + num_experts; // 1d tensor with shape (num_experts + 1) + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; + } + + /** + * In the first step we compute token_cnts[thread_index + 1][expert_index], + * which counts how many tokens in the token shard of thread_index are + * assigned to expert expert_index. + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; + } + + __syncthreads(); + + // For each expert we accumulate the token counts from the different threads. + tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; + for (int i = 1; i <= blockDim.x; ++i) { + tokens_cnts[index(num_experts, i, threadIdx.x)] += + tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; + } + + __syncthreads(); + + // We accumulate the token counts of all experts in thread 0. + if (threadIdx.x == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = cumsum[i - 1] + + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], + block_size) * + block_size; } - - /** - * Each thread processes a token shard, calculating the index of each token after - * sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and - * block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], - * where * represents a padding value(preset in python). - */ - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - int32_t expert_id = topk_ids[i]; - /** The cumsum[expert_id] stores the starting index of the tokens that the - * expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id] - * stores the indices of the tokens processed by the expert with expert_id within - * the current thread's token shard. - */ - int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id]; - sorted_token_ids[rank_post_pad] = i; - ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; - } -} + *total_tokens_post_pad = cumsum[num_experts]; + } + + __syncthreads(); + + /** + * For each expert, each thread processes the tokens of the corresponding + * blocks and stores the corresponding expert_id for each block. + */ + for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; + i += block_size) { + expert_ids[i / block_size] = threadIdx.x; + } + + /** + * Each thread processes a token shard, calculating the index of each token + * after sorting by expert number. Given the example topk_ids = + * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *, + * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a + * padding value(preset in python). + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + int32_t expert_id = topk_ids[i]; + /** The cumsum[expert_id] stores the starting index of the tokens that the + * expert with expert_id needs to process, and + * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens + * processed by the expert with expert_id within the current thread's token + * shard. + */ + int32_t rank_post_pad = + tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + + cumsum[expert_id]; + sorted_token_ids[rank_post_pad] = i; + ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; + } } - -void moe_align_block_size( - torch::Tensor topk_ids, - int num_experts, - int block_size, - torch::Tensor sorted_token_ids, - torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad) { - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_INTEGRAL_TYPES( - topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { - // calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors - const int32_t shared_mem = ((num_experts + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); +} // namespace vllm + +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, + int64_t block_size, torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad) { + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_INTEGRAL_TYPES( + topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { + // calc needed amount of shared mem for `tokens_cnts` and `cumsum` + // tensors + const int32_t shared_mem = + ((num_experts + 1) * num_experts + (num_experts + 1)) * + sizeof(int32_t); // set dynamic shared mem auto kernel = vllm::moe_align_block_size_kernel; - AT_CUDA_CHECK( - VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize((void *)kernel, shared_mem)); + AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( + (void*)kernel, shared_mem)); kernel<<<1, num_experts, shared_mem, stream>>>( - topk_ids.data_ptr(), - sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), - num_experts, - block_size, + topk_ids.data_ptr(), sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), num_experts, block_size, topk_ids.numel()); - }); + }); } diff --git a/csrc/ops.h b/csrc/ops.h index 9541adcb3de88..8a92afdc81a9b 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -1,206 +1,152 @@ #pragma once -#include +#include +#include void paged_attention_v1( - torch::Tensor& out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& seq_lens, - int block_size, - int max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, - float kv_scale); + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step); void paged_attention_v2( - torch::Tensor& out, - torch::Tensor& exp_sums, - torch::Tensor& max_logits, - torch::Tensor& tmp_out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& seq_lens, - int block_size, - int max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, - float kv_scale); - -void rms_norm( - torch::Tensor& out, - torch::Tensor& input, - torch::Tensor& weight, - float epsilon); - -void fused_add_rms_norm( - torch::Tensor& input, - torch::Tensor& residual, - torch::Tensor& weight, - float epsilon); - -void rotary_embedding( - torch::Tensor& positions, - torch::Tensor& query, - torch::Tensor& key, - int head_size, - torch::Tensor& cos_sin_cache, - bool is_neox); - -void batched_rotary_embedding( - torch::Tensor& positions, - torch::Tensor& query, - torch::Tensor& key, - int head_size, - torch::Tensor& cos_sin_cache, - bool is_neox, - int rot_dim, - torch::Tensor& cos_sin_cache_offsets); - -void silu_and_mul( - torch::Tensor& out, - torch::Tensor& input); - -void gelu_and_mul( - torch::Tensor& out, - torch::Tensor& input); - -void gelu_tanh_and_mul( - torch::Tensor& out, - torch::Tensor& input); - -void gelu_new( - torch::Tensor& out, - torch::Tensor& input); - -void gelu_fast( - torch::Tensor& out, - torch::Tensor& input); + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step); + +void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, + double epsilon); + +void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, + torch::Tensor& weight, double epsilon); + +void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, + torch::Tensor& key, int64_t head_size, + torch::Tensor& cos_sin_cache, bool is_neox); + +void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, + torch::Tensor& key, int64_t head_size, + torch::Tensor& cos_sin_cache, bool is_neox, + int64_t rot_dim, + torch::Tensor& cos_sin_cache_offsets); + +void silu_and_mul(torch::Tensor& out, torch::Tensor& input); + +void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); + +void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); + +void gelu_new(torch::Tensor& out, torch::Tensor& input); + +void gelu_fast(torch::Tensor& out, torch::Tensor& input); + +void gelu_quick(torch::Tensor& out, torch::Tensor& input); #ifndef USE_ROCM -torch::Tensor aqlm_gemm( - const torch::Tensor& input, - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& scales, - const torch::Tensor& codebook_partition_sizes, - const std::optional& bias -); - -torch::Tensor aqlm_dequant( - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& codebook_partition_sizes -); - -torch::Tensor awq_gemm( - torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters); - -torch::Tensor awq_dequantize( - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters, - int thx, - int thy); - -torch::Tensor marlin_gemm( - torch::Tensor& a, - torch::Tensor& b_q_weight, - torch::Tensor& b_scales, - torch::Tensor& workspace, - int64_t size_m, - int64_t size_n, - int64_t size_k); - -torch::Tensor gptq_marlin_gemm( - torch::Tensor &a, - torch::Tensor &b_q_weight, - torch::Tensor &b_scales, - torch::Tensor &g_idx, - torch::Tensor &perm, - torch::Tensor &workspace, - int64_t num_bits, - int64_t size_m, - int64_t size_n, - int64_t size_k, - bool is_k_full); - -torch::Tensor gptq_marlin_repack( - torch::Tensor &b_q_weight, - torch::Tensor &perm, - int64_t size_k, - int64_t size_n, - int64_t num_bits); +torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const torch::Tensor& codebook_partition_sizes, + const std::optional& bias); + +torch::Tensor aqlm_dequant(const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& codebook_partition_sizes); + +torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, + torch::Tensor _scaling_factors, torch::Tensor _zeros, + int64_t split_k_iters); + +torch::Tensor awq_dequantize(torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, int64_t split_k_iters, + int64_t thx, int64_t thy); + +torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& workspace, + int64_t size_m, int64_t size_n, int64_t size_k); + +torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_meta, + torch::Tensor& b_scales, + torch::Tensor& workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, + int64_t size_k); + +torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& g_idx, + torch::Tensor& perm, torch::Tensor& workspace, + int64_t num_bits, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full); + +torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, + int64_t size_k, int64_t size_n, + int64_t num_bits); + +bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); + +void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + c10::optional const& bias); + #endif -void squeezellm_gemm( - torch::Tensor vec, - torch::Tensor mat, - torch::Tensor mul, - torch::Tensor lookup_table); - -torch::Tensor gptq_gemm( - torch::Tensor a, - torch::Tensor b_q_weight, - torch::Tensor b_gptq_qzeros, - torch::Tensor b_gptq_scales, - torch::Tensor b_g_idx, - bool use_exllama, - int bit); - -void gptq_shuffle( - torch::Tensor q_weight, - torch::Tensor q_perm, - int bit); - -void static_scaled_fp8_quant( - torch::Tensor& out, - torch::Tensor& input, - torch::Tensor& scale); - -void dynamic_scaled_fp8_quant( - torch::Tensor& out, - torch::Tensor& input, - torch::Tensor& scale); - -void moe_align_block_size( - torch::Tensor topk_ids, - int num_experts, - int block_size, - torch::Tensor sorted_token_ids, - torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad); +void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, + torch::Tensor const& scale); + +void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, + torch::Tensor& scales); + +void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor lookup_table); + +torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, + bool use_exllama, int64_t bit); + +void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); + +void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, + torch::Tensor& scale); + +void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, + torch::Tensor& scale); + +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, + int64_t block_size, torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad); #ifndef USE_ROCM -using fptr_t = uint64_t; -fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, - const std::vector &handles, - const std::vector &offsets, int rank, - bool full_nvlink); -bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, +using fptr_t = int64_t; +fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, + const std::vector& handles, + const std::vector& offsets, int64_t rank, + bool full_nvlink); +bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size, bool full_nvlink); -void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out); -void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, - torch::Tensor &out); +void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); +void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, + torch::Tensor& out); void dispose(fptr_t _fa); -int meta_size(); -void register_buffer(fptr_t _fa, torch::Tensor &t, - const std::vector &handles, - const std::vector &offsets); -std::pair, std::vector> get_graph_buffer_ipc_meta(fptr_t _fa); -void register_graph_buffers(fptr_t _fa, const std::vector &handles, - const std::vector> &offsets); +int64_t meta_size(); +void register_buffer(fptr_t _fa, torch::Tensor& t, + const std::vector& handles, + const std::vector& offsets); +std::tuple> get_graph_buffer_ipc_meta( + fptr_t _fa); +void register_graph_buffers(fptr_t _fa, const std::vector& handles, + const std::vector>& offsets); #endif diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index d80cb6973fad6..97184a8735593 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -7,14 +7,10 @@ namespace vllm { -template +template inline __device__ void apply_token_rotary_embedding( - scalar_t* __restrict__ arr, - const scalar_t* __restrict__ cos_ptr, - const scalar_t* __restrict__ sin_ptr, - int rot_offset, - int embed_dim) -{ + scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr, + const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) { int x_index, y_index; scalar_t cos, sin; if (IS_NEOX) { @@ -37,19 +33,17 @@ inline __device__ void apply_token_rotary_embedding( arr[y_index] = y * cos + x * sin; } -template +template inline __device__ void apply_rotary_embedding( - scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] - const scalar_t* cache_ptr, - const int head_size, - const int num_heads, - const int num_kv_heads, - const int rot_dim, - const int token_idx, - const int64_t query_stride, - const int64_t key_stride) -{ + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* cache_ptr, const int head_size, const int num_heads, + const int num_kv_heads, const int rot_dim, const int token_idx, + const int64_t query_stride, const int64_t key_stride) { const int embed_dim = rot_dim / 2; const scalar_t* cos_ptr = cache_ptr; const scalar_t* sin_ptr = cache_ptr + embed_dim; @@ -59,8 +53,8 @@ inline __device__ void apply_rotary_embedding( const int head_idx = i / embed_dim; const int64_t token_head = token_idx * query_stride + head_idx * head_size; const int rot_offset = i % embed_dim; - apply_token_rotary_embedding(query + token_head, cos_ptr, - sin_ptr, rot_offset, embed_dim); + apply_token_rotary_embedding( + query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); } const int nk = num_kv_heads * embed_dim; @@ -68,62 +62,74 @@ inline __device__ void apply_rotary_embedding( const int head_idx = i / embed_dim; const int64_t token_head = token_idx * key_stride + head_idx * head_size; const int rot_offset = i % embed_dim; - apply_token_rotary_embedding(key + token_head, cos_ptr, - sin_ptr, rot_offset, embed_dim); + apply_token_rotary_embedding( + key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); } } -template +template __global__ void rotary_embedding_kernel( - const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] - scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] - const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] - const int rot_dim, - const int64_t query_stride, - const int64_t key_stride, - const int num_heads, - const int num_kv_heads, - const int head_size) { + const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] + const int rot_dim, const int64_t query_stride, const int64_t key_stride, + const int num_heads, const int num_kv_heads, const int head_size) { // Each thread block is responsible for one token. const int token_idx = blockIdx.x; int64_t pos = positions[token_idx]; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; - apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); + apply_rotary_embedding( + query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, + token_idx, query_stride, key_stride); } -template +template __global__ void batched_rotary_embedding_kernel( - const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] - scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] - const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] - const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens] - const int rot_dim, - const int64_t query_stride, - const int64_t key_stride, - const int num_heads, - const int num_kv_heads, - const int head_size) { + const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] + const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] + // or [num_tokens] + const int rot_dim, const int64_t query_stride, const int64_t key_stride, + const int num_heads, const int num_kv_heads, const int head_size) { // Each thread block is responsible for one token. const int token_idx = blockIdx.x; int64_t pos = positions[token_idx]; int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx]; - const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; + const scalar_t* cache_ptr = + cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; - apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); + apply_rotary_embedding( + query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, + token_idx, query_stride, key_stride); } -} // namespace vllm +} // namespace vllm void rotary_embedding( - torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] - torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size] - torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size] - int head_size, - torch::Tensor& cos_sin_cache, // [max_position, rot_dim] - bool is_neox) { + torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] + torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or + // [num_tokens, num_heads * head_size] + torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or + // [num_tokens, num_kv_heads * head_size] + int64_t head_size, + torch::Tensor& cos_sin_cache, // [max_position, rot_dim] + bool is_neox) { int64_t num_tokens = query.numel() / query.size(-1); int rot_dim = cos_sin_cache.size(1); int num_heads = query.size(-1) / head_size; @@ -132,39 +138,24 @@ void rotary_embedding( int64_t key_stride = key.stride(-2); dim3 grid(num_tokens); - dim3 block(std::min(num_heads * rot_dim / 2, 512)); + dim3 block(std::min(num_heads * rot_dim / 2, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - query.scalar_type(), - "rotary_embedding", - [&] { - if (is_neox) { - vllm::rotary_embedding_kernel<<>>( - positions.data_ptr(), - query.data_ptr(), - key.data_ptr(), - cos_sin_cache.data_ptr(), - rot_dim, - query_stride, - key_stride, - num_heads, - num_kv_heads, - head_size); - } else { - vllm::rotary_embedding_kernel<<>>( - positions.data_ptr(), - query.data_ptr(), - key.data_ptr(), - cos_sin_cache.data_ptr(), - rot_dim, - query_stride, - key_stride, - num_heads, - num_kv_heads, - head_size); - } - }); + VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { + if (is_neox) { + vllm::rotary_embedding_kernel<<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), rot_dim, + query_stride, key_stride, num_heads, num_kv_heads, head_size); + } else { + vllm::rotary_embedding_kernel + <<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + rot_dim, query_stride, key_stride, num_heads, num_kv_heads, + head_size); + } + }); } /* @@ -172,14 +163,15 @@ Batched version of rotary embedding, pack multiple LoRAs together and process in batched manner. */ void batched_rotary_embedding( - torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] - torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size] - torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size] - int head_size, - torch::Tensor& cos_sin_cache, // [max_position, rot_dim] - bool is_neox, - int rot_dim, - torch::Tensor& cos_sin_cache_offsets // [num_tokens] + torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] + torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or + // [num_tokens, num_heads * head_size] + torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or + // [num_tokens, num_kv_heads * head_size] + int64_t head_size, + torch::Tensor& cos_sin_cache, // [max_position, rot_dim] + bool is_neox, int64_t rot_dim, + torch::Tensor& cos_sin_cache_offsets // [num_tokens] ) { int64_t num_tokens = cos_sin_cache_offsets.size(0); int num_heads = query.size(-1) / head_size; @@ -188,39 +180,24 @@ void batched_rotary_embedding( int64_t key_stride = key.stride(-2); dim3 grid(num_tokens); - dim3 block(std::min(num_heads * rot_dim / 2, 512)); + dim3 block(std::min(num_heads * rot_dim / 2, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - query.scalar_type(), - "rotary_embedding", - [&] { - if (is_neox) { - vllm::batched_rotary_embedding_kernel<<>>( - positions.data_ptr(), - query.data_ptr(), - key.data_ptr(), - cos_sin_cache.data_ptr(), - cos_sin_cache_offsets.data_ptr(), - rot_dim, - query_stride, - key_stride, - num_heads, - num_kv_heads, - head_size); - } else { - vllm::batched_rotary_embedding_kernel<<>>( - positions.data_ptr(), - query.data_ptr(), - key.data_ptr(), - cos_sin_cache.data_ptr(), - cos_sin_cache_offsets.data_ptr(), - rot_dim, - query_stride, - key_stride, - num_heads, - num_kv_heads, - head_size); - } - }); + VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { + if (is_neox) { + vllm::batched_rotary_embedding_kernel + <<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, + key_stride, num_heads, num_kv_heads, head_size); + } else { + vllm::batched_rotary_embedding_kernel + <<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, + key_stride, num_heads, num_kv_heads, head_size); + } + }); } diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index 19c058cacfbc4..2c8d007d8719f 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -16,44 +16,68 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 512) \ f(in_T, out_T, W_T, narrow, 640) \ f(in_T, out_T, W_T, narrow, 768) \ + f(in_T, out_T, W_T, narrow, 896) \ f(in_T, out_T, W_T, narrow, 1024) \ f(in_T, out_T, W_T, narrow, 1152) \ + f(in_T, out_T, W_T, narrow, 1216) \ f(in_T, out_T, W_T, narrow, 1280) \ f(in_T, out_T, W_T, narrow, 1536) \ + f(in_T, out_T, W_T, narrow, 1664) \ f(in_T, out_T, W_T, narrow, 1728) \ f(in_T, out_T, W_T, narrow, 1792) \ f(in_T, out_T, W_T, narrow, 2048) \ + f(in_T, out_T, W_T, narrow, 2240) \ f(in_T, out_T, W_T, narrow, 2304) \ + f(in_T, out_T, W_T, narrow, 2368) \ + f(in_T, out_T, W_T, narrow, 2432) \ f(in_T, out_T, W_T, narrow, 2560) \ f(in_T, out_T, W_T, narrow, 2752) \ f(in_T, out_T, W_T, narrow, 2816) \ f(in_T, out_T, W_T, narrow, 3072) \ + f(in_T, out_T, W_T, narrow, 3328) \ f(in_T, out_T, W_T, narrow, 3456) \ f(in_T, out_T, W_T, narrow, 3584) \ + f(in_T, out_T, W_T, narrow, 3712) \ f(in_T, out_T, W_T, narrow, 4096) \ + f(in_T, out_T, W_T, narrow, 4480) \ f(in_T, out_T, W_T, narrow, 4608) \ + f(in_T, out_T, W_T, narrow, 4736) \ + f(in_T, out_T, W_T, narrow, 4864) \ f(in_T, out_T, W_T, narrow, 5120) \ f(in_T, out_T, W_T, narrow, 5504) \ f(in_T, out_T, W_T, narrow, 5632) \ + f(in_T, out_T, W_T, narrow, 5888) \ f(in_T, out_T, W_T, narrow, 6144) \ + f(in_T, out_T, W_T, narrow, 6400) \ f(in_T, out_T, W_T, narrow, 6848) \ f(in_T, out_T, W_T, narrow, 6912) \ f(in_T, out_T, W_T, narrow, 7168) \ + f(in_T, out_T, W_T, narrow, 7424) \ f(in_T, out_T, W_T, narrow, 8192) \ + f(in_T, out_T, W_T, narrow, 8960) \ f(in_T, out_T, W_T, narrow, 9216) \ + f(in_T, out_T, W_T, narrow, 9472) \ f(in_T, out_T, W_T, narrow, 10240) \ f(in_T, out_T, W_T, narrow, 11008) \ + f(in_T, out_T, W_T, narrow, 11264) \ f(in_T, out_T, W_T, narrow, 12288) \ f(in_T, out_T, W_T, narrow, 13696) \ f(in_T, out_T, W_T, narrow, 13824) \ f(in_T, out_T, W_T, narrow, 14336) \ + f(in_T, out_T, W_T, narrow, 14784) \ + f(in_T, out_T, W_T, narrow, 14848) \ f(in_T, out_T, W_T, narrow, 15360) \ f(in_T, out_T, W_T, narrow, 16384) \ + f(in_T, out_T, W_T, narrow, 18944) \ f(in_T, out_T, W_T, narrow, 20480) \ f(in_T, out_T, W_T, narrow, 22016) \ + f(in_T, out_T, W_T, narrow, 22528) \ f(in_T, out_T, W_T, narrow, 24576) \ f(in_T, out_T, W_T, narrow, 27392) \ + f(in_T, out_T, W_T, narrow, 27648) \ f(in_T, out_T, W_T, narrow, 28672) \ + f(in_T, out_T, W_T, narrow, 29568) \ + f(in_T, out_T, W_T, narrow, 29696) \ f(in_T, out_T, W_T, narrow, 32000) \ f(in_T, out_T, W_T, narrow, 32256) \ f(in_T, out_T, W_T, narrow, 32512) \ @@ -62,6 +86,9 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 36864) \ f(in_T, out_T, W_T, narrow, 43264) \ f(in_T, out_T, W_T, narrow, 49152) \ + f(in_T, out_T, W_T, narrow, 49408) \ + f(in_T, out_T, W_T, narrow, 60544) \ + f(in_T, out_T, W_T, narrow, 60672) \ f(in_T, out_T, W_T, narrow, 64000) \ f(in_T, out_T, W_T, narrow, 64256) \ f(in_T, out_T, W_T, narrow, 64512) \ @@ -71,12 +98,14 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 128000) \ f(in_T, out_T, W_T, narrow, 128256) \ f(in_T, out_T, W_T, narrow, 128512) \ + + // Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA // and vllm/tests/lora/test_punica.py -// Used for defining kernels going from the variety of +// Used for defining kernels going from the variety of // dim in to the narrow dim out - // Using it for the fully sharded column + // Using it for the fully sharded column // parallel LoRA A which splits the rank dim #define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \ f(in_T, out_T, W_T, 128, narrow) \ @@ -84,44 +113,68 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 512, narrow) \ f(in_T, out_T, W_T, 640, narrow) \ f(in_T, out_T, W_T, 768, narrow) \ + f(in_T, out_T, W_T, 896, narrow) \ f(in_T, out_T, W_T, 1024, narrow) \ f(in_T, out_T, W_T, 1152, narrow) \ + f(in_T, out_T, W_T, 1216, narrow) \ f(in_T, out_T, W_T, 1280, narrow) \ f(in_T, out_T, W_T, 1536, narrow) \ + f(in_T, out_T, W_T, 1664, narrow) \ f(in_T, out_T, W_T, 1728, narrow) \ f(in_T, out_T, W_T, 1792, narrow) \ f(in_T, out_T, W_T, 2048, narrow) \ + f(in_T, out_T, W_T, 2240, narrow) \ f(in_T, out_T, W_T, 2304, narrow) \ + f(in_T, out_T, W_T, 2368, narrow) \ + f(in_T, out_T, W_T, 2432, narrow) \ f(in_T, out_T, W_T, 2560, narrow) \ f(in_T, out_T, W_T, 2752, narrow) \ f(in_T, out_T, W_T, 2816, narrow) \ f(in_T, out_T, W_T, 3072, narrow) \ + f(in_T, out_T, W_T, 3328, narrow) \ f(in_T, out_T, W_T, 3456, narrow) \ f(in_T, out_T, W_T, 3584, narrow) \ + f(in_T, out_T, W_T, 3712, narrow) \ f(in_T, out_T, W_T, 4096, narrow) \ + f(in_T, out_T, W_T, 4480, narrow) \ f(in_T, out_T, W_T, 4608, narrow) \ + f(in_T, out_T, W_T, 4736, narrow) \ + f(in_T, out_T, W_T, 4864, narrow) \ f(in_T, out_T, W_T, 5120, narrow) \ f(in_T, out_T, W_T, 5504, narrow) \ f(in_T, out_T, W_T, 5632, narrow) \ + f(in_T, out_T, W_T, 5888, narrow) \ f(in_T, out_T, W_T, 6144, narrow) \ + f(in_T, out_T, W_T, 6400, narrow) \ f(in_T, out_T, W_T, 6848, narrow) \ f(in_T, out_T, W_T, 6912, narrow) \ f(in_T, out_T, W_T, 7168, narrow) \ + f(in_T, out_T, W_T, 7424, narrow) \ f(in_T, out_T, W_T, 8192, narrow) \ + f(in_T, out_T, W_T, 8960, narrow) \ f(in_T, out_T, W_T, 9216, narrow) \ + f(in_T, out_T, W_T, 9472, narrow) \ f(in_T, out_T, W_T, 10240, narrow) \ f(in_T, out_T, W_T, 11008, narrow) \ + f(in_T, out_T, W_T, 11264, narrow) \ f(in_T, out_T, W_T, 12288, narrow) \ f(in_T, out_T, W_T, 13696, narrow) \ f(in_T, out_T, W_T, 13824, narrow) \ f(in_T, out_T, W_T, 14336, narrow) \ + f(in_T, out_T, W_T, 14784, narrow) \ + f(in_T, out_T, W_T, 14848, narrow) \ f(in_T, out_T, W_T, 15360, narrow) \ f(in_T, out_T, W_T, 16384, narrow) \ + f(in_T, out_T, W_T, 18944, narrow) \ f(in_T, out_T, W_T, 20480, narrow) \ f(in_T, out_T, W_T, 22016, narrow) \ + f(in_T, out_T, W_T, 22528, narrow) \ f(in_T, out_T, W_T, 24576, narrow) \ f(in_T, out_T, W_T, 27392, narrow) \ + f(in_T, out_T, W_T, 27648, narrow) \ f(in_T, out_T, W_T, 28672, narrow) \ + f(in_T, out_T, W_T, 29568, narrow) \ + f(in_T, out_T, W_T, 29696, narrow) \ f(in_T, out_T, W_T, 32000, narrow) \ f(in_T, out_T, W_T, 32256, narrow) \ f(in_T, out_T, W_T, 32512, narrow) \ @@ -130,6 +183,9 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 36864, narrow) \ f(in_T, out_T, W_T, 43264, narrow) \ f(in_T, out_T, W_T, 49152, narrow) \ + f(in_T, out_T, W_T, 49408, narrow) \ + f(in_T, out_T, W_T, 60544, narrow) \ + f(in_T, out_T, W_T, 60672, narrow) \ f(in_T, out_T, W_T, 64000, narrow) \ f(in_T, out_T, W_T, 64256, narrow) \ f(in_T, out_T, W_T, 64512, narrow) \ diff --git a/csrc/punica/bgmv/bgmv_impl.cuh b/csrc/punica/bgmv/bgmv_impl.cuh index dad8805c750cb..8a3b8403b4a6f 100644 --- a/csrc/punica/bgmv/bgmv_impl.cuh +++ b/csrc/punica/bgmv/bgmv_impl.cuh @@ -1,8 +1,14 @@ #pragma once #include +#ifndef USE_ROCM #include +#else +#include +#endif +#ifndef USE_ROCM #include +#endif #include #include #include @@ -11,6 +17,24 @@ namespace cg = cooperative_groups; +#ifdef USE_ROCM +template +__host__ __device__ +inline void* memcpy_blocking(void *dst, const void *src) { + // Does not handle the case of long datatypes + char *d = reinterpret_cast(dst); + const char *s = reinterpret_cast(src); + size_t i = 0; +#pragma unroll + for (i = 0; i < len; ++i) { + d[i] = s[i]; + } + return dst; +} +#endif + +#ifndef USE_ROCM + // nthrs = (32, 4) template +__global__ void +bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, + const W_T *__restrict__ W, + const int64_t *__restrict__ indicies, int64_t y_offset, + int64_t full_y_size, int64_t num_layers, int64_t layer_idx, + float scale) { + size_t batch_idx = blockIdx.y; + int64_t idx = indicies[batch_idx] * num_layers + layer_idx; + if (idx < 0) { + return; + } + + size_t j = blockIdx.x; + constexpr size_t tile_size = tx * ty * vec_size; + constexpr size_t num_tiles = (feat_in + tile_size - 1) / tile_size; + __shared__ float y_warpwise[ty]; + + float y = 0; + vec_t x_vec; + vec_t w_vec; + size_t tile_idx; + +#pragma unroll + for (tile_idx = 0; tile_idx < num_tiles; ++tile_idx) { + if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) { + x_vec.load(X + (batch_idx * feat_in) + + tile_idx * tile_size + + (threadIdx.y * tx + threadIdx.x) * vec_size); + w_vec.load(W + (idx * feat_out + j) * feat_in + + tile_idx * tile_size + + (threadIdx.y * tx + threadIdx.x) * vec_size); + } + + float sum = 0.f; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + sum += convert_type(w_vec[i]) * convert_type(x_vec[i]) * scale; + } +#pragma unroll + for (size_t offset = tx / 2; offset > 0; offset /= 2) { + sum += VLLM_SHFL_DOWN_SYNC(sum, offset); + } + + __syncthreads(); + + if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) { + y += sum; + } + } + + if (threadIdx.x == 0) { + y_warpwise[threadIdx.y] = y; + } + __syncthreads(); + + float y_write = 0.f; +#pragma unroll + for (size_t i = 0; i < ty; ++i) { + y_write += y_warpwise[i]; + } + + // write Y; + if (threadIdx.x == 0 && threadIdx.y == 0) { + size_t y_idx = batch_idx * full_y_size + y_offset + j; + Y[y_idx] = vllm_add(Y[y_idx], convert_type(y_write)); + } +} + +#endif + // nthrs = (2, 16, 4) template @@ -172,7 +271,11 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, float sum = 0.f; #pragma unroll for (size_t i = 0; i < vec_size; ++i) { +#ifndef USE_ROCM sum += float(w_vec[i]) * float(x_vec[i]) * scale; +#else + sum += convert_type(w_vec[i]) * convert_type(x_vec[i]) * scale; +#endif } cg::thread_block_tile g = cg::tiled_partition(block); @@ -183,8 +286,14 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, sum = g.shfl(sum, 0); if (threadIdx.x == 0) { +#ifndef USE_ROCM Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + threadIdx.z * ty + threadIdx.y] += static_cast(sum); +#else + size_t y_idx = batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + + threadIdx.z * ty + threadIdx.y; + Y[y_idx] = vllm_add(Y[y_idx], convert_type(sum)); +#endif } } @@ -236,6 +345,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, scale); } } else { +#ifndef USE_ROCM static_assert(feat_in % (vec_size * 32) == 0 || feat_in % (vec_size * 16) == 0 || feat_in % (vec_size * 8) == 0); @@ -279,6 +389,50 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, full_y_size, num_layers, layer_idx, scale); } +#else + constexpr size_t rocm_warp_size = warpSize; + +#define CHECK_INPUT_TILEABLE_BY(vec_size_) \ + feat_in % (rocm_warp_size * vec_size_) == 0 + +#define LAUNCH_BGMV_SHRINK_KERNELS_ROCM(factor_, vec_size_, tx_, ty_) \ + if constexpr (CHECK_INPUT_TILEABLE_BY(factor_)) { \ + constexpr size_t vec_size_shrink = vec_size_; \ + constexpr int tx = tx_; \ + constexpr int ty = ty_; \ + dim3 nblks(feat_out, batch_size); \ + dim3 nthrs(tx, ty); \ + bgmv_shrink_kernel \ + <<>>(Y, X, W, indicies, y_offset, \ + full_y_size, num_layers, layer_idx, \ + scale); \ + } + + static_assert(CHECK_INPUT_TILEABLE_BY(32) || + CHECK_INPUT_TILEABLE_BY(16) || + CHECK_INPUT_TILEABLE_BY( 8) || + CHECK_INPUT_TILEABLE_BY( 4) || + CHECK_INPUT_TILEABLE_BY( 2) || + CHECK_INPUT_TILEABLE_BY( 1)); + + LAUNCH_BGMV_SHRINK_KERNELS_ROCM(32, vec_size, rocm_warp_size, 32/vec_size) + else + LAUNCH_BGMV_SHRINK_KERNELS_ROCM(16, vec_size, rocm_warp_size, 16/vec_size) + else + LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 8, vec_size, rocm_warp_size, 8/vec_size) + else + LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 4, vec_size, rocm_warp_size/(vec_size/4), vec_size/4) + else + LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 2, vec_size, rocm_warp_size/(vec_size/2), vec_size/2) + else + LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 1, vec_size, rocm_warp_size/(vec_size/1), vec_size/1) + +#undef CHECK_INPUT_TILEABLE_BY +#undef LAUNCH_BGMV_SHRINK_KERNELS_ROCM +#endif } } diff --git a/csrc/punica/bgmv/vec_dtypes.cuh b/csrc/punica/bgmv/vec_dtypes.cuh index cf00d869cf635..2738892e6dc4a 100644 --- a/csrc/punica/bgmv/vec_dtypes.cuh +++ b/csrc/punica/bgmv/vec_dtypes.cuh @@ -1,8 +1,6 @@ #ifndef VEC_DTYPES_CUH_ #define VEC_DTYPES_CUH_ -#include -#include #ifdef FLASHINFER_USE_FP8 #include #endif @@ -10,6 +8,9 @@ #include +#include "../type_convert.h" +#include "../../cuda_compat.h" + #define FLASHINFER_INLINE \ inline __attribute__((always_inline)) __device__ __host__ diff --git a/csrc/punica/punica_ops.cc b/csrc/punica/punica_ops.cu similarity index 97% rename from csrc/punica/punica_ops.cc rename to csrc/punica/punica_ops.cu index 8797fde85744a..dd29820144b34 100644 --- a/csrc/punica/punica_ops.cc +++ b/csrc/punica/punica_ops.cu @@ -1,12 +1,11 @@ -#include -#include -#include +#include #include #include +#include "type_convert.h" +#include "../cuda_compat.h" #include "bgmv/bgmv_config.h" -namespace { //====== utils ====== @@ -89,7 +88,7 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W, } void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, - torch::Tensor indicies, int64_t layer_idx, float scale) { + torch::Tensor indicies, int64_t layer_idx, double scale) { CHECK_INPUT(y); CHECK_INPUT(x); CHECK_INPUT(w); @@ -321,7 +320,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, torch::Tensor indicies, int64_t layer_idx, - float scale, int64_t h_in, int64_t h_out, + double scale, int64_t h_in, int64_t h_out, int64_t y_offset) { CHECK_INPUT(y); CHECK_INPUT(x); @@ -568,15 +567,3 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type()); } - -} // namespace - -//====== pybind ====== - -#define DEFINE_pybind(name) m.def(#name, &name, #name); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv"); - m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level, - "dispatch_bgmv_low_level"); -} diff --git a/csrc/punica/punica_ops.h b/csrc/punica/punica_ops.h new file mode 100644 index 0000000000000..5d625d0564f75 --- /dev/null +++ b/csrc/punica/punica_ops.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, + torch::Tensor indicies, int64_t layer_idx, double scale); + +void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, + torch::Tensor indicies, int64_t layer_idx, + double scale, int64_t h_in, int64_t h_out, + int64_t y_offset); diff --git a/csrc/punica/torch_bindings.cpp b/csrc/punica/torch_bindings.cpp new file mode 100644 index 0000000000000..894e229b6d9db --- /dev/null +++ b/csrc/punica/torch_bindings.cpp @@ -0,0 +1,18 @@ +#include "registration.h" +#include "punica_ops.h" + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { + m.def( + "dispatch_bgmv(Tensor! y, Tensor x, Tensor w, Tensor indicies, int " + "layer_idx, float scale) -> ()"); + m.impl("dispatch_bgmv", torch::kCUDA, &dispatch_bgmv); + + m.def( + "dispatch_bgmv_low_level(Tensor! y, Tensor x, Tensor w," + "Tensor indicies, int layer_idx," + "float scale, int h_in, int h_out," + "int y_offset) -> ()"); + m.impl("dispatch_bgmv_low_level", torch::kCUDA, &dispatch_bgmv_low_level); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/punica/type_convert.h b/csrc/punica/type_convert.h new file mode 100644 index 0000000000000..dff7ce49283d7 --- /dev/null +++ b/csrc/punica/type_convert.h @@ -0,0 +1,82 @@ +#ifndef CSRC__PUNICA__TYPE_CONVERT_H__ +#define CSRC__PUNICA__TYPE_CONVERT_H__ + +#ifndef USE_ROCM + +#include +#include + +#else + +#include +#include + +#define __TYPE_CONVERT__HOST_DEVICE__ __host__ __device__ + +typedef __half nv_half; +typedef __hip_bfloat16 nv_bfloat16; +typedef __hip_bfloat162 nv_bfloat162; + +__TYPE_CONVERT__HOST_DEVICE__ +inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 val) { + return __hip_bfloat162{val, val}; +} + +__TYPE_CONVERT__HOST_DEVICE__ +inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 vall, __hip_bfloat16 valr) { + return __hip_bfloat162{vall, valr}; +} + +template +__TYPE_CONVERT__HOST_DEVICE__ +inline T_dst convert_type(T_src val) { + return static_cast(val); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline float convert_type<__half, float>(__half val) { + return __half2float(val); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline __half convert_type(float val) { + return __float2half(val); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline float convert_type<__hip_bfloat16, float>(__hip_bfloat16 val) { + return __bfloat162float(val); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline __hip_bfloat16 convert_type(float val) { + return __float2bfloat16(val); +} + +template +__TYPE_CONVERT__HOST_DEVICE__ +inline T vllm_add(T a, T b) { + return a + b; +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline __half vllm_add<__half>(__half a, __half b) { + return __hadd(a, b); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline __hip_bfloat16 vllm_add<__hip_bfloat16>(__hip_bfloat16 a, __hip_bfloat16 b) { + return __hadd(a, b); +} + +#undef __TYPE_CONVERT__HOST_DEVICE__ + +#endif // USE_ROCM + +#endif // CSRC__PUNICA__TYPE_CONVERT_H__ diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp deleted file mode 100644 index 173e0b1732e13..0000000000000 --- a/csrc/pybind.cpp +++ /dev/null @@ -1,136 +0,0 @@ -#include "cache.h" -#include "cuda_utils.h" -#include "ops.h" -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - // vLLM custom ops - pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); - - // Attention ops - ops.def( - "paged_attention_v1", - &paged_attention_v1, - "Compute the attention between an input query and the cached keys/values using PagedAttention."); - ops.def( - "paged_attention_v2", - &paged_attention_v2, - "PagedAttention V2."); - - // Activation ops - ops.def( - "silu_and_mul", - &silu_and_mul, - "Activation function used in SwiGLU."); - ops.def( - "gelu_and_mul", - &gelu_and_mul, - "Activation function used in GeGLU with `none` approximation."); - ops.def( - "gelu_tanh_and_mul", - &gelu_tanh_and_mul, - "Activation function used in GeGLU with `tanh` approximation."); - ops.def( - "gelu_new", - &gelu_new, - "GELU implementation used in GPT-2."); - ops.def( - "gelu_fast", - &gelu_fast, - "Approximate GELU implementation."); - - // Layernorm - ops.def( - "rms_norm", - &rms_norm, - "Apply Root Mean Square (RMS) Normalization to the input tensor."); - - ops.def( - "fused_add_rms_norm", - &fused_add_rms_norm, - "In-place fused Add and RMS Normalization"); - - // Rotary embedding - ops.def( - "rotary_embedding", - &rotary_embedding, - "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); - - ops.def( - "batched_rotary_embedding", - &batched_rotary_embedding, - "Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras)"); - -// Quantization ops -#ifndef USE_ROCM - ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM"); - ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM"); - ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); - ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ"); - ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ"); - ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ"); - ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); -#endif - - ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); - ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); - ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); - ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, "Compute FP8 quantized tensor for given scaling factor"); - ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor"); - ops.def( - "moe_align_block_size", - &moe_align_block_size, - "Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."); - - // Cache ops - pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); - cache_ops.def( - "swap_blocks", - &swap_blocks, - "Swap in (out) the cache blocks from src to dst"); - cache_ops.def( - "copy_blocks", - ©_blocks, - "Copy the cache blocks from src to dst"); - cache_ops.def( - "reshape_and_cache", - &reshape_and_cache, - "Reshape the key and value tensors and cache them"); - cache_ops.def( - "reshape_and_cache_flash", - &reshape_and_cache_flash, - "Reshape the key and value tensors and cache them"); - cache_ops.def( - "convert_fp8", - &convert_fp8, - "Convert the key and value cache to fp8 data type"); - - // Cuda utils - pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils"); - cuda_utils.def( - "get_device_attribute", - &get_device_attribute, - "Gets the specified device attribute."); - - cuda_utils.def( - "get_max_shared_memory_per_block_device_attribute", - &get_max_shared_memory_per_block_device_attribute, - "Gets the maximum shared memory per block device attribute."); - -#ifndef USE_ROCM - // Custom all-reduce kernels - pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce"); - custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar"); - custom_ar.def("should_custom_ar", &should_custom_ar, "should_custom_ar"); - custom_ar.def("all_reduce_reg", &all_reduce_reg, "all_reduce_reg"); - custom_ar.def("all_reduce_unreg", &all_reduce_unreg, "all_reduce_unreg"); - custom_ar.def("dispose", &dispose, "dispose"); - custom_ar.def("meta_size", &meta_size, "meta_size"); - custom_ar.def("register_buffer", ®ister_buffer, "register_buffer"); - custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, - "get_graph_buffer_ipc_meta"); - custom_ar.def("register_graph_buffers", ®ister_graph_buffers, - "register_graph_buffers"); -#endif - -} diff --git a/csrc/quantization/aqlm/gemm_kernels.cu b/csrc/quantization/aqlm/gemm_kernels.cu index 4415316e1e8cd..8fb9856800867 100644 --- a/csrc/quantization/aqlm/gemm_kernels.cu +++ b/csrc/quantization/aqlm/gemm_kernels.cu @@ -18,39 +18,35 @@ #include #include #include -#include +#include #include #include #include #include - namespace vllm { namespace aqlm { __global__ void Code1x16MatVec( - const int4* __restrict__ A, - const int4* __restrict__ B, - int4* __restrict__ C, - const int4* __restrict__ codebook, - const int prob_m, - const int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. - const int codebook_stride // as int4. + const int4* __restrict__ A, const int4* __restrict__ B, + int4* __restrict__ C, const int4* __restrict__ codebook, const int prob_m, + const int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each + // codebook, at most 3 long. + const int codebook_stride // as int4. ) { int a_gl_stride = prob_k / 8 / 8; int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); bool pred = a_gl_rd < prob_m; - if (pred) - { - // advance to the correct codebook, this easy because we only multiply one column of the codebook. + if (pred) { + // advance to the correct codebook, this easy because we only multiply one + // column of the codebook. auto codebook_size = &codebook_a_sizes.x; - while (a_gl_rd >= *codebook_size) - { - codebook += codebook_stride; - ++codebook_size; + while (a_gl_rd >= *codebook_size) { + codebook += codebook_stride; + ++codebook_size; } } @@ -67,8 +63,7 @@ __global__ void Code1x16MatVec( // We pad shared memory to avoid bank conflicts during reads __syncthreads(); for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { - if (b_gl_rd + i < prob_k / 8) - sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; + if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; } __syncthreads(); b_gl_rd += 32 * 8; @@ -76,22 +71,19 @@ __global__ void Code1x16MatVec( int b_sh_rd = 9 * (threadIdx.x % 32); if (pred && a_gl_rd < a_gl_end) { const uint16_t* enc = reinterpret_cast(&A[a_gl_rd]); - #pragma unroll +#pragma unroll for (int i = 0; i < 8; i++) { uint32_t dec[4]; - // We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't - // actually help us; this brings > 2x speedup. - asm volatile ( - "ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" - : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) - : "l"((void*) &codebook[enc[i]]) - ); + // We bypass the L1 cache to avoid massive amounts of memory streaming + // that doesn't actually help us; this brings > 2x speedup. + asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) + : "l"((void*)&codebook[enc[i]])); half2* a = reinterpret_cast(&dec); half2* b = reinterpret_cast(&sh_b[b_sh_rd]); half2 res2 = {}; - #pragma unroll - for (int j = 0; j < 4; j++) - res2 = __hfma2(a[j], b[j], res2); +#pragma unroll + for (int j = 0; j < 4; j++) res2 = __hfma2(a[j], b[j], res2); res += __half2float(res2.x) + __half2float(res2.y); b_sh_rd++; } @@ -100,37 +92,33 @@ __global__ void Code1x16MatVec( } if (pred) { - #pragma unroll - for (int i = 16; i > 0; i /= 2) - res += __shfl_down_sync(0xffffffff, res, i); +#pragma unroll + for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i); if (threadIdx.x % 32 == 0) reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); } } __global__ void Code2x8MatVec( - const int4* __restrict__ A, - const int4* __restrict__ B, - int4* __restrict__ C, - const int4* __restrict__ codebook, - int prob_m, - int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. - const int codebook_stride // as int4. + const int4* __restrict__ A, const int4* __restrict__ B, + int4* __restrict__ C, const int4* __restrict__ codebook, int prob_m, + int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each + // codebook, at most 3 long. + const int codebook_stride // as int4. ) { int a_gl_stride = prob_k / 8 / 8; int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); bool pred = a_gl_rd < prob_m; - if (pred) - { - // advance to the correct codebook, this easy because we only multiply one column of the codebook. + if (pred) { + // advance to the correct codebook, this easy because we only multiply one + // column of the codebook. auto codebook_size = &codebook_a_sizes.x; - while (a_gl_rd >= *codebook_size) - { - codebook += codebook_stride; - ++codebook_size; + while (a_gl_rd >= *codebook_size) { + codebook += codebook_stride; + ++codebook_size; } } @@ -148,9 +136,8 @@ __global__ void Code2x8MatVec( for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { int4 dec = codebook[i]; - #pragma unroll - for (int j = 0; j < 8; j++) - sh_code[8 * i + (j + lane) % 8] = dec; +#pragma unroll + for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec; } __syncthreads(); @@ -161,8 +148,7 @@ __global__ void Code2x8MatVec( // We pad shared memory to avoid bank conflicts during reads __syncthreads(); for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { - if (b_gl_rd + i < prob_k / 8) - sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; + if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; } __syncthreads(); b_gl_rd += 32 * 8; @@ -170,13 +156,15 @@ __global__ void Code2x8MatVec( int b_sh_rd = 9 * (threadIdx.x % 32); if (pred && a_gl_rd < a_gl_end) { const uint8_t* enc = reinterpret_cast(&A[a_gl_rd]); - #pragma unroll +#pragma unroll for (int i = 0; i < 8; i++) { - half2* a0 = reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); - half2* a1 = reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); - half2* b = reinterpret_cast(&sh_b[b_sh_rd]); + half2* a0 = + reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); + half2* a1 = + reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); + half2* b = reinterpret_cast(&sh_b[b_sh_rd]); half2 res2 = {}; - #pragma unroll +#pragma unroll for (int j = 0; j < 4; j++) res2 = __hfma2(__hadd2(a0[j], a1[j]), b[j], res2); res += __half2float(res2.x) + __half2float(res2.y); @@ -187,36 +175,31 @@ __global__ void Code2x8MatVec( } if (pred) { - #pragma unroll - for (int i = 16; i > 0; i /= 2) - res += __shfl_down_sync(0xffffffff, res, i); +#pragma unroll + for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i); if (threadIdx.x % 32 == 0) reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); } } - __global__ void Code1x16Dequant( - const int4* __restrict__ A, - int4* __restrict__ C, - const int4* __restrict__ codebook, - int prob_m, - int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, sums to m. - const int codebook_stride // as int4 + const int4* __restrict__ A, int4* __restrict__ C, + const int4* __restrict__ codebook, int prob_m, int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each + // codebook, at most 3 long, sums to m. + const int codebook_stride // as int4 ) { int a_gl_stride = prob_k / 8 / 8; int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); bool pred = a_gl_rd < prob_m; - if (pred) - { - // advance to the correct codebook, this easy because we only multiply one column of the codebook. + if (pred) { + // advance to the correct codebook, this easy because we only multiply one + // column of the codebook. auto codebook_size = &codebook_a_sizes.x; - while (a_gl_rd >= *codebook_size) - { - codebook += codebook_stride; - ++codebook_size; + while (a_gl_rd >= *codebook_size) { + codebook += codebook_stride; + ++codebook_size; } } @@ -231,17 +214,15 @@ __global__ void Code1x16Dequant( while (iters--) { if (pred && a_gl_rd < a_gl_end) { const uint16_t* enc = reinterpret_cast(&A[a_gl_rd]); - #pragma unroll +#pragma unroll for (int i = 0; i < 8; i++) { int4 chunk; auto dec = reinterpret_cast(&chunk); - // We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't - // actually help us; this brings > 2x speedup. - asm volatile ( - "ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" - : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) - : "l"((void*) &codebook[enc[i]]) - ); + // We bypass the L1 cache to avoid massive amounts of memory streaming + // that doesn't actually help us; this brings > 2x speedup. + asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) + : "l"((void*)&codebook[enc[i]])); C[a_gl_rd * 8 + i] = chunk; } @@ -250,28 +231,25 @@ __global__ void Code1x16Dequant( } } - __global__ void Code2x8Dequant( - const int4* __restrict__ A, - int4* __restrict__ C, - const int4* __restrict__ codebook, - int prob_m, - int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols. - const int codebook_stride // as int4 + const int4* __restrict__ A, int4* __restrict__ C, + const int4* __restrict__ codebook, int prob_m, int prob_k, + const int4 + codebook_a_sizes, // cumulative sizes of A spanning each codebook, at + // most 3 long, corresponds to cols. + const int codebook_stride // as int4 ) { int a_gl_stride = prob_k / 8 / 8; int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); bool pred = a_gl_rd < prob_m; - if (pred) - { - // advance to the correct codebook, this easy because we only multiply one column of the codebook. + if (pred) { + // advance to the correct codebook, this easy because we only multiply one + // column of the codebook. auto codebook_size = &codebook_a_sizes.x; - while (a_gl_rd >= *codebook_size) - { - codebook += codebook_stride; - ++codebook_size; + while (a_gl_rd >= *codebook_size) { + codebook += codebook_stride; + ++codebook_size; } } @@ -290,9 +268,8 @@ __global__ void Code2x8Dequant( for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { int4 dec = codebook[i]; - #pragma unroll - for (int j = 0; j < 8; j++) - sh_code[8 * i + (j + lane) % 8] = dec; +#pragma unroll + for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec; } __syncthreads(); @@ -302,12 +279,14 @@ __global__ void Code2x8Dequant( while (iters--) { if (pred && a_gl_rd < a_gl_end) { const uint8_t* enc = reinterpret_cast(&A[a_gl_rd]); - #pragma unroll +#pragma unroll for (int i = 0; i < 8; i++) { int4 chunk; - half2* a0 = reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); - half2* a1 = reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); - #pragma unroll + half2* a0 = + reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); + half2* a1 = + reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); +#pragma unroll for (int j = 0; j < 4; j++) reinterpret_cast(&chunk)[j] = __hadd2(a0[j], a1[j]); C[a_gl_rd * 8 + i] = chunk; @@ -317,22 +296,15 @@ __global__ void Code2x8Dequant( } } -inline int ceildiv(int a, int b) { - return (a + b - 1) / b; -} +inline int ceildiv(int a, int b) { return (a + b - 1) / b; } const int THREAD_M = 16; -void code1x16_matvec_cuda( - const void* __restrict__ A, - const void* __restrict__ B, - void* __restrict__ C, - const void* __restrict__ codebook, - int prob_m, - int prob_k, - const int4 codebook_a_sizes, - const int codebook_stride -) { +void code1x16_matvec_cuda(const void* __restrict__ A, + const void* __restrict__ B, void* __restrict__ C, + const void* __restrict__ codebook, int prob_m, + int prob_k, const int4 codebook_a_sizes, + const int codebook_stride) { int sms; cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); int waves = 0; @@ -345,28 +317,16 @@ void code1x16_matvec_cuda( int blocks = ceildiv(prob_m, thread_m); int threads = 32 * thread_m; cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - Code1x16MatVec<<>>( - (const int4*) A, - (const int4*) B, - (int4*) C, - (const int4*) codebook, - prob_m, - prob_k, - codebook_a_sizes, - codebook_stride - ); + Code1x16MatVec<<>>( + (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m, + prob_k, codebook_a_sizes, codebook_stride); } -void code2x8_matvec_cuda( - const void* __restrict__ A, - const void* __restrict__ B, - void* __restrict__ C, - const void* __restrict__ codebook, - int prob_m, - int prob_k, - const int4 codebook_a_sizes, - const int codebook_stride -) { +void code2x8_matvec_cuda(const void* __restrict__ A, const void* __restrict__ B, + void* __restrict__ C, + const void* __restrict__ codebook, int prob_m, + int prob_k, const int4 codebook_a_sizes, + const int codebook_stride) { int sms; cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); int waves = 0; @@ -379,30 +339,20 @@ void code2x8_matvec_cuda( int blocks = ceildiv(prob_m, thread_m); int threads = 32 * thread_m; int shared = 16 * (2 * 256 * 8 + 32 * 9); - cudaFuncSetAttribute( - Code2x8MatVec, cudaFuncAttributeMaxDynamicSharedMemorySize, shared - ); + cudaFuncSetAttribute(Code2x8MatVec, + cudaFuncAttributeMaxDynamicSharedMemorySize, shared); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); Code2x8MatVec<<>>( - (const int4*) A, - (const int4*) B, - (int4*) C, - (const int4*) codebook, - prob_m, - prob_k, - codebook_a_sizes, - codebook_stride - ); + (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m, + prob_k, codebook_a_sizes, codebook_stride); } void code1x16_dequant_cuda( - const void* __restrict__ A, - void* __restrict__ C, - const void* __restrict__ codebook, - int prob_m, - int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. - const int codebook_stride // as int4. + const void* __restrict__ A, void* __restrict__ C, + const void* __restrict__ codebook, int prob_m, int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each + // codebook, at most 3 long. + const int codebook_stride // as int4. ) { int sms; cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); @@ -417,25 +367,21 @@ void code1x16_dequant_cuda( int threads = 32 * thread_m; cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); Code1x16Dequant<<>>( - (const int4*) A, - (int4*) C, - (const int4*) codebook, - prob_m, - prob_k, - codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. - codebook_stride // as int4. + (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k, + codebook_a_sizes, // cumulative sizes of A spanning each codebook, at + // most 3 long. + codebook_stride // as int4. ); } // Dequantizes the code and codebook into weights. -void code2x8_dequant_cuda( - const void* __restrict__ A, - void* __restrict__ C, - const void* __restrict__ codebook, - int prob_m, - int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols. - const int codebook_stride // as int4 +void code2x8_dequant_cuda( + const void* __restrict__ A, void* __restrict__ C, + const void* __restrict__ codebook, int prob_m, int prob_k, + const int4 + codebook_a_sizes, // cumulative sizes of A spanning each codebook, at + // most 3 long, corresponds to cols. + const int codebook_stride // as int4 ) { int sms; cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); @@ -451,74 +397,50 @@ void code2x8_dequant_cuda( int shared = 16 * (2 * 256 * 8 + 32 * 9); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cudaFuncSetAttribute( - Code2x8Dequant, cudaFuncAttributeMaxDynamicSharedMemorySize, shared - ); + cudaFuncSetAttribute(Code2x8Dequant, + cudaFuncAttributeMaxDynamicSharedMemorySize, shared); Code2x8Dequant<<>>( - (const int4*) A, - (int4*) C, - (const int4*) codebook, - prob_m, - prob_k, - codebook_a_sizes, - codebook_stride - ); + (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k, + codebook_a_sizes, codebook_stride); } -int codebook_stride(const torch::Tensor& codebooks) -{ +int codebook_stride(const torch::Tensor& codebooks) { return codebooks.stride(0) * codebooks.element_size() / sizeof(int4); } void code1x16_matvec( - const torch::Tensor& A, - const torch::Tensor& B, - torch::Tensor& C, - const torch::Tensor& codebook, - const int4 codebook_a_sizes // cumulative sizes of A spanning each codebook, at most 3 long. + const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& C, + const torch::Tensor& codebook, + const int4 codebook_a_sizes // cumulative sizes of A spanning each + // codebook, at most 3 long. ) { const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); int prob_m = C.size(0); int prob_k = B.size(0); - code1x16_matvec_cuda( - A.data_ptr(), - B.data_ptr(), - C.data_ptr(), - codebook.data_ptr(), - prob_m, - prob_k, - codebook_a_sizes, - codebook_stride(codebook) - ); + code1x16_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(), + codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes, + codebook_stride(codebook)); } -torch::Tensor code1x16_matmat( - const torch::Tensor& input, - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& scales, - const int4 codebook_a_sizes, - const std::optional& bias) { +torch::Tensor code1x16_matmat(const torch::Tensor& input, + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const int4 codebook_a_sizes, + const std::optional& bias) { auto input_sizes = input.sizes(); auto out_features = codes.size(0) * codebooks.size(2); auto flat_input = input.reshape({-1, input.size(-1)}); - auto flat_output = torch::empty({flat_input.size(0), out_features}, - torch::TensorOptions() - .dtype(input.dtype()) - .device(input.device()) - ); + auto flat_output = torch::empty( + {flat_input.size(0), out_features}, + torch::TensorOptions().dtype(input.dtype()).device(input.device())); for (int i = 0; i < flat_input.size(0); ++i) { auto input_vec = flat_input.index({i}); auto output_vec = flat_output.index({i}); - code1x16_matvec( - codes.squeeze(2), - input_vec, - output_vec, - codebooks, - codebook_a_sizes - ); + code1x16_matvec(codes.squeeze(2), input_vec, output_vec, codebooks, + codebook_a_sizes); } flat_output *= scales.flatten().unsqueeze(0); @@ -533,55 +455,35 @@ torch::Tensor code1x16_matmat( return output; } -void code2x8_matvec( - const torch::Tensor& A, - const torch::Tensor& B, - torch::Tensor& C, - const torch::Tensor& codebook, - const int4 codebook_a_sizes -) { +void code2x8_matvec(const torch::Tensor& A, const torch::Tensor& B, + torch::Tensor& C, const torch::Tensor& codebook, + const int4 codebook_a_sizes) { const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); int prob_m = C.size(0); int prob_k = B.size(0); - code2x8_matvec_cuda( - A.data_ptr(), - B.data_ptr(), - C.data_ptr(), - codebook.data_ptr(), - prob_m, - prob_k, - codebook_a_sizes, - 2 * codebook_stride(codebook) - ); + code2x8_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(), + codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes, + 2 * codebook_stride(codebook)); } -torch::Tensor code2x8_matmat( - const torch::Tensor& input, - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& scales, - const int4 codebook_a_sizes, - const std::optional& bias -) { +torch::Tensor code2x8_matmat(const torch::Tensor& input, + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const int4 codebook_a_sizes, + const std::optional& bias) { auto input_sizes = input.sizes(); auto out_features = codes.size(0) * codebooks.size(2); auto flat_input = input.reshape({-1, input.size(-1)}); - auto flat_output = torch::empty({flat_input.size(0), out_features}, - torch::TensorOptions() - .dtype(input.dtype()) - .device(input.device()) - ); + auto flat_output = torch::empty( + {flat_input.size(0), out_features}, + torch::TensorOptions().dtype(input.dtype()).device(input.device())); for (int i = 0; i < flat_input.size(0); ++i) { auto input_vec = flat_input.index({i}); auto output_vec = flat_output.index({i}); - code2x8_matvec( - codes.squeeze(2), - input_vec, - output_vec, - codebooks, - codebook_a_sizes - ); + code2x8_matvec(codes.squeeze(2), input_vec, output_vec, codebooks, + codebook_a_sizes); } flat_output *= scales.flatten().unsqueeze(0); if (bias.has_value()) { @@ -596,64 +498,56 @@ torch::Tensor code2x8_matmat( } // Accumulate the partition sizes. -int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) -{ +int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) { int4 cumulative_sizes; auto cumulative_size = &cumulative_sizes.x; int i = 0; int last = 0; assert(codebook_partition_sizes.size(0) <= 4); - for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) - { + for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) { *cumulative_size = codebook_partition_sizes[i].item() + last; last = *cumulative_size; } // fill in the rest with unreachable. - for (; i < 4; ++i, ++cumulative_size) - { - *cumulative_size = last*10; + for (; i < 4; ++i, ++cumulative_size) { + *cumulative_size = last * 10; } return cumulative_sizes; } -} // namespace aqlm -} // namespace vllm - +} // namespace aqlm +} // namespace vllm -torch::Tensor aqlm_gemm( - const torch::Tensor& input, - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& scales, - const torch::Tensor& codebook_partition_sizes, - const std::optional& bias -) -{ - int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes); +torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const torch::Tensor& codebook_partition_sizes, + const std::optional& bias) { + int4 cumulative_sizes = + vllm::aqlm::accumulate_sizes(codebook_partition_sizes); int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); int const entries = codebooks.size(1); - if (nbooks == 1 && entries == (1 << 16)) - { - return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, cumulative_sizes, bias); + if (nbooks == 1 && entries == (1 << 16)) { + return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, + cumulative_sizes, bias); } - if (nbooks == 2 && entries == (1 << 8)) - { - return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, cumulative_sizes, bias); + if (nbooks == 2 && entries == (1 << 8)) { + return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, + cumulative_sizes, bias); } - TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.") + TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, + " entries is not currently supported.") return {}; } -torch::Tensor aqlm_dequant( - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& codebook_partition_sizes -) -{ - int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes); +torch::Tensor aqlm_dequant(const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& codebook_partition_sizes) { + int4 cumulative_sizes = + vllm::aqlm::accumulate_sizes(codebook_partition_sizes); int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); int const entries = codebooks.size(1); @@ -668,45 +562,37 @@ torch::Tensor aqlm_dequant( assert(out_features = codebook_partition_sizes.sum().item()); auto weights = torch::empty({out_features, in_features}, - torch::TensorOptions() - .dtype(codebooks.dtype()) - .device(codebooks.device()) - ); + torch::TensorOptions() + .dtype(codebooks.dtype()) + .device(codebooks.device())); + + if (nbooks == 1 && entries == (1 << 16)) { + vllm::aqlm::code1x16_dequant_cuda(codes.data_ptr(), weights.data_ptr(), + codebooks.data_ptr(), out_features, + in_features, cumulative_sizes, + vllm::aqlm::codebook_stride(codebooks)); - if (nbooks == 1 && entries == (1 << 16)) - { - vllm::aqlm::code1x16_dequant_cuda( - codes.data_ptr(), - weights.data_ptr(), - codebooks.data_ptr(), - out_features, - in_features, - cumulative_sizes, - vllm::aqlm::codebook_stride(codebooks)); - - // if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation.) - // weights *= scales.index({"...", 0, 0}); - - return weights; + // if you wanted to flip to scaling the weights, (though it's 30%-ish slower + // and not consistent with gemv implementation.) weights *= + // scales.index({"...", 0, 0}); + + return weights; } - if (nbooks == 2 && entries == (1 << 8)) - { - vllm::aqlm::code2x8_dequant_cuda( - codes.data_ptr(), - weights.data_ptr(), - codebooks.data_ptr(), - out_features, - in_features, - cumulative_sizes, - vllm::aqlm::codebook_stride(codebooks)); - - // if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation) - // weights *= scales.index({"...", 0, 0}); - - return weights; + if (nbooks == 2 && entries == (1 << 8)) { + vllm::aqlm::code2x8_dequant_cuda(codes.data_ptr(), weights.data_ptr(), + codebooks.data_ptr(), out_features, + in_features, cumulative_sizes, + vllm::aqlm::codebook_stride(codebooks)); + + // if you wanted to flip to scaling the weights, (though it's 30%-ish slower + // and not consistent with gemv implementation) weights *= + // scales.index({"...", 0, 0}); + + return weights; } - TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.") + TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, + " entries is not currently supported.") return {}; } diff --git a/csrc/quantization/awq/dequantize.cuh b/csrc/quantization/awq/dequantize.cuh index d1d926de18d78..813ec6716cf54 100644 --- a/csrc/quantization/awq/dequantize.cuh +++ b/csrc/quantization/awq/dequantize.cuh @@ -1,11 +1,11 @@ /* Adapted from https://github.com/mit-han-lab/llm-awq -Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +Modified from NVIDIA FasterTransformer: +https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h @article{lin2023awq, - title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, - author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, - journal={arXiv}, - year={2023} + title={AWQ: Activation-aware Weight Quantization for LLM Compression and +Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, +Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} } */ @@ -14,74 +14,88 @@ Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransfor namespace vllm { namespace awq { -__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) -{ +__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 assert(false); #else - uint4 result; + uint4 result; - uint32_t* h = reinterpret_cast(&result); - uint32_t const i4s = reinterpret_cast(source); + uint32_t* h = reinterpret_cast(&result); + uint32_t const i4s = reinterpret_cast(source); - // First, we extract the i4s and construct an intermediate fp16 number. - static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint32_t BOTTOM_MASK = 0x000f000f; - static constexpr uint32_t TOP_MASK = 0x00f000f0; - static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM_MASK = 0x000f000f; + static constexpr uint32_t TOP_MASK = 0x00f000f0; + static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; - // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing - // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. - // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and - // elt_67 to fp16 without having to shift them to the bottom bits before hand. + // Note that the entire sequence only requires 1 shift instruction. This is + // thanks to the register packing format and the fact that we force our + // integers to be unsigned, and account for this in the fp16 subtractions. In + // addition, I exploit the fact that sub and fma have the same throughput in + // order to convert elt_23 and elt_67 to fp16 without having to shift them to + // the bottom bits before hand. - // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue - // immediately before required. - const uint32_t top_i4s = i4s >> 8; - // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[0]) - : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[1]) - : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[2]) - : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[3]) - : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW + // dependency if we issue immediately before required. + const uint32_t top_i4s = i4s >> 8; + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); + // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); + // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); + // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); - // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the - // half2 ctor. In this case, I chose performance reliability over code readability. + // I use inline PTX below because I am not sure if the compiler will emit + // float2half instructions if I use the half2 ctor. In this case, I chose + // performance reliability over code readability. - // This is the half2 {1032, 1032} represented as an integer. - // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; - // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7] - static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400; - // This is the half2 {1 / 16, 1 / 16} represented as an integer. - static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; - // This is the half2 {-72, -72} represented as an integer. - // static constexpr uint32_t NEG_72 = 0xd480d480; - // Haotian: Let's use {-64, -64}. - static constexpr uint32_t NEG_64 = 0xd400d400; + // This is the half2 {1032, 1032} represented as an integer. + // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; + // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7] + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400; + // This is the half2 {1 / 16, 1 / 16} represented as an integer. + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + // This is the half2 {-72, -72} represented as an integer. + // static constexpr uint32_t NEG_72 = 0xd480d480; + // Haotian: Let's use {-64, -64}. + static constexpr uint32_t NEG_64 = 0xd400d400; - // Finally, we construct the output numbers. - // Convert elt_01 - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); - // Convert elt_23 - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); - // Convert elt_45 - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); - // Convert elt_67 - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[0]) + : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(h[1]) + : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); + // Convert elt_45 + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[2]) + : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(h[3]) + : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); - return result; + return result; #endif } -} // namespace awq -} // namespace vllm +} // namespace awq +} // namespace vllm diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu index 5aefb0bd16aef..6d6da5f3d8746 100644 --- a/csrc/quantization/awq/gemm_kernels.cu +++ b/csrc/quantization/awq/gemm_kernels.cu @@ -1,15 +1,13 @@ /* Adapted from https://github.com/mit-han-lab/llm-awq @article{lin2023awq, - title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, - author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, - journal={arXiv}, - year={2023} + title={AWQ: Activation-aware Weight Quantization for LLM Compression and +Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, +Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} } */ - -#include +#include #include #include "dequantize.cuh" @@ -20,26 +18,20 @@ namespace vllm { namespace awq { // Pack two half values. -static inline __device__ __host__ unsigned -__pack_half2(const half x, const half y) { - unsigned v0 = *((unsigned short *)&x); - unsigned v1 = *((unsigned short *)&y); +static inline __device__ __host__ unsigned __pack_half2(const half x, + const half y) { + unsigned v0 = *((unsigned short*)&x); + unsigned v1 = *((unsigned short*)&y); return (v1 << 16) | v0; } -template -__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( - int G, - int split_k_iters, - half* __restrict__ A, - int* __restrict__ B, - half* __restrict__ scaling_factors, - int* __restrict__ zeros, - int M, - int IC, - int OC, - half* __restrict__ C) -{ +template +__global__ void __launch_bounds__(64) + gemm_forward_4bit_cuda_m16nXk32(int G, int split_k_iters, + half* __restrict__ A, int* __restrict__ B, + half* __restrict__ scaling_factors, + int* __restrict__ zeros, int M, int IC, + int OC, half* __restrict__ C) { // Only support matrix n = 64 or 128 assert(N == 64 || N == 128); #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 @@ -70,43 +62,46 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( static constexpr int row_stride = 2 * 32 * 8 / N; bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N; // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 - bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id + bool ld_A_flag = + (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id // bool wb_C_flag = (threadIdx.x / 4) < M; - half* A_ptr = A - + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC - + (((int)threadIdx.x) % (32 / 8)) * 8; - - int* B_ptr = B - + ((int)threadIdx.y) * (OC / 8) * (256 / N) - + (((int)threadIdx.x) / (N / 8)) * (OC / 8) - + (((int)blockIdx_y) % j_factors1) * (N / 8) - + (((int)threadIdx.x) % (N / 8)) * 1; -// Why * 1 in the above line? - - half* A_shared_ptr = A_shared - + ((int)threadIdx.y) * row_stride_warp * (32 + 8) - + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) - + (((int)threadIdx.x) % (32 / 8) ) * 8; - - half* B_shared_ptr = B_shared - + ((int)threadIdx.y) * (row_stride / 2) * (N + 8) - + (((int)threadIdx.x) / (N / 8)) * (N + 8) - + (((int)threadIdx.x) % (N / 8)) * 8; - - int* zeros_ptr = zeros - + (((int)blockIdx_y) % j_factors1) * (N / 8) - + ((int)threadIdx.x) % (N / 8); - - half* scaling_factors_ptr = scaling_factors - + (((int)blockIdx_y) % j_factors1) * N - + (((int)threadIdx.x) % (N / 8)) * 8; - - half* C_ptr = C - + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim - + (((int)blockIdx_y) % j_factors1) * N - + ((int)threadIdx.y) * (N / 2) - + (((int)threadIdx.x) % 4) * 2; + half* A_ptr = + A + + (((int)blockIdx_y) / j_factors1 * 16 + + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * + IC + + (((int)threadIdx.x) % (32 / 8)) * 8; + + int* B_ptr = B + ((int)threadIdx.y) * (OC / 8) * (256 / N) + + (((int)threadIdx.x) / (N / 8)) * (OC / 8) + + (((int)blockIdx_y) % j_factors1) * (N / 8) + + (((int)threadIdx.x) % (N / 8)) * 1; + // Why * 1 in the above line? + + half* A_shared_ptr = A_shared + + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + + (((int)threadIdx.x) % (32 / 8)) * 8; + + half* B_shared_ptr = B_shared + + ((int)threadIdx.y) * (row_stride / 2) * (N + 8) + + (((int)threadIdx.x) / (N / 8)) * (N + 8) + + (((int)threadIdx.x) % (N / 8)) * 8; + + int* zeros_ptr = zeros + (((int)blockIdx_y) % j_factors1) * (N / 8) + + ((int)threadIdx.x) % (N / 8); + + half* scaling_factors_ptr = scaling_factors + + (((int)blockIdx_y) % j_factors1) * N + + (((int)threadIdx.x) % (N / 8)) * 8; + + half* C_ptr = + C + + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim + + (((int)blockIdx_y) % j_factors1) * N + ((int)threadIdx.y) * (N / 2) + + (((int)threadIdx.x) % 4) * 2; // preload s.f. and zeros int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; @@ -115,57 +110,83 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; __syncthreads(); // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 - if (ld_A_flag) - { + if (ld_A_flag) { *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); - } - else - { + } else { *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); } // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); - uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); + uint4 B_loaded_scale = + *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); /* - if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){ - printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); + if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && + threadIdx.y == 0){ printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, + B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, + B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); } */ // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) { - // B: 32 x 136 (128+8) float16 // each warp: 32 x 4 - // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 - // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8))); - // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) - uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); + // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus + // zero -> WB UINT4 + // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * + // 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) + // * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * + // 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * + // 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * + // 8))); row stride in shared memory: (NWARPS * 32 * 8 / cta_N) + uint32_t B_loaded = + *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); - //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8); + // uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / + // 8)) * 8); - // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8); + // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x + // % (cta_N / 8)) * 8); // - zero and * scale - // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = + // q * scale - zero * scale. + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.x) + : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.x) + : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.y) + : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.y) + : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.z) + : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.z) + : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.w) + : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.w) + : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); /* - if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){ - printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); + if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == + 0 && threadIdx.x == 17 && threadIdx.y == 0){ printf("[x] %X %X %X %X\n", + B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); } */ // write back - *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16; + *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = + B_loaded_fp16; } __syncthreads(); @@ -173,112 +194,179 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( { unsigned int addr; __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) - ); - + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, " + "addr; }\n" + : "=r"(addr) + : "l"((void*)((&(A_shared[(k_0_1 * 16)])) + + (((((int)threadIdx.x) & 15) * 40) + + ((((int)threadIdx.x) >> 4) * 8))))); __asm__ __volatile__( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) - : "r"(addr) - ); + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned*)(A_shared_warp + 0))[0]), + "=r"(((unsigned*)(A_shared_warp + 0))[1]), + "=r"(((unsigned*)(A_shared_warp + 0))[2]), + "=r"(((unsigned*)(A_shared_warp + 0))[3]) + : "r"(addr)); } for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) { { unsigned int addr; __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8)))) - ); + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, " + "addr; }\n" + : "=r"(addr) + : "l"((void*)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + + (((int)threadIdx.y) * (N / 2))) + + (ax1_0 * 16))])) + + (((((int)threadIdx.x) & 15) * (N + 8)) + + ((((int)threadIdx.x) >> 4) * 8))))); __asm__ __volatile__( - "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) - : "r"(addr) - ); + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]), + "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[1]), + "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[2]), + "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[3]) + : "r"(addr)); } } for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); } -#else + #else { __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, " + "%13};\n" + : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, " + "%13};\n" + : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); } -#endif + #endif } } } -// TODO: Shang: Hoist loop invariance. + // TODO: Shang: Hoist loop invariance. for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) { for (int local_id = 0; local_id < 8; ++local_id) { - int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; - if (row_offset < M) - { - *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); + int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; + if (row_offset < M) { + *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); } } } #endif } -__global__ void __launch_bounds__(64) dequantize_weights( - int* __restrict__ B, - half* __restrict__ scaling_factors, - int* __restrict__ zeros, - half* __restrict__ C, - int G -) -{ +__global__ void __launch_bounds__(64) + dequantize_weights(int* __restrict__ B, half* __restrict__ scaling_factors, + int* __restrict__ zeros, half* __restrict__ C, int G) { int j_factors1 = 4; int row_stride2 = 4; int split_k_iters = 1; @@ -310,14 +398,30 @@ __global__ void __launch_bounds__(64) dequantize_weights( uint32_t B_loaded = *(uint32_t*)B_ptr2; uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.x) + : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.x) + : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.y) + : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.y) + : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.z) + : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.z) + : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.w) + : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.w) + : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); *(uint4*)B_shared_ptr2 = B_loaded_fp16; @@ -326,58 +430,57 @@ __global__ void __launch_bounds__(64) dequantize_weights( } } -} // namespace awq -} // namespace vllm - -torch::Tensor awq_dequantize( - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters, - int thx, - int thy) -{ - int in_c = _kernel.size(0); - int qout_c = _kernel.size(1); - int out_c = qout_c * 8; - int G = in_c / _scaling_factors.size(0); - - int x_thread = thx; - int y_thread = thy; - - int x_blocks = 1; - int y_blocks = 1; - if (thx==0) { - x_thread = qout_c; - } - if (thy==0) { - y_thread = in_c; - } - if (thx==0 && thy==0) { - x_thread = 8; - y_thread = 8; - x_blocks = (int)(qout_c / 8); - y_blocks = (int)(in_c / 8); - } +} // namespace awq +} // namespace vllm + +torch::Tensor awq_dequantize(torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, int64_t split_k_iters, + int64_t thx, int64_t thy) { + int in_c = _kernel.size(0); + int qout_c = _kernel.size(1); + int out_c = qout_c * 8; + int G = in_c / _scaling_factors.size(0); + + int x_thread = thx; + int y_thread = thy; + + int x_blocks = 1; + int y_blocks = 1; + if (thx == 0) { + x_thread = qout_c; + } + if (thy == 0) { + y_thread = in_c; + } + if (thx == 0 && thy == 0) { + x_thread = 8; + y_thread = 8; + x_blocks = (int)(qout_c / 8); + y_blocks = (int)(in_c / 8); + } - const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); - auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device()); - at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); + auto options = torch::TensorOptions() + .dtype(_scaling_factors.dtype()) + .device(_scaling_factors.device()); + at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); - auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); + auto scaling_factors = + reinterpret_cast(_scaling_factors.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); - dim3 num_blocks(x_blocks, y_blocks); - dim3 threads_per_block(x_thread, y_thread); + dim3 num_blocks(x_blocks, y_blocks); + dim3 threads_per_block(x_thread, y_thread); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - vllm::awq::dequantize_weights<<>>( - kernel, scaling_factors, zeros, de_kernel, G); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + vllm::awq::dequantize_weights<<>>( + kernel, scaling_factors, zeros, de_kernel, G); - return _de_kernel; + return _de_kernel; } // in_feats: M, IC [float16] @@ -386,61 +489,61 @@ torch::Tensor awq_dequantize( // zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b] // assume that batch_size < 16 for now -torch::Tensor awq_gemm( - torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters) -{ - int num_in_feats = _in_feats.size(0); - int num_in_channels = _in_feats.size(1); - const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); - - auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); - at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); - int num_out_feats = _out_feats.size(-2); - int num_out_channels = _out_feats.size(-1); - - auto in_feats = reinterpret_cast(_in_feats.data_ptr()); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto out_feats = reinterpret_cast(_out_feats.data_ptr()); - auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); - int group_size = num_in_channels / _scaling_factors.size(0); - - if (num_out_channels % 64 != 0) - throw std::invalid_argument("OC is not multiple of cta_N = 64"); - if (num_out_channels % 8 != 0) - throw std::invalid_argument("OC is not multiple of pack_num = 8"); - if (group_size % 32 != 0) - throw std::invalid_argument("Group size should be a multiple of 32"); - if (num_out_channels % group_size != 0) - throw std::invalid_argument("OC is not multiple of Group size"); - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (num_out_channels % 128 == 0) - { - int j_factors1 = num_out_channels / 128 / 1; - dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); - // threadIdx.x: 32 - // threadIdx.y: i_factors[2] * j_factors[2] - dim3 threads_per_block(32, 2); - vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128><<>>( - group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, - num_out_channels, out_feats); - } - else if (num_out_channels % 64 == 0) - { - int j_factors1 = num_out_channels / 64 / 1; - dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); - - // threadIdx.x: 32 - // threadIdx.y: i_factors[2] * j_factors[2] - dim3 threads_per_block(32, 2); - vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64><<>>( - group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, - num_out_channels, out_feats); - } - return _out_feats.sum(0); +torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, + torch::Tensor _scaling_factors, torch::Tensor _zeros, + int64_t split_k_iters) { + int num_in_feats = _in_feats.size(0); + int num_in_channels = _in_feats.size(1); + const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); + + auto options = torch::TensorOptions() + .dtype(_in_feats.dtype()) + .device(_in_feats.device()); + at::Tensor _out_feats = + torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); + int num_out_feats = _out_feats.size(-2); + int num_out_channels = _out_feats.size(-1); + + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + auto scaling_factors = + reinterpret_cast(_scaling_factors.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); + int group_size = num_in_channels / _scaling_factors.size(0); + + if (num_out_channels % 64 != 0) + throw std::invalid_argument("OC is not multiple of cta_N = 64"); + if (num_out_channels % 8 != 0) + throw std::invalid_argument("OC is not multiple of pack_num = 8"); + if (group_size % 32 != 0) + throw std::invalid_argument("Group size should be a multiple of 32"); + if (num_out_channels % group_size != 0) + throw std::invalid_argument("OC is not multiple of Group size"); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if (num_out_channels % 128 == 0) { + int j_factors1 = num_out_channels / 128 / 1; + dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128> + <<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, + num_in_feats, num_in_channels, num_out_channels, out_feats); + } else if (num_out_channels % 64 == 0) { + int j_factors1 = num_out_channels / 64 / 1; + dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * + split_k_iters); + + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64> + <<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, + num_in_feats, num_in_channels, num_out_channels, out_feats); + } + return _out_feats.sum(0); } diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu new file mode 100644 index 0000000000000..aa9511daa2772 --- /dev/null +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -0,0 +1,115 @@ +#include +#include +#include + +#include "../../dispatch_utils.h" +#include "../../reduction_utils.cuh" + +static inline __device__ int8_t float_to_int8_rn(float x) { +#ifdef USE_ROCM + static const float i8_min = + static_cast(std::numeric_limits::min()); + static const float i8_max = + static_cast(std::numeric_limits::max()); + // round + float dst = std::nearbyint(x); + // saturate + dst = std::clamp(dst, i8_min, i8_max); + return static_cast(dst); +#else + // CUDA path + uint32_t dst; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +#endif +} + +namespace vllm { + +template +__global__ void static_scaled_int8_quant_kernel( + scalar_t const* __restrict__ input, int8_t* __restrict__ out, + scale_type const* scale_ptr, const int hidden_size) { + int const tid = threadIdx.x; + int const token_idx = blockIdx.x; + scale_type const scale = *scale_ptr; + + for (int i = tid; i < hidden_size; i += blockDim.x) { + out[token_idx * hidden_size + i] = float_to_int8_rn( + static_cast(input[token_idx * hidden_size + i]) / scale); + } +} + +template +__global__ void dynamic_scaled_int8_quant_kernel( + scalar_t const* __restrict__ input, int8_t* __restrict__ out, + scale_type* scale, const int hidden_size) { + int const tid = threadIdx.x; + int const token_idx = blockIdx.x; + float absmax_val = 0.0f; + float const zero = 0.0f; + + for (int i = tid; i < hidden_size; i += blockDim.x) { + float val = static_cast(input[token_idx * hidden_size + i]); + val = val > zero ? val : -val; + absmax_val = val > absmax_val ? val : absmax_val; + } + + float const block_absmax_val_maybe = blockReduceMax(absmax_val); + __shared__ float block_absmax_val; + if (tid == 0) { + block_absmax_val = block_absmax_val_maybe; + scale[token_idx] = block_absmax_val / 127.0f; + } + __syncthreads(); + + float const tmp_scale = 127.0f / block_absmax_val; + for (int i = tid; i < hidden_size; i += blockDim.x) { + out[token_idx * hidden_size + i] = float_to_int8_rn( + static_cast(input[token_idx * hidden_size + i]) * tmp_scale); + } +} + +} // namespace vllm + +void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] + torch::Tensor const& input, // [..., hidden_size] + torch::Tensor const& scale) { + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(scale.numel() == 1); + + int const hidden_size = input.size(-1); + int const num_tokens = input.numel() / hidden_size; + dim3 const grid(num_tokens); + dim3 const block(std::min(hidden_size, 1024)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { + vllm::static_scaled_int8_quant_kernel + <<>>(input.data_ptr(), + out.data_ptr(), + scale.data_ptr(), hidden_size); + }); +} + +void dynamic_scaled_int8_quant( + torch::Tensor& out, // [..., hidden_size] + torch::Tensor const& input, // [..., hidden_size] + torch::Tensor& scales) { + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); + + int const hidden_size = input.size(-1); + int const num_tokens = input.numel() / hidden_size; + dim3 const grid(num_tokens); + dim3 const block(std::min(hidden_size, 1024)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { + vllm::dynamic_scaled_int8_quant_kernel + <<>>(input.data_ptr(), + out.data_ptr(), + scales.data_ptr(), hidden_size); + }); +} diff --git a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp b/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp new file mode 100644 index 0000000000000..c4c6b18654eed --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp @@ -0,0 +1,346 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// +// This file is a modified excerpt of +// include/cutlass/epilogue/fusion/visitor_load.hpp from +// https://github.com/NVIDIA/cutlass v3.5.0 +// It has been modified to support either +// row/column or scalar broadcasting where the tensor being loaded from is +// always passed in via a device pointer. This lets one compiled kernel handle +// all cases of per-tensor or per-channel/per-token quantization. +// +// This interface also allows the scales to be passed in as tensors that +// consistently reside on the device, which avoids an issue with a previous +// implementation where scalars needed to be on the CPU since they +// were passed in via float values. This created a potential performance hazard +// if scales were initially on the device, and caused torch.compile graph +// breaks when moving scales to the CPU. +// +#pragma once + +// Turn off clang-format for the entire file to keep it close to upstream +// clang-format off + +#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp" +#include "cute/tensor.hpp" + +namespace cutlass::epilogue::threadblock { + +using namespace cute; +using namespace detail; + +template< + class ThreadMap, + class Element, + class StrideMNL +> +struct VisitorRowOrScalarBroadcast { + + // This struct has been modified to have a bool indicating that ptr_row is a + // scalar that must be broadcast. + struct Arguments { + Element const* ptr_row = nullptr; + bool row_broadcast = true; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage {}; + + // Global load type + static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value; + using VecType = uint_bit_t; + static int constexpr VecLength = sizeof(VecType) / sizeof(Element); + + CUTLASS_HOST_DEVICE + VisitorRowOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gRow, + RTensor&& tC_rRow, + CTensor&& tC_cRow, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gRow(cute::forward(tC_gRow)), + tC_rRow(cute::forward(tC_rRow)), + tC_cRow(cute::forward(tC_cRow)), + n(get<1>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gRow; + RTensor tC_rRow; + CTensor tC_cRow; + Params const* params_ptr; + int n; + + // This function is modified from VisitorRowBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rRow); + auto src_v = filter(tC_gRow); + auto coord_v = filter(tC_cRow); + auto dst_v = filter(tC_rRow); + + if (params_ptr->row_broadcast) { + // In this case we are loading from a row vector and broadcasting + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + bool guard = get<1>(coord_v(i)) < n; + cutlass::arch::global_load( + dst_v(i), (void const*)&src_v(i), guard); + } + } else { + // In this case we are loading from a scalar and broadcasting + VecType filled_vec; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < VecLength; i++) { + reinterpret_cast(&filled_vec)[i] = *(params_ptr->ptr_row); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + if (get<1>(coord_v(i)) < n) { + dst_v(i) = filled_vec; + } + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Tensor rRow_frg = recast>(coalesce(tC_rRow)); + return rRow_frg(column_idx); + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mRow = make_tensor( + make_gmem_ptr(params_ptr->ptr_row), + problem_shape, + params_ptr->dRow); + + // VECTOR, FRAGMENT_COLUMN + Tensor tC_gRow = recast( + ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset) + )(_,_,_0{},_0{},_0{},_0{}); + Tensor tC_rRow = make_tensor_like(tC_gRow); + + // Generate the pred tensor + Tensor cRow = make_identity_tensor(mRow.shape()); + Tensor tC_cRow = outer_partition( + ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}), + Shape>{}, + (_0{}) + ); + + return Callbacks< + decltype(tC_gRow), decltype(tC_rRow), + decltype(tC_cRow), ProblemShape>( + cute::move(tC_gRow), + cute::move(tC_rRow), + cute::move(tC_cRow), + problem_shape, + params_ptr + ); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Column vector broadcast +template< + class ThreadMap, + class Element, + class StrideMNL = Stride<_1,_0,_0> +> +struct VisitorColOrScalarBroadcast { + + // This struct has been modified to have a bool indicating that ptr_col is a + // scalar that must be broadcast. + struct Arguments { + Element const* ptr_col = nullptr; + bool col_broadcast = true; + StrideMNL dCol = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage { }; + + CUTLASS_HOST_DEVICE + VisitorColOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gCol, + RTensor&& tC_rCol, + CTensor&& tC_cCol, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gCol(cute::forward(tC_gCol)), + tC_rCol(cute::forward(tC_rCol)), + tC_cCol(cute::forward(tC_cCol)), + m(get<0>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gCol; + RTensor tC_rCol; + CTensor tC_cCol; + Params const* params_ptr; + int m; + + // This function is modified from VisitorColBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rCol); + + Tensor pred = make_tensor(shape(tC_gCol)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(pred); ++i) { + pred(i) = get<0>(tC_cCol(i)) < m; + } + + if (params_ptr->col_broadcast) { + // In this case we are loading from a column vector and broadcasting + copy_if(pred, tC_gCol, tC_rCol); + } else { + // In this case we are loading from a scalar and broadcasting + auto dst_v = filter(tC_rCol); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(dst_v); ++i) { + if (pred(i)) { + dst_v(i) = *(params_ptr->ptr_col); + } + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Array frg_col; + frg_col.fill(tC_rCol(row_idx,iter_idx)); + return frg_col; + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mCol = make_tensor( + make_gmem_ptr(params_ptr->ptr_col), + problem_shape, + params_ptr->dCol); + + // VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER + Tensor tC_gCol = group_modes<1,4>( + ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); + Tensor tC_rCol = make_tensor_like(tC_gCol); + + // Generate the pred tensor + Tensor cCol = make_identity_tensor(mCol.shape()); + Tensor tC_cCol = group_modes<1,4>( + ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); + + return Callbacks< + decltype(tC_gCol), decltype(tC_rCol), + decltype(tC_cCol), ProblemShape>( + cute::move(tC_gCol), + cute::move(tC_rCol), + cute::move(tC_cCol), + problem_shape, + params_ptr + ); + } +}; + +} diff --git a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp b/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp new file mode 100644 index 0000000000000..877a9f5b9e5de --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp @@ -0,0 +1,389 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// +// This file is a modified excerpt of +// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +// from https://github.com/NVIDIA/cutlass v3.5.0 +// It has been modified to support either row/column or scalar broadcasting +// where the tensor being loaded from is always passed in via a device pointer. +// This lets one compiled kernel handle all cases of per-tensor or +// per-channel/per-token quantization. +// +// This interface also allows the scales to be passed in as tensors that +// consistently reside on the device, which avoids an issue with a previous +// implementation where scalars needed to be on the CPU since they +// were passed in via float values. This created a potential performance hazard +// if scales were initially on the device, and caused torch.compile graphs +// breaks when moving scales to the CPU. +// +#pragma once + +// Turn off clang-format for the entire file to keep it close to upstream +// clang-format off + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" + +#include "cute/tensor.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +// Row vector broadcast +template< + // Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least + // ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_0,_1,_0>, + int Alignment = 128 / sizeof_bits_v +> +struct Sm90RowOrScalarBroadcast { + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert( + (cute::is_same_v>) || // row vector broadcast, e.g. per-col alpha/bias + (cute::is_same_v>)); // batched row vector broadcast + + // Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem + struct SharedStorage { + alignas(16) array_aligned(CtaTileShapeMNK{}) * Stages> smem_row; + }; + + // This struct has been modified to have a bool indicating that ptr_row is a + // scalar that must be broadcast, instead of containing a scalar that is + // valid if ptr_row is null. + struct Arguments { + Element const* ptr_row = nullptr; + bool row_broadcast = true; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90RowOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params(params), + smem_row(const_cast(shared_storage.smem_row.data())) { } + + Params params; + Element* smem_row; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return true; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (!params.row_broadcast && *(params.ptr_row) == Element(0)); + } + + template + struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks { + CUTLASS_DEVICE + ProducerLoadCallbacks(GTensor&& gRow, STensor&& sRow, Params const& params) + : gRow(cute::forward(gRow)), + sRow(cute::forward(sRow)), + params(params) {} + + GTensor gRow; // (CTA_M,CTA_N) + STensor sRow; // (CTA_M,CTA_N,PIPE) + Params const& params; + + CUTLASS_DEVICE void + begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) { + if (!params.row_broadcast) { + return; + } + + if (issue_tma_load) { + // Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size + constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * sizeof_bits_v / 8; + cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes); + // Issue the TMA bulk copy + auto bulk_copy = Copy_Atom{}.with(*full_mbarrier_ptr); + // Filter so we don't issue redundant copies over stride-0 modes + int bcast_pipe_index = (load_iteration / EpiTiles) % Stages; + copy(bulk_copy, filter(gRow), filter(sRow(_,_,bcast_pipe_index))); + } + } + }; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); + Tensor gRow = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) + Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE) + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), + make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); + + constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; + return ProducerLoadCallbacks( + cute::move(gRow), cute::move(sRow), params); + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(RTensor&& tCrRow, STensor&& tCsRow, Params const& params) + : tCrRow(cute::forward(tCrRow)), + tCsRow(cute::forward(tCsRow)), + params(params) {} + + RTensor tCrRow; // (CPY,CPY_M,CPY_N) + STensor tCsRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) + Params const& params; + + CUTLASS_DEVICE void + previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { + if (!params.row_broadcast) { + fill(tCrRow, *(params.ptr_row)); + return; + } + + if (epi_m == 0) { // Assumes M-major subtile loop + // Filter so we don't issue redundant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + int bcast_pipe_index = (load_iteration / EpiTiles) % Stages; + copy_aligned(filter(tCsRow(_,_,_,epi_m,epi_n,bcast_pipe_index)), filter(tCrRow)); + } + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_row; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_row[i] = tCrRow(epi_v * FragmentSize + i); + } + + return frg_row; + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE) + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), + make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); + Tensor tCsRow = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) + sRow, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrRow = make_tensor_like(take<0,3>(tCsRow)); // (CPY,CPY_M,CPY_N) + + constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; + return ConsumerStoreCallbacks( + cute::move(tCrRow), cute::move(tCsRow), params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Column vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_1,_0,_0>, + int Alignment = 128 / sizeof_bits_v +> +struct Sm90ColOrScalarBroadcast { + static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet"); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert( + (cute::is_same_v>) || // col vector broadcast, e.g. per-row alpha/bias + (cute::is_same_v>)); // batched col vector broadcast, e.g. batched per-row bias + + // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem + struct SharedStorage { }; + + // This struct has been modified to have a bool indicating that ptr_col is a + // scalar that must be broadcast, instead of containing a scalar that is + // valid if ptr_col is null. + struct Arguments { + Element const* ptr_col = nullptr; + bool col_broadcast = true; + StrideMNL dCol = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (!params.col_broadcast && *(params.ptr_col) == Element(0)); + } + + CUTLASS_HOST_DEVICE + Sm90ColOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(GTensor&& tCgCol, RTensor&& tCrCol, Params const& params) + : tCgCol(cute::forward(tCgCol)), + tCrCol(cute::forward(tCrCol)), + params(params) {} + + GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + RTensor tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + Params const& params; + + CUTLASS_DEVICE void + begin() { + if (!params.col_broadcast) { + fill(tCrCol, *(params.ptr_col)); + return; + } + + // Filter so we don't issue redundant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + copy_aligned(filter(tCgCol), filter(tCrCol)); + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_col; + Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i); + } + + return frg_col; + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); + Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + return ConsumerStoreCallbacks( + cute::move(tCgCol), cute::move(tCrCol), params); + } +}; + +} diff --git a/csrc/quantization/cutlass_w8a8/common.hpp b/csrc/quantization/cutlass_w8a8/common.hpp new file mode 100644 index 0000000000000..bf04bb400790f --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/common.hpp @@ -0,0 +1,27 @@ +#pragma once + +#include "cutlass/cutlass.h" +#include + +/** + * Helper function for checking CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + TORCH_CHECK(status == cutlass::Status::kSuccess, \ + cutlassGetStatusString(status)) \ + } + +inline uint32_t next_pow_2(uint32_t const num) { + if (num <= 1) return num; + return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); +} + +inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { + int max_shared_mem_per_block_opt_in = 0; + cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, + cudaDevAttrMaxSharedMemoryPerBlockOptin, + device); + return max_shared_mem_per_block_opt_in; +} + diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu new file mode 100644 index 0000000000000..6ce25c5ac897b --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu @@ -0,0 +1,609 @@ +#include +#include + +#include + +// clang-format will break include orders +// clang-format off +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" + +#include "cutlass/util/device_memory.h" + +#include "cutlass/cutlass.h" +#include "cutlass/gemm_coord.h" +#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" +#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" + +#include "broadcast_load_epilogue_c2x.hpp" +#include "common.hpp" +// clang-format on + +using namespace cute; + +/* + This file defines quantized GEMM operations using the CUTLASS 2.x API, for + NVIDIA GPUs with SM versions prior to sm90 (Hopper). + + Epilogue functions can be defined to post-process the output before it is + written to GPU memory. + Epilogues must contain a public type named EVTCompute of type Sm80EVT, + as well as a static prepare_args function that constructs an + EVTCompute::Arguments struct. +*/ + +namespace { + +// Wrappers for the GEMM kernel that is used to guard against compilation on +// architectures that will never use the kernel. The purpose of this is to +// reduce the size of the compiled binary. +// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef +// into code that will be executed on the device where it is defined. +template +struct enable_sm75_to_sm80 : Kernel { + template + CUTLASS_DEVICE static void invoke(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800 + Kernel::invoke(std::forward(args)...); +#endif + } +}; + +template +struct enable_sm80_to_sm89 : Kernel { + template + CUTLASS_DEVICE static void invoke(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890 + Kernel::invoke(std::forward(args)...); +#endif + } +}; + +template +struct enable_sm89_to_sm90 : Kernel { + template + CUTLASS_DEVICE static void invoke(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900 + Kernel::invoke(std::forward(args)...); +#endif + } +}; + +/* + * This class provides the common ScaleA and ScaleB descriptors for the + * ScaledEpilogue and ScaledEpilogueBias classes. + */ +template +struct ScaledEpilogueBase { + protected: + using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; + + using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< + OutputTileThreadMap, float, Stride, Int<0>, Int<0>>>; + + using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< + OutputTileThreadMap, float, Stride, Int<1>, Int<0>>>; +}; + +/* + This epilogue function defines a quantized GEMM operation similar to + torch._scaled_mm. + + A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or + per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ +template +struct ScaledEpilogue + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::ScaleA; + using ScaleB = typename SUPER::ScaleB; + + using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::threadblock::Sm80EVT; + + using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::threadblock::Sm80EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + using ScaleAArgs = typename ScaleA::Arguments; + using ScaleBArgs = typename ScaleB::Arguments; + + ScaleBArgs b_args{b_scales.data_ptr(), b_scales.numel() != 1, {}}; + ScaleAArgs a_args{a_scales.data_ptr(), a_scales.numel() != 1, {}}; + + typename EVTCompute0::Arguments evt0_compute_args{b_args}; + + typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args}; + return evt_compute_args; + } +}; + +template +struct ScaledEpilogueBias + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::ScaleA; + using ScaleB = typename SUPER::ScaleB; + + using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::threadblock::Sm80EVT; + + using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast< + OutputTileThreadMap, ElementD, Stride, Int<1>, Int<0>>>; + + public: + using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& bias) { + using ScaleAArgs = typename ScaleA::Arguments; + using ScaleBArgs = typename ScaleB::Arguments; + using BiasArgs = typename Bias::Arguments; + + ScaleBArgs b_args{b_scales.data_ptr(), b_scales.numel() != 1, {}}; + ScaleAArgs a_args{a_scales.data_ptr(), a_scales.numel() != 1, {}}; + BiasArgs bias_args{static_cast(bias.data_ptr()), {}}; + + typename EVTCompute0::Arguments evt0_compute_args{b_args}; + + typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args, + bias_args}; + return evt_compute_args; + } +}; + +template typename ArchGuard, + typename ElementAB_, typename ElementD_, + template typename Epilogue_, typename TileShape, + typename WarpShape, typename InstructionShape, int32_t MainLoopStages> +struct cutlass_2x_gemm { + using ElementAB = ElementAB_; + using ElementD = ElementD_; + + using ElementAcc = + typename std::conditional, int32_t, + float>::type; + + using Operator = + typename std::conditional, + cutlass::arch::OpMultiplyAddSaturate, + cutlass::arch::OpMultiplyAdd>::type; + + using OutputTileThreadMap = + cutlass::epilogue::threadblock::OutputTileThreadLayout< + TileShape, WarpShape, float, 4, 1 /* epilogue stages */ + >; + + using Epilogue = Epilogue_; + using EVTCompute = typename Epilogue::EVTCompute; + + using D = cutlass::epilogue::threadblock::VisitorAuxStore< + OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest, + Stride, Int<0>>>; + + using EVTD = cutlass::epilogue::threadblock::Sm80EVT; + + // clang-format off + using RowMajor = typename cutlass::layout::RowMajor; + using ColumnMajor = typename cutlass::layout::ColumnMajor; + using KernelType = + ArchGuard::GemmKernel>; + // clang-format on + + using Op = cutlass::gemm::device::GemmUniversalAdapter; +}; + +template +void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + EpilogueArgs&&... epilogue_params) { + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + int32_t m = a.size(0); + int32_t n = b.size(1); + int32_t k = a.size(1); + cutlass::gemm::GemmCoord problem_size{m, n, k}; + + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + using StrideC = Stride, Int<0>>; + StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto c_ptr = static_cast(out.data_ptr()); + + typename Gemm::D::Arguments d_args{c_ptr, c_stride}; + + using Epilogue = typename Gemm::Epilogue; + auto evt_args = + Epilogue::prepare_args(std::forward(epilogue_params)...); + + typename Gemm::EVTD::Arguments epilogue_args{ + evt_args, + d_args, + }; + + typename Gemm::Op::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode + problem_size, // problem size + 1, // batch count + epilogue_args, + a_ptr, + b_ptr, + nullptr, + nullptr, + 0, + 0, + 0, + 0, + lda, + ldb, + ldc, + ldc}; + + // Launch the CUTLASS GEMM kernel. + typename Gemm::Op gemm_op; + size_t workspace_size = gemm_op.get_workspace_size(args); + cutlass::device_memory::allocation workspace(workspace_size); + + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + CUTLASS_CHECK(gemm_op.can_implement(args)); + cutlass::Status status = gemm_op(args, workspace.get(), stream); + CUTLASS_CHECK(status); +} + +template +void fallback_cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + EpilogueArgs&&... args) { + // In some cases, the GPU isn't able to accommodate the + // shared memory requirements of the Gemm. In such cases, use + // the FallbackGemm instead. + static const int max_shared_mem_per_block_opt_in = + get_cuda_max_shared_memory_per_block_opt_in(0); + + size_t const gemm_shared_mem_size = + sizeof(typename Gemm::KernelType::SharedStorage); + size_t const fallback_gemm_shared_mem_size = + sizeof(typename FallbackGemm::KernelType::SharedStorage); + + if (gemm_shared_mem_size <= max_shared_mem_per_block_opt_in) { + return cutlass_gemm_caller(out, a, b, + std::forward(args)...); + } else { + TORCH_CHECK(fallback_gemm_shared_mem_size <= + max_shared_mem_per_block_opt_in); + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } +} + +template typename Epilogue> +struct sm80_config_default { + // This config is used in 2 cases, + // - M in (128, inf) + // - M in (64, 128] and N >= 8192 + // Shared Memory required by this Gemm - 81920 bytes + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + using Cutlass2xGemm = + cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm80_config_M64 { + // This config is used in 2 cases, + // - M in (32, 64] + // - M in (64, 128] and N < 8192 + // Shared Memory required by this Gemm - 122880 bytes + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + using Cutlass2xGemm = + cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm80_config_M32 { + // M in (16, 32] + // Shared Memory required by this Gemm - 61440 bytes + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + using Cutlass2xGemm = + cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm80_config_M16 { + // M in [1, 16] + // Shared Memory required by this Gemm - 51200 bytes + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + using Cutlass2xGemm = + cutlass_2x_gemm; +}; + +} // namespace + +template typename Epilogue, + typename... EpilogueArgs> +void cutlass_gemm_sm80_dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kInt8); + TORCH_CHECK(b.dtype() == torch::kInt8); + + using Cutlass2xGemmDefault = + typename sm80_config_default::Cutlass2xGemm; + using Cutlass2xGemmM128BigN = + typename sm80_config_default::Cutlass2xGemm; + using Cutlass2xGemmM128SmallN = + typename sm80_config_M64::Cutlass2xGemm; + using Cutlass2xGemmM64 = + typename sm80_config_M64::Cutlass2xGemm; + using Cutlass2xGemmM32 = + typename sm80_config_M32::Cutlass2xGemm; + using Cutlass2xGemmM16 = + typename sm80_config_M16::Cutlass2xGemm; + + // Due to shared memory requirements, some Gemms may fail to run on some + // GPUs. As the name indicates, the Fallback Gemm is used as an alternative + // in such cases. + // sm80_config_M16 has the least shared-memory requirement. However, + // based on some profiling, we select sm80_config_M32 as a better alternative + // performance wise. + using FallbackGemm = + typename sm80_config_M32::Cutlass2xGemm; + + uint32_t const m = a.size(0); + uint32_t const mp2 = + std::max(static_cast(16), next_pow_2(m)); // next power of 2 + if (mp2 <= 16) { + // M in [1, 16] + return fallback_cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else if (mp2 <= 32) { + // M in (16, 32] + return fallback_cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else if (mp2 <= 64) { + // M in (32, 64] + return fallback_cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else if (mp2 <= 128) { + // M in (64, 128] + uint32_t const n = out.size(1); + bool const small_n = n < 8192; + if (small_n) { + return fallback_cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else { + return fallback_cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } + } else { + // M in (128, inf) + return fallback_cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } +} + +template